@@ -360,10 +360,29 @@ def build_train_test_models(self):
360360 return self .build_model (False ), self .build_model (True )
361361
362362 def save_model_weights (self , model , epoch ):
363+ # Keras HDF5 weights only
363364 save_path = self .get_save_path (epoch )
364- full_model_save_path = self .get_save_path (epoch , ext = 'hdf5' )
365- model .save (full_model_save_path , overwrite = True )
366365 model .save_weights (save_path , overwrite = True )
366+ # Keras light-weight HDF5 format: model arch, weights, compile info
367+ full_model_save_path = self .get_save_path (epoch , ext = 'hdf5' )
368+ model .save (full_model_save_path ,
369+ overwrite = True , # default
370+ include_optimizer = True , # default
371+ save_format = None , # default, 'h5' in r1.15. Else 'tf'
372+ signatures = None , # applicable to 'tf' SavedModel format only
373+ )
374+ # TensorFlow SavedModel format (full directory)
375+ full_moodel_save_dir = full_model_save_path .rsplit ('.' ,1 )[0 ]
376+ # TODO(KGF): model.save(..., save_format='tf') disabled in r1.15
377+ # Same with tf.keras.models.save_model(..., save_format="tf").
378+ # Need to use experimental API until r2.x
379+ # model.save(full_model_save_dir, overwrite=True, save_format='tf')
380+ tf .keras .experimental .export_saved_model (model , full_moodel_save_dir ,
381+ custom_objects = None ,
382+ as_text = False ,
383+ input_signature = None ,
384+ serving_only = False
385+ )
367386 # try:
368387 if _has_onnx :
369388 save_path = self .get_save_path (epoch , ext = 'onnx' )
@@ -427,7 +446,8 @@ def extract_id_and_epoch_from_filename(self, filename):
427446 regex = re .compile (r'-?\d+' )
428447 numbers = [int (x ) for x in regex .findall (filename )]
429448 # TODO: should ignore any files that dont match our naming convention
430- # in this directory, especially since we are now writing full .hdf5 too
449+ # in this directory, especially since we are now writing full .hdf5 too.
450+ # Will crash the program if, e.g., a .tgz file is in that directory
431451 if filename [- 3 :] == '.h5' :
432452 assert len (numbers ) == 3 # id, epoch number, and .h5 extension
433453 assert numbers [2 ] == 5 # .h5 extension
0 commit comments