Skip to content

Commit 35d4181

Browse files
committed
Dot product with Vector type; Passed vector Cartesian product tests
1 parent d5a78a9 commit 35d4181

6 files changed

Lines changed: 179 additions & 53 deletions

File tree

src/lambdacloud/test/TestBatchVec.java

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
package lambdacloud.test;
22

33
import static lambdacloud.test.TestUtils.assertEqual;
4-
import static symjava.symbolic.Symbol.x;
5-
import static symjava.symbolic.Symbol.y;
6-
import static symjava.symbolic.Symbol.z;
4+
import static symjava.symbolic.Symbol.*;
5+
import static symjava.math.SymMath.*;
76
import lambdacloud.core.lang.LCArray;
87
import lambdacloud.core.lang.LCInt;
98
import lambdacloud.core.lang.LCLoop;
@@ -13,8 +12,6 @@
1312
import symjava.bytecode.BytecodeBatchVecFunc;
1413
import symjava.bytecode.BytecodeSelect;
1514
import symjava.bytecode.BytecodeVecFunc;
16-
import symjava.math.Dot;
17-
import symjava.math.SymMath;
1815
import symjava.relational.Lt;
1916
import symjava.symbolic.Expr;
2017
import symjava.symbolic.Vector;
@@ -30,19 +27,15 @@ public static void main(String[] args) {
3027
//testVector();
3128
//testBatchVector();
3229

33-
testBatchVectorDot1();
34-
//TODO
35-
//1. dot product of vector type
36-
//2. Cartesian between vectors (Done)
37-
// [[1,2,3]] * [[4,5,6],[7,8,9]] =
38-
// [ [1,2,3] dot [4,5,6]; [1,2,3] dot [7,8,9] ]
39-
// 1,2,3,1,2,3
40-
// 4,5,6,7,8,9
30+
testBatchVectorDotWithLCArray();
4131
long begin = System.currentTimeMillis();
42-
testBatchVectorDot2();
32+
//Cartesian between vectors
33+
testBatchVectorDotWithLCArray2();
4334
System.out.println("Time: "+(System.currentTimeMillis() - begin)+"ms");
4435

45-
36+
//Dot product of Vector type
37+
testBatchVectorDot();
38+
testBatchVectorDot2();
4639
}
4740

