11from __future__ import division , print_function
22import plasma .global_vars as g
33# KGF: the first time Keras is ever imported via mpi_learn.py -> mpi_runner.py
4- import keras .backend as K
4+ # -> builder.py (here)
5+ import tensorflow as tf
56# KGF: see below synchronization--- output is launched here
6- from keras . models import Model # , Sequential
7+ #
78# KGF: (was used only in hyper_build_model())
8- from keras .layers import Input
9- from keras . layers . core import (
9+ from tensorflow . keras .layers import (
10+ Input ,
1011 Dense , Activation , Dropout , Lambda ,
1112 Reshape , Flatten , Permute , # RepeatVector
13+ LSTM , CuDNNLSTM , SimpleRNN , BatchNormalization ,
14+ Convolution1D , MaxPooling1D , TimeDistributed ,
15+ Concatenate
1216 )
13- from keras .layers import LSTM , CuDNNLSTM , SimpleRNN , BatchNormalization
14- from keras .layers .convolutional import Convolution1D
15- from keras .layers .pooling import MaxPooling1D
16- # from keras.utils.data_utils import get_file
17- from keras .layers .wrappers import TimeDistributed
18- from keras .layers .merge import Concatenate
19- from keras .callbacks import Callback
20- from keras .regularizers import l2 # l1, l1_l2
17+ from tensorflow .keras .callbacks import Callback
18+ from tensorflow .keras .regularizers import l2 # l1, l1_l2
2119
2220import re
2321import os
@@ -275,7 +273,7 @@ def slicer_output_shape(input_shape, indices):
275273 bias_regularizer = l2 (dense_regularization ),
276274 activity_regularizer = l2 (dense_regularization ))(pre_rnn )
277275
278- pre_rnn_model = Model (inputs = pre_rnn_input , outputs = pre_rnn )
276+ pre_rnn_model = tf . keras . Model (inputs = pre_rnn_input , outputs = pre_rnn )
279277 # TODO(KGF): uncomment following lines to get summary of pre-RNN model
280278 # from mpi4py import MPI
281279 # comm = MPI.COMM_WORLD
@@ -344,16 +342,17 @@ def slicer_output_shape(input_shape, indices):
344342 # x_out = TimeDistributed(Dense(100,activation='tanh')) (x_in)
345343 x_out = TimeDistributed (
346344 Dense (1 , activation = output_activation ))(x_in )
347- model = Model (inputs = x_input , outputs = x_out )
345+ model = tf . keras . Model (inputs = x_input , outputs = x_out )
348346 # bug with tensorflow/Keras
349347 # TODO(KGF): what is this bug? this is the only direct "tensorflow"
350348 # import outside of mpi_runner.py and runner.py
351- if (conf ['model' ]['backend' ] == 'tf'
352- or conf ['model' ]['backend' ] == 'tensorflow' ):
353- first_time = "tensorflow" not in sys .modules
354- import tensorflow as tf
355- if first_time :
356- K .get_session ().run (tf .global_variables_initializer ())
349+ # if (conf['model']['backend'] == 'tf'
350+ # or conf['model']['backend'] == 'tensorflow'):
351+ # first_time = "tensorflow" not in sys.modules
352+ # import tensorflow as tf
353+ # if first_time:
354+ # tf.compat.v1.keras.backend.get_session().run(
355+ # tf.global_variables_initializer())
357356 model .reset_states ()
358357 return model
359358
@@ -362,6 +361,8 @@ def build_train_test_models(self):
362361
363362 def save_model_weights (self , model , epoch ):
364363 save_path = self .get_save_path (epoch )
364+ full_model_save_path = self .get_save_path (epoch , ext = 'hdf5' )
365+ model .save (full_model_save_path , overwrite = True )
365366 model .save_weights (save_path , overwrite = True )
366367 # try:
367368 if _has_onnx :
@@ -425,6 +426,8 @@ def load_model_weights(self, model, custom_path=None):
425426 def extract_id_and_epoch_from_filename (self , filename ):
426427 regex = re .compile (r'-?\d+' )
427428 numbers = [int (x ) for x in regex .findall (filename )]
429+ # TODO: should ignore any files that dont match our naming convention
430+ # in this directory, especially since we are now writing full .hdf5 too
428431 if filename [- 3 :] == '.h5' :
429432 assert len (numbers ) == 3 # id, epoch number, and .h5 extension
430433 assert numbers [2 ] == 5 # .h5 extension
@@ -438,8 +441,8 @@ def get_all_saved_files(self):
438441 filenames = [name for name in os .listdir (path )
439442 if os .path .isfile (os .path .join (path , name ))]
440443 epochs = []
441- for file in filenames :
442- curr_id , epoch = self .extract_id_and_epoch_from_filename (file )
444+ for fname in filenames :
445+ curr_id , epoch = self .extract_id_and_epoch_from_filename (fname )
443446 if curr_id == unique_id :
444447 epochs .append (epoch )
445448 return epochs
0 commit comments