Skip to content

Commit 19abb27

Browse files
author
冷漠
committed
增加回归功能
1 parent 813e31d commit 19abb27

File tree

7 files changed

+192
-28
lines changed

7 files changed

+192
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SCMath.pro.user.5cdfc2d

SCMath.pro.user

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<?xml version="1.0" encoding="UTF-8"?>
22
<!DOCTYPE QtCreatorProject>
3-
<!-- Written by QtCreator 4.5.0, 2019-04-05T21:43:53. -->
3+
<!-- Written by QtCreator 4.5.0, 2019-04-11T01:29:23. -->
44
<qtcreator>
55
<data>
66
<variable>EnvironmentId</variable>

main.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,22 @@ int main()
1515
cout << test.diff("e") << endl;
1616
cout << test.diff("f") << endl;
1717
cout << test.diff("g") << endl;
18-
cout << endl;
19-
2018
test.setVal("a", 1);
2119
cout << test.eval() << endl;
20+
test.setVal("a", 2);
21+
cout << test.eval() << endl;
22+
cout << endl;
23+
24+
string testRegression = "a*x+b";
25+
test = testRegression;
26+
cout << testRegression << endl;
27+
DataSet data(2, 2);
28+
data.m[0][0] = 2; data.m[0][1] = 2;//(2, 2)
29+
data.m[1][0] = 3; data.m[1][1] = 5;//(3, 5)
30+
VarSet varlist;
31+
varlist.push_back("x");
32+
ValWeight weight = test.regress(data, varlist, 10000);
33+
for(auto &i : weight)
34+
cout << i.first << '\t' << i.second << endl;
2235
return 0;
2336
}

matrix.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class vectorNode : public BasicNode
7878
}
7979

8080
~vectorNode() { delete[]v; }
81-
unsigned int getl() { return l; }
81+
unsigned int getl() const { return l; }
8282
};
8383

8484

@@ -136,8 +136,8 @@ class matrixNode : public BasicNode
136136
delete[] m;
137137
}
138138

139-
unsigned int getr() { return r; }
140-
unsigned int getc() { return c; }
139+
unsigned int getr() const { return r; }
140+
unsigned int getc() const { return c; }
141141

