forked from joyspark/TensorFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloadPythonModel.java
More file actions
123 lines (105 loc) · 2.89 KB
/
loadPythonModel.java
File metadata and controls
123 lines (105 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
public class LoadPythonModel {
static int ROW = 0;
static int FEATURE = 0;
public static void main(String[] args) throws IOException{
System.out.println("TensorFlow version : "+TensorFlow.version());
String filePath = "./data/test.csv";
//get shape of data
getDataSize(filePath);
System.out.print("[number of row] ==> "+ ROW);
System.out.println(" / [number of feature] ==> "+ FEATURE);
float[][] testInput = new float[ROW][FEATURE];
//insert csv data to matrix
csvToMtrx(filePath, testInput);
printMatrix(testInput);
//load the model bundle
try(SavedModelBundle b = SavedModelBundle.load("/tmp/fromPython", "serve")){
//create a session from the Bundle
Session sess = b.session();
//create an input Tensor
Tensor x = Tensor.create(testInput);
//run the model and get the result
float[][] y = sess.runner()
.feed("x", x)
.fetch("h")
.run()
.get(0)
.copyTo(new float[ROW][1]);
//print out the result
for(int i=0; i<y.length;i++)
System.out.println(y[i][0]);
}
}
/**
* csv 파일의 행/열 사이즈 측정
* @param filePath
* @throws IOException
*/
public static void getDataSize(String filePath) throws IOException {
try {
//read csv data file
File csv = new File(filePath);
BufferedReader br = new BufferedReader(new FileReader(csv));
String line = "";
String[] field = null;
while((line=br.readLine())!=null) {
field = line.split(",");
ROW++;
}
FEATURE = field.length;
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* csv 파일 데이터를 행렬로 옮김
* @param filePath
* @param mtrx
* @throws IOException
*/
public static void csvToMtrx(String filePath, float[][] mtrx) throws IOException {
try {
//read csv data file
File csv = new File(filePath);
BufferedReader br = new BufferedReader(new FileReader(csv));
String line = "";
String[] field = null;
for(int i=0; i<mtrx.length; i++) {
if((line=br.readLine())!= null) {
field = line.split(",");
for(int j=0; j<field.length; j++) {
mtrx[i][j] = Float.parseFloat(field[j]);
}
}
}
}catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 행렬 값 확인용 출력
* @param mtrx
*/
public static void printMatrix(float[][] mtrx) {
System.out.println("============ARRAY VALUES============");
for(int i=0; i<mtrx.length; i++) {
if(i==0)
System.out.print("[");
else
System.out.println();
for(int j =0; j<mtrx[i].length; j++) {
System.out.print("["+mtrx[i][j]+"]");
}
}
System.out.println("]");
}
}