Skip to content

Commit e4d2e4e

Browse files
committed
Add Serve TensorFlow Models in Java
1 parent 7c1cc34 commit e4d2e4e

5 files changed

Lines changed: 140 additions & 0 deletions

File tree

ml/serve_tf_model/model-server/.attach_pid14326

Whitespace-only changes.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<groupId>com.javahelps.tensorflow</groupId>
8+
<artifactId>model-server</artifactId>
9+
<version>1.0-SNAPSHOT</version>
10+
<build>
11+
<plugins>
12+
<plugin>
13+
<groupId>org.apache.maven.plugins</groupId>
14+
<artifactId>maven-compiler-plugin</artifactId>
15+
<configuration>
16+
<source>11</source>
17+
<target>11</target>
18+
</configuration>
19+
</plugin>
20+
</plugins>
21+
</build>
22+
23+
<dependencies>
24+
<dependency>
25+
<groupId>org.tensorflow</groupId>
26+
<artifactId>tensorflow</artifactId>
27+
<version>1.13.1</version>
28+
</dependency>
29+
</dependencies>
30+
31+
</project>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.javahelps.tensorflow.moderserver;
2+
3+
import org.tensorflow.SavedModelBundle;
4+
import org.tensorflow.Session;
5+
import org.tensorflow.Tensor;
6+
import org.tensorflow.TensorFlowException;
7+
8+
import java.util.List;
9+
10+
public class ModelServer {
11+
12+
public static void main(String[] args) {
13+
try (SavedModelBundle savedModelBundle = SavedModelBundle.load("/tmp/tf_add_model", "serve")) {
14+
15+
try (Session session = savedModelBundle.session()) {
16+
Session.Runner runner = session.runner();
17+
runner.feed("x", Tensor.create(10));
18+
runner.feed("y", Tensor.create(20));
19+
20+
List<Tensor<?>> tensors = runner.fetch("ans").run();
21+
System.out.println("Answer is: " + tensors.get(0).intValue());
22+
}
23+
24+
} catch (TensorFlowException ex) {
25+
ex.printStackTrace();
26+
}
27+
}
28+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
�
2+
��
3+
:
4+
Add
5+
x"T
6+
y"T
7+
z"T"
8+
Ttype:
9+
2  
10+
C
11+
Placeholder
12+
output"dtype"
13+
dtypetype"
14+
shapeshape:"serve*1.10.02v1.10.0-0-g656e7a2b34�
15+
F
16+
x Placeholder*
17+
dtype0*
18+
_output_shapes
19+
:*
20+
shape:
21+
F
22+
y Placeholder*
23+
dtype0*
24+
_output_shapes
25+
:*
26+
shape:
27+
3
28+
ansAddxy*
29+
T0*
30+
_output_shapes
31+
:"*i
32+
serving_defaultV
33+

34+
x
35+
x:0
36+

37+
y
38+
y:0
39+
ans
40+
ans:0tensorflow/serving/predict
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env python3
2+
import tensorflow as tf
3+
from tensorflow.python.saved_model import builder as saved_model_builder
4+
from tensorflow.python.saved_model import signature_constants
5+
from tensorflow.python.saved_model import signature_def_utils
6+
from tensorflow.python.saved_model import tag_constants
7+
from tensorflow.python.saved_model.utils import build_tensor_info
8+
9+
x = tf.placeholder(tf.int32, name='x')
10+
y = tf.placeholder(tf.int32, name='y')
11+
12+
# This is our model
13+
add = tf.add(x, y, name='ans')
14+
15+
with tf.Session() as sess:
16+
# Pick out the model input and output
17+
x_tensor = sess.graph.get_tensor_by_name('x:0')
18+
y_tensor = sess.graph.get_tensor_by_name('y:0')
19+
ans_tensor = sess.graph.get_tensor_by_name('ans:0')
20+
21+
x_info = build_tensor_info(x_tensor)
22+
y_info = build_tensor_info(y_tensor)
23+
ans_info = build_tensor_info(ans_tensor)
24+
25+
# Create a signature definition for tfserving
26+
signature_definition = signature_def_utils.build_signature_def(
27+
inputs={'x': x_info, 'y': y_info},
28+
outputs={'ans': ans_info},
29+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
30+
31+
builder = saved_model_builder.SavedModelBuilder('/tmp/tf_add_model')
32+
33+
builder.add_meta_graph_and_variables(
34+
sess, [tag_constants.SERVING],
35+
signature_def_map={
36+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
37+
signature_definition
38+
})
39+
40+
# Save the model so we can serve it with a model server :)
41+
builder.save()

0 commit comments

Comments
 (0)