-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathinput_converter.py
More file actions
44 lines (36 loc) · 1.28 KB
/
input_converter.py
File metadata and controls
44 lines (36 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
import ast
import numpy as np
import msgpack
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_config', type=str, required=True)
parser.add_argument('--inputs', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
args = parser.parse_args()
inputs = args.inputs.split(',')
with open(args.model_config, 'rb') as f:
model_config = msgpack.unpackb(f.read())
input_idxes = model_config['inp_idxes']
scale_factor = model_config['global_sf']
# Get the input shapes from the layers
input_shapes = [[0] for _ in input_idxes]
for layer in model_config['layers']:
for layer_inp_idx, layer_shape in zip(layer['inp_idxes'], layer['inp_shapes']):
for index, inp_idx in enumerate(input_idxes):
if layer_inp_idx == inp_idx:
input_shapes[index] = layer_shape
tensors = []
for inp, shape, idx in zip(inputs, input_shapes, input_idxes):
tensor = np.load(inp).reshape(shape)
tensor = (tensor * scale_factor).round().astype(np.int64)
tensors.append({
'idx': idx,
'shape': shape,
'data': tensor.flatten().tolist(),
})
packed = msgpack.packb(tensors, use_bin_type=True)
with open(args.output, 'wb') as f:
f.write(packed)
if __name__ == '__main__':
main()