1- #include " scmath.h"
1+ #include " scmath.h"
22
3- MathFunc::MathFunc (BasicNode *func)
3+ MathFunc::MathFunc (BasicNode *func, Scope *sc )
44{
5- funScope = new Scope (&record::globalScope);
5+ if (sc == nullptr )
6+ funScope = new Scope (&record::globalScope);
7+ else
8+ {
9+ funScope = sc;
10+ ownerScope = false ;
11+ }
612 stringstream ss;
713 ast::output (func, ss);
814 string funString;
915 ss >> funString;// 因为域不同所以要这样做
1016 funPro = ast::toAST (funString, funScope);
1117}
1218
13- MathFunc::MathFunc (const string &func)
19+ MathFunc::MathFunc (const string &func, Scope *sc )
1420{
15- funScope = new Scope (&record::globalScope);
21+ if (sc == nullptr )
22+ funScope = new Scope (&record::globalScope);
23+ else
24+ {
25+ funScope = sc;
26+ ownerScope = false ;
27+ }
1628 funPro = ast::toAST (func, funScope);
1729}
1830
19- MathFunc::MathFunc (const MathFunc &func)
31+ MathFunc::MathFunc (const MathFunc &func, Scope *sc )
2032{
21- funScope = new Scope (&record::globalScope);
33+ if (sc == nullptr )
34+ funScope = new Scope (&record::globalScope);
35+ else
36+ {
37+ funScope = sc;
38+ ownerScope = false ;
39+ }
2240 stringstream ss;
2341 ss << func;
2442 string funString;
@@ -29,7 +47,8 @@ MathFunc::MathFunc(const MathFunc &func)
2947MathFunc::~MathFunc ()
3048{
3149 delete funPro;
32- delete funScope;
50+ if (ownerScope)
51+ delete funScope;
3352}
3453
3554void MathFunc::setVal (const string &value, BasicNode *num)
@@ -46,9 +65,49 @@ void MathFunc::setVal(const string &value, const double &num)
4665 funScope->findVariable (value, true )->setVal (new NumNode (num));
4766}
4867
68+ void MathFunc::changeScope (Scope *sc)
69+ {
70+ stringstream ss;
71+ ss << *this ;
72+ string fun;
73+ ss >> fun;
74+ delete funPro;
75+ if (ownerScope)
76+ delete funScope;
77+ if (sc == nullptr )
78+ {
79+ funScope = new Scope (&record::globalScope);
80+ ownerScope = true ;
81+ }
82+ else
83+ {
84+ funScope = sc;
85+ ownerScope = false ;
86+ }
87+ funPro = ast::toAST (fun, funScope);
88+ }
89+
4990const MathFunc MathFunc::eval ()
5091{
51- return MathFunc (funPro->eval ());
92+ BasicNode *p = copyHelp::copyNode (funPro);
93+ MathFunc retn (funPro->eval ());
94+ delete funPro;
95+ funPro = p;
96+ return retn;
97+ }
98+
99+ double MathFunc::getNum ()
100+ {
101+ BasicNode *t = copyHelp::copyNode (funPro);
102+ BasicNode *p = funPro->eval ();
103+ if (p->getType () != Num)
104+ throw string (" 有变量未赋值" );
105+ // cout << ((NumNode*)p)->getNum() << endl;
106+ double retn = ((NumNode*)p)->getNum ();
107+ delete funPro;
108+ // delete p;
109+ funPro = t;
110+ return retn;
52111}
53112
54113const MathFunc MathFunc::diff (const string &value)
@@ -59,6 +118,11 @@ const MathFunc MathFunc::diff(const string &value)
59118 return retn;
60119}
61120
121+ const ValWeight MathFunc::regress (const DataSet &data, const VarSet var, int n)
122+ {
123+ return regression (funPro, data, var, n);
124+ }
125+
62126MathFunc& MathFunc::operator =(const string &st)
63127{
64128 delete funPro;
@@ -352,3 +416,71 @@ BasicNode* Derivation(BasicNode *now, const string &value){
352416 __Simplificate (retn);
353417 return retn;
354418}
419+
420+ const ValWeight regression (BasicNode *func, const DataSet &data, const VarSet var, int n)
421+ {
422+ // 参数分别为 目标函数,数据,要回归的变量,迭代次数
423+ // 数据最后一列为函数值
424+ if (data.getr () != var.size () + 1 )
425+ throw string (" 维度不对" );
426+ ValWeight weight;
427+ unique_ptr<Scope> sc (new Scope (&record::globalScope)) ;
428+ map<string, MathFunc> grad;
429+
430+ stack<BasicNode*> te;
431+ te.push (func);
432+ while (!te.empty ())
433+ {
434+ BasicNode *t = te.top ();
435+ te.pop ();
436+ if (t->getType () == Var)
437+ {
438+ string name = ((VarNode*)t)->NAME ;
439+ if (weight.count (name) == 0 )
440+ {
441+ for (auto &i : var)
442+ if (i == name)
443+ goto lable;
444+ weight[name] = 1.0 ;
445+ }
446+ }
447+ if (t->getType () == Fun)
448+ {
449+ FunNode *temp = (FunNode*)t;
450+ for (auto &i :temp->sonNode )
451+ te.push (i);
452+ }
453+ lable:;
454+ }
455+
456+ string IndependentValue = (*(weight.rbegin ())).first + " y" ;// 防止重名
457+ stringstream ss;
458+ ast::output (func, ss);
459+ string tfunc;
460+ ss >> tfunc;
461+ string tloss = " (" + tfunc + " -" + IndependentValue + " )^2" ;
462+ MathFunc loss = tloss;
463+ loss.changeScope (sc.get ());
464+ for (auto i : weight)
465+ {
466+ grad[i.first ] = loss.diff (i.first );
467+ grad[i.first ].changeScope (loss.getScope ());
468+ // cout <<i.first << '\t' << grad[i.first] << endl;
469+ }
470+
471+ for (int time = 0 ; time < n; time++)// 第time次遍历
472+ {
473+ for (int i = 0 ; i < data.getc (); i++)// 对于每个数据
474+ {
475+ double alpha = 0.01 ;
476+ for (int p = 0 ; p < var.size (); p++)// 把数据点放到函数里面
477+ loss.setVal (var[p], data.m [i][p]);
478+ loss.setVal (IndependentValue, data.m [i][data.getr () - 1 ]);
479+ for (auto &p : weight)// 把目前的权放到函数里面
480+ loss.setVal (p.first , p.second );
481+ for (auto &p : weight)// 每个变量
482+ p.second -= alpha * grad[p.first ].getNum ();
483+ }
484+ }
485+ return weight;
486+ }
0 commit comments