-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathpredict.py
More file actions
50 lines (39 loc) · 1.69 KB
/
predict.py
File metadata and controls
50 lines (39 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import tensorflow as tf
import numpy as np
import cv2
import os
def predict(path):
# 预测部分的初始化要跟模型部分的初始化一致
image_size = 64
num_channels = 3
images = []
# 待测试的图片,放到和程序同一目录下
image = cv2.imread(path)
image = cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
images.append(image)
images = np.array(images, dtype=np.uint8)
images = images.astype('float32')
images = np.multiply(images, 1.0 / 255.0)
# tf要求输入是4维的:第一个参数是1因为每次预测的就是一张图 1 64 64 3
x_batch = images.reshape(1, image_size, image_size, num_channels)
# 开启一个对话
sess = tf.Session()
# 拿到刚才保存的模型:直接输文件名,要.meta结尾的
saver = tf.train.import_meta_graph('./model/savemodel.ckpt-4960.meta')
# 拿刚才保存模型的权重参数:直接输文件名,后缀最长的那个
saver.restore(sess, './model/savemodel.ckpt-4960')
# writer=tf.compat.v1.summary.FileWriter('./improved_graph2',sess.graph)
# 同样导入标签:1和0
graph = tf.get_default_graph()
y_pred = graph.get_tensor_by_name("y_pred:0")
x = graph.get_tensor_by_name("x:0")
y_true = graph.get_tensor_by_name("y_true:0")
y_test_images = np.zeros((1, 4))
# 两个值:谁的概率大就是谁
feed_dict_testing = {x: x_batch, y_true: y_test_images}
# 预测开始:
result = sess.run(y_pred, feed_dict=feed_dict_testing)
# 把1和0又回到其对应的种类
res_label = ['沙漠', '公路', '房屋', '岛屿']
print(res_label[result.argmax()])
return res_label[result.argmax()]