142142
vectorNode getRVector(unsigned int rn)
143143
{

scmath.cpp

Lines changed: 141 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,42 @@
1-
#include "scmath.h"
1+
#include "scmath.h"
22

3-
MathFunc::MathFunc(BasicNode *func)
3+
MathFunc::MathFunc(BasicNode *func, Scope *sc)
44
{
5-
funScope = new Scope(&record::globalScope);
5+
if(sc == nullptr)
6+
funScope = new Scope(&record::globalScope);
7+
else
8+
{
9+
funScope = sc;
10+
ownerScope = false;
11+
}
612
stringstream ss;
713
ast::output(func, ss);
814
string funString;
915
ss >> funString;//因为域不同所以要这样做
1016
funPro = ast::toAST(funString, funScope);
1117
}
1218

13-
MathFunc::MathFunc(const string &func)
19+
MathFunc::MathFunc(const string &func, Scope *sc)
1420
{
15-
funScope = new Scope(&record::globalScope);
21+
if(sc == nullptr)
22+
funScope = new Scope(&record::globalScope);
23+
else
24+
{
25+
funScope = sc;
26+
ownerScope = false;
27+
}
1628
funPro = ast::toAST(func, funScope);
1729
}
1830

19-
MathFunc::MathFunc(const MathFunc &func)
31+
MathFunc::MathFunc(const MathFunc &func, Scope *sc)
2032
{
21-
funScope = new Scope(&record::globalScope);
33+
if(sc == nullptr)
34+
funScope = new Scope(&record::globalScope);
35+
else
36+
{
37+
funScope = sc;
38+
ownerScope = false;
39+
}
2240
stringstream ss;
2341
ss << func;
2442
string funString;
@@ -29,7 +47,8 @@ MathFunc::MathFunc(const MathFunc &func)
2947
MathFunc::~MathFunc()
3048
{
3149
delete funPro;
32-
delete funScope;
50+
if(ownerScope)
51+
delete funScope;
3352
}
3453

3554
void MathFunc::setVal(const string &value, BasicNode *num)
@@ -46,9 +65,49 @@ void MathFunc::setVal(const string &value, const double &num)
4665
funScope->findVariable(value, true)->setVal(new NumNode(num));
4766
}
4867

68+
void MathFunc::changeScope(Scope *sc)
69+
{
70+
stringstream ss;
71+
ss << *this;
72+
string fun;
73+
ss >> fun;
74+
delete funPro;
75+
if(ownerScope)
76+
delete funScope;
77+
if(sc == nullptr)
78+
{
79+
funScope = new Scope(&record::globalScope);
80+
ownerScope = true;
81+
}
82+
else
83+
{
84+
funScope = sc;
85+
ownerScope = false;
86+
}
87+
funPro = ast::toAST(fun, funScope);
88+
}
89+
4990
const MathFunc MathFunc::eval()
5091
{
51-
return MathFunc(funPro->eval());
92+
BasicNode *p = copyHelp::copyNode(funPro);
93+
MathFunc retn(funPro->eval());
94+
delete funPro;
95+
funPro = p;
96+
return retn;
97+
}
98+
99+
double MathFunc::getNum()
100+
{
101+
BasicNode *t = copyHelp::copyNode(funPro);
102+
BasicNode *p = funPro->eval();
103+
if(p->getType() != Num)
104+
throw string("有变量未赋值");
105+
//cout << ((NumNode*)p)->getNum() << endl;
106+
double retn = ((NumNode*)p)->getNum();
107+
delete funPro;
108+
//delete p;
109+
funPro = t;
110+
return retn;
52111
}
53112

54113
const MathFunc MathFunc::diff(const string &value)
@@ -59,6 +118,11 @@ const MathFunc MathFunc::diff(const string &value)
59118
return retn;
60119
}
61120

