Skip to content

Commit 1ceadbf

Browse files
committed
Save TensorFlow SavedModel directory format of trained model too
1 parent 417048c commit 1ceadbf

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

plasma/models/builder.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,29 @@ def build_train_test_models(self):
360360
return self.build_model(False), self.build_model(True)
361361

362362
def save_model_weights(self, model, epoch):
363+
# Keras HDF5 weights only
363364
save_path = self.get_save_path(epoch)
364-
full_model_save_path = self.get_save_path(epoch, ext='hdf5')
365-
model.save(full_model_save_path, overwrite=True)
366365
model.save_weights(save_path, overwrite=True)
366+
# Keras light-weight HDF5 format: model arch, weights, compile info
367+
full_model_save_path = self.get_save_path(epoch, ext='hdf5')
368+
model.save(full_model_save_path,
369+
overwrite=True, # default
370+
include_optimizer=True, # default
371+
save_format=None, # default, 'h5' in r1.15. Else 'tf'
372+
signatures=None, # applicable to 'tf' SavedModel format only
373+
)
374+
# TensorFlow SavedModel format (full directory)
375+
full_moodel_save_dir = full_model_save_path.rsplit('.',1)[0]
376+
# TODO(KGF): model.save(..., save_format='tf') disabled in r1.15
377+
# Same with tf.keras.models.save_model(..., save_format="tf").
378+
# Need to use experimental API until r2.x
379+
# model.save(full_model_save_dir, overwrite=True, save_format='tf')
380+
tf.keras.experimental.export_saved_model(model, full_moodel_save_dir,
381+
custom_objects=None,
382+
as_text=False,
383+
input_signature=None,
384+
serving_only=False
385+
)
367386
# try:
368387
if _has_onnx:
369388
save_path = self.get_save_path(epoch, ext='onnx')
@@ -427,7 +446,8 @@ def extract_id_and_epoch_from_filename(self, filename):
427446
regex = re.compile(r'-?\d+')
428447
numbers = [int(x) for x in regex.findall(filename)]
429448
# TODO: should ignore any files that dont match our naming convention
430-
# in this directory, especially since we are now writing full .hdf5 too
449+
# in this directory, especially since we are now writing full .hdf5 too.
450+
# Will crash the program if, e.g., a .tgz file is in that directory
431451
if filename[-3:] == '.h5':
432452
assert len(numbers) == 3 # id, epoch number, and .h5 extension
433453
assert numbers[2] == 5 # .h5 extension

0 commit comments

Comments
 (0)