forked from PaddlePaddle/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
79 lines (70 loc) · 2.64 KB
/
infer.py
File metadata and controls
79 lines (70 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gzip
import argparse
import itertools
import paddle.v2 as paddle
import network_conf
from train import dnn_layer_dims
import reader
from utils import logger, ModelType
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--model_gz_path',
type=str,
required=True,
help="path of model parameters gz file")
parser.add_argument(
'--data_path', type=str, required=True, help="path of the dataset to infer")
parser.add_argument(
'--prediction_output_path',
type=str,
required=True,
help="path to output the prediction")
parser.add_argument(
'--data_meta_path',
type=str,
default="./data.meta",
help="path of trainset's meta info, default is ./data.meta")
parser.add_argument(
'--model_type',
type=int,
required=True,
default=ModelType.CLASSIFICATION,
help='model type, classification: %d, regression %d (default classification)'
% (ModelType.CLASSIFICATION, ModelType.REGRESSION))
args = parser.parse_args()
paddle.init(use_gpu=False, trainer_count=1)
class CTRInferer(object):
def __init__(self, param_path):
logger.info("create CTR model")
dnn_input_dim, lr_input_dim = reader.load_data_meta(args.data_meta_path)
# create the mdoel
self.ctr_model = network_conf.CTRmodel(
dnn_layer_dims,
dnn_input_dim,
lr_input_dim,
model_type=ModelType(args.model_type),
is_infer=True)
# load parameter
logger.info("load model parameters from %s" % param_path)
self.parameters = paddle.parameters.Parameters.from_tar(
gzip.open(param_path, 'r'))
self.inferer = paddle.inference.Inference(
output_layer=self.ctr_model.model,
parameters=self.parameters, )
def infer(self, data_path):
logger.info("infer data...")
dataset = reader.Dataset()
infer_reader = paddle.batch(
dataset.infer(args.data_path), batch_size=1000)
logger.warning('write predictions to %s' % args.prediction_output_path)
output_f = open(args.prediction_output_path, 'w')
for id, batch in enumerate(infer_reader()):
res = self.inferer.infer(input=batch)
predictions = [x for x in itertools.chain.from_iterable(res)]
assert len(batch) == len(
predictions), "predict error, %d inputs, but %d predictions" % (
len(batch), len(predictions))
output_f.write('\n'.join(map(str, predictions)) + '\n')
if __name__ == '__main__':
ctr_inferer = CTRInferer(args.model_gz_path)
ctr_inferer.infer(args.data_path)