121+
const ValWeight MathFunc::regress(const DataSet &data, const VarSet var, int n)
122+
{
123+
return regression(funPro, data, var, n);
124+
}
125+
62126
MathFunc& MathFunc::operator=(const string &st)
63127
{
64128
delete funPro;
@@ -352,3 +416,71 @@ BasicNode* Derivation(BasicNode *now, const string &value){
352416
__Simplificate(retn);
353417
return retn;
354418
}
419+
420+
const ValWeight regression(BasicNode *func, const DataSet &data, const VarSet var, int n)
421+
{
422+
//参数分别为 目标函数,数据,要回归的变量,迭代次数
423+
//数据最后一列为函数值
424+
if(data.getr() != var.size() + 1)
425+
throw string("维度不对");
426+
ValWeight weight;
427+
unique_ptr<Scope> sc(new Scope(&record::globalScope)) ;
428+
map<string, MathFunc> grad;
429+
430+
stack<BasicNode*> te;
431+
te.push(func);
432+
while(!te.empty())
433+
{
434+
BasicNode *t = te.top();
435+
te.pop();
436+
if(t->getType() == Var)
437+
{
438+
string name = ((VarNode*)t)->NAME;
439+
if(weight.count(name) == 0)
440+
{
441+
for(auto &i : var)
442+
if(i == name)
443+
goto lable;
444+
weight[name] = 1.0;
445+
}
446+
}
447+
if(t->getType() == Fun)
448+
{
449+
FunNode *temp = (FunNode*)t;
450+
for(auto &i :temp->sonNode)
451+
te.push(i);
452+
}
453+
lable:;
454+
}
455+
456+
string IndependentValue = (*(weight.rbegin())).first + "y";//防止重名
457+
stringstream ss;
458+
ast::output(func, ss);
459+
string tfunc;
460+
ss >> tfunc;
461+
string tloss = "(" + tfunc + "-" + IndependentValue + ")^2";
462+
MathFunc loss = tloss;
463+
loss.changeScope(sc.get());
464+
for(auto i : weight)
465+
{
466+
grad[i.first] = loss.diff(i.first);
467+
grad[i.first].changeScope(loss.getScope());
468+
//cout <<i.first << '\t' << grad[i.first] << endl;
469+
}
470+
471+
for(int time = 0; time < n; time++)//第time次遍历
472+
{
473+
for(int i = 0; i < data.getc(); i++)//对于每个数据
474+
{
475+
double alpha = 0.01;
476+
for(int p = 0; p < var.size(); p++)//把数据点放到函数里面
477+
loss.setVal(var[p], data.m[i][p]);
478+
loss.setVal(IndependentValue, data.m[i][data.getr() - 1]);
479+
for(auto &p : weight)//把目前的权放到函数里面
480+
loss.setVal(p.first, p.second);
481+
for(auto &p : weight)//每个变量
482+
p.second -= alpha * grad[p.first].getNum();
483+
}
484+
}
485+
return weight;
486+
}

scmath.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,34 @@
22
#include "ast.h"
33
#include "nodetype.h"
44
#include<sstream>
5+
#include <memory>
6+
7+
typedef matrixNode DataSet;
8+
typedef vector<string> VarSet;
9+
typedef map<string, double> ValWeight;
510

611
class MathFunc
712
{
813
private:
9-
BasicNode *funPro;
10-
Scope* funScope;
14+
BasicNode *funPro = nullptr;
15+
Scope* funScope = nullptr;
16+
bool ownerScope = true;
1117
public:
12-
MathFunc(const string &);
13-
MathFunc(BasicNode *);
14-
MathFunc(const MathFunc &);
18+
MathFunc(){}
19+
MathFunc(const string &, Scope* = nullptr);
20+
MathFunc(BasicNode *, Scope* = nullptr);
21+
MathFunc(const MathFunc &, Scope* = nullptr);
1522
~MathFunc();
1623

17-
void setVal(const string&, const double&);
24+
void setVal(const string &, const double&);
1825
void setVal(const string &, BasicNode*);
26+
void changeScope(Scope * = nullptr);
1927
const MathFunc eval();
28+
double getNum();
2029
const MathFunc diff(const string &);
30+
const ValWeight regress(const DataSet &, const VarSet, int = 500);
31+
32+
Scope* getScope(){return funScope;}
2133

2234
MathFunc& operator=(const string &);
2335
MathFunc& operator=(BasicNode *);
@@ -29,4 +41,4 @@ class MathFunc
2941
void __Simplificate(BasicNode *&);
3042
BasicNode* __Derivation(BasicNode* , const string &);
3143
BasicNode* Derivation(BasicNode*, const string &);
32-
44+
const ValWeight regression(BasicNode*, const DataSet &, const VarSet, int);

scope.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,25 @@ void Scope::deleteVariable(string name)
5757

5858
void Scope::deleteVariable(Variable *var)
5959
{
60-
for(auto p:this->variableList)
60+
auto p = variableList.begin();
61+
while(p != variableList.end())
6162
{
62-
if(p.second==var)
63-
this->variableList.erase(p.first);
63+
if(p->second == var)
64+
variableList.erase(p++);
65+
else
66+
p++;
6467
}
6568
}
6669

6770
void Scope::deleteFunction(Function *fun)
6871
{
69-
for(auto p:this->functionList)
72+
auto p = functionList.begin();
73+
while(p != functionList.end())
7074
{
71-
if(p.second==fun)
72-
this->functionList.erase(p.first);
75+
if(p->second == fun)
76+
functionList.erase(p++);
77+
else
78+
p++;
7379
}
7480
}
7581

0 commit comments

Comments
 (0)