4841
public static void testBatchVecFunc() {
@@ -132,7 +125,7 @@ public static void testBatchVector() {
132125
assertEqual(new double[]{5,6,7,8,7,8,9,10,9,10,11,12}, outAry2);
133126
}
134127

135-
public static void testBatchVectorDot1() {
128+
public static void testBatchVectorDotWithLCArray() {
136129
LCStatements lcs = new LCStatements();
137130

138131
LCArray x = LCArray.getDoubleArray("x");
@@ -162,7 +155,7 @@ public static void testBatchVectorDot1() {
162155
assertEqual(new double[]{15,30,45}, outAry2);
163156
}
164157

165-
public static void testBatchVectorDot2() {
158+
public static void testBatchVectorDotWithLCArray2() {
166159
LCStatements lcs = new LCStatements();
167160

168161
LCArray x = LCArray.getDoubleArray("x");
@@ -182,6 +175,7 @@ public static void testBatchVectorDot2() {
182175
int dim = 3;
183176
// The length of the return value of dot product is 1.
184177
BytecodeBatchVecFunc ff = new BytecodeBatchVecFunc(func, dim, 1);
178+
//Cartesian with vector length = dim
185179
double[][] args2 = BytecodeSelect.cartesian(dim, args);
186180
int outLen = args2[0].length/dim;
187181
double[] outAry2 = new double[outLen];
@@ -192,4 +186,39 @@ public static void testBatchVectorDot2() {
192186
assertEqual(new double[]{6,12,15,30}, outAry2);
193187
}
194188

189+
public static void testBatchVectorDot() {
190+
int dim = 3;
191+
Vector x = new Vector("x", dim);
192+
Vector y = new Vector("y", dim);
193+
194+
BytecodeVecFunc func = CompileUtils.compileVecFunc(new LCReturn(dot(x,y)));
195+
double[] outAry = new double[1];
196+
double[][] args = { {1,2,3}, {4,5,6} };
197+
func.apply(outAry, 0, args);
198+
for(double d : outAry) {
199+
System.out.println(d);
200+
}
201+
assertEqual(new double[]{32}, outAry);
202+
}
203+
204+
public static void testBatchVectorDot2() {
205+
int dim = 3;
206+
Vector x = new Vector("x", dim);
207+
Vector y = new Vector("y", dim);
208+
209+
BytecodeVecFunc func = CompileUtils.compileVecFunc(new LCReturn(dot(x,y)));
210+
211+
double[][][] args = { {{1,2,3,4,5,6}}, {{1,1,1,2,2,2}} };
212+
// The length of the return value of dot product is 1.
213+
BytecodeBatchVecFunc ff = new BytecodeBatchVecFunc(func, dim, 1);
214+
//Cartesian with vector length = dim
215+
double[][] args2 = BytecodeSelect.cartesian(dim, args);
216+
int outLen = args2[0].length/dim;
217+
double[] outAry2 = new double[outLen];
218+
ff.apply(outAry2, 0, args2);
219+
for(double d : outAry2) {
220+
System.out.println(d);
221+
}
222+
assertEqual(new double[]{6,12,15,30}, outAry2);
223+
}
195224
}

src/symjava/math/Dot.java

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,70 @@
11
package symjava.math;
22

3+
import static com.sun.org.apache.bcel.internal.generic.InstructionConstants.DADD;
4+
import static com.sun.org.apache.bcel.internal.generic.InstructionConstants.FADD;
5+
import static com.sun.org.apache.bcel.internal.generic.InstructionConstants.IADD;
6+
import static com.sun.org.apache.bcel.internal.generic.InstructionConstants.LADD;
7+
38
import java.util.ArrayList;
49
import java.util.List;
10+
import java.util.Map;
11+
12+
import com.sun.org.apache.bcel.internal.Constants;
13+
import com.sun.org.apache.bcel.internal.generic.ConstantPoolGen;
14+
import com.sun.org.apache.bcel.internal.generic.InstructionFactory;
15+
import com.sun.org.apache.bcel.internal.generic.InstructionHandle;
16+
import com.sun.org.apache.bcel.internal.generic.InstructionList;
17+
import com.sun.org.apache.bcel.internal.generic.MethodGen;
18+
import com.sun.org.apache.bcel.internal.generic.ObjectType;
19+
import com.sun.org.apache.bcel.internal.generic.Type;
520

621
import symjava.matrix.SymVector;
722
import symjava.symbolic.Add;
823
import symjava.symbolic.Expr;
924
import symjava.symbolic.SymReal;
1025
import symjava.symbolic.TypeInfo;
26+
import symjava.symbolic.Vector;
27+
import symjava.symbolic.Expr.TYPE;
28+
import symjava.symbolic.arity.BinaryOp;
29+
import symjava.symbolic.utils.BytecodeUtils;
1130
import symjava.symbolic.utils.Utils;
1231

1332
/**
1433
* Dot Product of two vectors
1534
*
1635
*/
17-
public class Dot extends Expr {
18-
protected SymVector left;
19-
protected SymVector right;
36+
public class Dot extends BinaryOp {
2037
protected Expr expr = null;
2138
public Dot(SymVector l, SymVector r) {
39+
super(l,r);
2240
if(l.dim() != r.dim())
2341
throw new IllegalArgumentException("The size of the two vector must be the same!");
24-
left = l;
25-
right = r;
26-
if(left instanceof Grad && right instanceof Grad) {
27-
label = left + " \\cdot " + right;
42+
arg1 = l;
43+
arg2 = r;
44+
if(arg1 instanceof Grad && arg2 instanceof Grad) {
45+
label = arg1 + " \\cdot " + arg2;
2846
sortKey = label;
2947
return;
3048
}
3149
List<Expr> list = new ArrayList<Expr>();
32-
for(int i=0; i<left.dim(); i++) {
33-
list.add(left.get(i).multiply(right.get(i)));
50+
for(int i=0; i<l.dim(); i++) {
51+
list.add(l.get(i).multiply(r.get(i)));
3452
}
3553
expr = Utils.addListToExpr(list).simplify();
3654
label = expr.toString();
3755
sortKey = label;
3856
}
3957

58+
public Dot(Vector l, Vector r) {
59+
super(l,r);
60+
if(l.dim() != r.dim())
61+
throw new IllegalArgumentException("The size of the two vector must be the same!");
62+
arg1 = l;
63+
arg2 = r;
64+
label = "dot(" + arg1 + ", " + arg2 + ")";
65+
sortKey = label;
66+
}
67+
4068
public static Expr apply(SymVector l, SymVector r) {
4169
List<Expr> list = new ArrayList<Expr>();
4270
for(int i=0; i<l.dim(); i++) {
@@ -51,20 +79,24 @@ public static Expr apply(SymVector l, SymVector r) {
5179
return dot;
5280
}
5381

82+
public static Expr apply(Vector l, Vector r) {
83+
return new Dot(l, r);
84+
}
85+
5486
@Override
5587
public Expr diff(Expr expr) {
5688
if(this.expr == null) {
57-
Grad lg = (Grad)left;
58-
Grad rg = (Grad)right;
89+
Grad lg = (Grad)arg1;
90+
Grad rg = (Grad)arg2;
5991
if(lg.isAbstract() && rg.isAbstract()) {
6092
Expr d1 = Dot.apply(new Grad(lg.getFunc().diff(expr), lg.getFunc().args), rg);
6193
Expr d2 = Dot.apply(lg, new Grad(rg.getFunc().diff(expr), rg.getFunc().args));
6294
return Add.simplifiedIns(d1, d2);
6395
}
6496
}
65-
if(left instanceof Grad && right instanceof Grad) {
66-
Grad lg = (Grad)left;
67-
Grad rg = (Grad)right;
97+
if(arg1 instanceof Grad && arg2 instanceof Grad) {
98+
Grad lg = (Grad)arg1;
99+
Grad rg = (Grad)arg2;
68100
Expr d1 = Dot.apply(lg.diff(expr), rg);
69101
Expr d2 = Dot.apply(lg, rg.diff(expr));
70102
return Add.simplifiedIns(d1, d2);
@@ -75,11 +107,11 @@ public Expr diff(Expr expr) {
75107
@Override
76108
public Expr fdiff(Expr f, Expr df) {
77109
if(expr == null) {
78-
Grad lg = (Grad)left;
79-
Grad rg = (Grad)right;
110+
Grad lg = (Grad)arg1;
111+
Grad rg = (Grad)arg2;
80112
if(lg.isAbstract() && rg.isAbstract()) {
81113
Expr d1 = Dot.apply(new Grad(lg.getFunc().fdiff(f, df)), rg);
82-
Expr d2 = Dot.apply(lg, new Grad(rg.getFunc().fdiff(f, df)));
114+
Expr d2 = Dot.apply(lg, new Grad(rg.getFunc().fdiff(f, df)));
83115
return Add.shallowSimplifiedIns(d1, d2);
84116
}
85117
}
@@ -125,7 +157,7 @@ public Expr subs(Expr from, Expr to) {
125157
if(Utils.symCompare(this, from))
126158
return to;
127159
if(expr == null)
128-
return new Dot(left.subs(from, to), right.subs(from, to));
160+
return new Dot((SymVector)arg1.subs(from, to), (SymVector)arg2.subs(from, to));
129161
else
130162
return expr.subs(from, to);
131163
}
@@ -135,25 +167,36 @@ public Expr getExpr() {
135167
return this.expr;
136168
else {
137169
List<Expr> list = new ArrayList<Expr>();
138-
for(int i=0; i<left.dim(); i++) {
139-
list.add(left.get(i).multiply(right.get(i)));
170+
for(int i=0; i<arg1.dim(); i++) {
171+
list.add(arg1.get(i).multiply(arg2.get(i)));
140172
}
141173
return Utils.addListToExpr(list).simplify();
142174
}
143175
}
144-
145-
@Override
146-
public Expr[] args() {
147-
// TODO Auto-generated method stub
148-
return null;
149-
}
150-
176+
151177
@Override
152-
public TypeInfo getTypeInfo() {
153-
// TODO Auto-generated method stub
154-
return null;
178+
public InstructionHandle bytecodeGen(String clsName, MethodGen mg,
179+
ConstantPoolGen cp, InstructionFactory factory,
180+
InstructionList il, Map<String, Integer> argsMap, int argsStartPos,
181+
Map<Expr, Integer> funcRefsMap) {
182+
InstructionHandle startPos = arg1.bytecodeGen(clsName, mg, cp, factory, il, argsMap, argsStartPos, funcRefsMap);
183+
if(arg1.getType() == TYPE.VECTOR && arg2.getType() == TYPE.VECTOR) {
184+
arg2.bytecodeGen(clsName, mg, cp, factory, il, argsMap, argsStartPos, funcRefsMap);
185+
il.append(factory.createInvoke("symjava.symbolic.utils.BytecodeOpSupport", "dot",
186+
//Type.DOUBLE, new Type[] { new ObjectType("Jama.Matrix"),new ObjectType("Jama.Matrix") },
187+
new ObjectType("Jama.Matrix"), new Type[] { new ObjectType("Jama.Matrix"),new ObjectType("Jama.Matrix") },
188+
Constants.INVOKESTATIC));
189+
return startPos;
190+
}
191+
//TODO
192+
return startPos;
155193
}
156194

195+
// @Override
196+
// public TypeInfo getTypeInfo() {
197+
// return TypeInfo.tiDouble;
198+
// }
199+
157200
@Override
158201
public void updateLabel() {
159202
// TODO Auto-generated method stub

src/symjava/math/SymMath.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import symjava.symbolic.SymConst;
1616
import symjava.symbolic.SymRandom;
1717
import symjava.symbolic.Tan;
18+
import symjava.symbolic.Vector;
1819
import symjava.symbolic.utils.Utils;
1920

2021
public class SymMath {
@@ -118,6 +119,10 @@ public static Expr random() {
118119
return new SymRandom();
119120
}
120121

122+
public static Expr dot(Vector l, Vector r) {
123+
return Dot.apply(l, r);
124+
}
125+
121126
public static Expr dot(SymVector l, SymVector r) {
122127
return Dot.apply(l, r);
123128
}

src/symjava/symbolic/Expr.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,35 @@ public Expr setArg(int index, Expr arg) {
551551
abstract public void updateLabel();
552552

553553
/**
554-
* Parent expression, for example sub-matrix
554+
* Parent expression, for example: a sub-matrix or sub-vector has a parent
555555
*/
556556
public Expr getParent() {
557-
return null;
557+
throw new UnsupportedOperationException();
558+
}
559+
560+
/**
561+
* Return the dimension of a vector
562+
* @return
563+
*/
564+
public int dim() {
565+
throw new UnsupportedOperationException();
566+
}
567+
568+
/**
569+
* return the dimension of a matrix or tensor
570+
* @return
571+
*/
572+
public int[] dims() {
573+
throw new UnsupportedOperationException();
574+
}
575+
576+
/**
577+
* return the element at index
578+
* @param index
579+
* @return
580+
*/
581+
public Expr get(int index) {
582+
throw new UnsupportedOperationException();
558583
}
559584
}
560585

src/symjava/symbolic/Vector.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ public Expr getParent() {
114114
return this.parent;
115115
}
116116

117+
public int dim() {
118+
return this.nDim;
119+
}
120+
117121
public static void main(String[] args) {
118122
Vector v = new Vector("A",8);
119123
SymVector sv = v.split(3);

0 commit comments

Comments
 (0)