-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmodel_restore_helper.py
More file actions
executable file
·78 lines (67 loc) · 3.55 KB
/
model_restore_helper.py
File metadata and controls
executable file
·78 lines (67 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Dict, Any, Optional, Type
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from dpu_utils.utils import RichPath
from models import Model, NeuralBoWModel, SelfAttentionModel, ConvolutionalModel, ConvSelfAttentionModel
#from models import Model, NeuralBoWModel, RNNModel, SelfAttentionModel, ConvolutionalModel, ConvSelfAttentionModel
from models import NeuralBoWModel_V1, CrossAttentionModel,NeuralBoWModel_V2
#from models import NeuralBoWModel_V1, CrossAttentionModel, RNNModel_V1,NeuralBoWModel_V2
def get_model_class_from_name(model_name: str) -> Type[Model]:
model_name = model_name.lower()
if model_name in ['neuralbow', 'neuralbowmodel']:
return NeuralBoWModel
elif model_name in ['neuralbow_v1', 'neuralbowmodel_v1']:
return NeuralBoWModel_V1
elif model_name in ['neuralbow_v2', 'neuralbowmodel_v2']:
return NeuralBoWModel_V2
elif model_name in ['rnn', 'rnnmodel']:
return RNNModel
elif model_name in ['rnn_v1', 'rnnmodel_v1']:
return RNNModel_V1
elif model_name in {'selfatt', 'selfattention', 'selfattentionmodel'}:
return SelfAttentionModel
elif model_name in {'1dcnn', 'convolutionalmodel'}:
return ConvolutionalModel
elif model_name in {'convselfatt', 'convselfattentionmodel'}:
return ConvSelfAttentionModel
elif model_name in {'crossatt', 'crossattention', 'crossattentionmodel'}:
return CrossAttentionModel
# elif model_name in {'polyatt', 'polyattention', 'polyattentionmodel'}:
# return PolyAttentionModel
else:
raise Exception("Unknown model '%s'!" % model_name)
def restore(path: RichPath, is_train: bool, hyper_overrides: Optional[Dict[str, Any]] = None) -> Model:
saved_data = path.read_as_pickle()
if hyper_overrides is not None:
saved_data['hyperparameters'].update(hyper_overrides)
model_class = get_model_class_from_name(saved_data['model_type'])
model = model_class(
saved_data['hyperparameters'], saved_data.get('run_name'))
model.query_metadata.update(saved_data['query_metadata'])
for (language, language_metadata) in saved_data['per_code_language_metadata'].items():
model.per_code_language_metadata[language] = language_metadata
model.make_model(is_train=is_train)
variables_to_initialize = []
with model.sess.graph.as_default():
with tf.name_scope("restore"):
restore_ops = []
used_vars = set()
for variable in sorted(model.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), key=lambda v: v.name):
used_vars.add(variable.name)
if variable.name in saved_data['weights']:
print('Initializing %s from saved value.' % variable.name)
restore_ops.append(variable.assign(
saved_data['weights'][variable.name]))
else:
print(
'Freshly initializing %s since no saved value was found.' % variable.name)
variables_to_initialize.append(variable)
for var_name in sorted(saved_data['weights']):
if var_name not in used_vars:
if var_name.endswith('Adam:0') or var_name.endswith('Adam_1:0') or var_name in ['beta1_power:0', 'beta2_power:0']:
continue
print('Saved weights for %s not used by model.' % var_name)
restore_ops.append(
tf.variables_initializer(variables_to_initialize))
model.sess.run(restore_ops)
return model