Skip to content

Commit f9d4fa7

Browse files
committed
Update attribute and import for tf.keras compatibility
Still need to find out counterpart of "uses_learning_phase" attribute
1 parent 64a9c77 commit f9d4fa7

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

plasma/models/mpi_runner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
This file trains a deep learning model to predict
2020
disruptions on time series data from plasma discharges.
2121
22-
Dependencies:
23-
conf.py: configuration of model,training,paths, and data
24-
model_builder.py: logic to construct the ML architecture
25-
data_processing.py: classes to handle data processing
26-
2722
Author: Julian Kates-Harbeck, [email protected]
2823
2924
This work was supported by the DOE CSGF program.
@@ -94,7 +89,9 @@
9489
print('[{}] importing Keras'.format(g.task_index))
9590
import tensorflow.keras.backend as K
9691
from tensorflow.keras.utils import Progbar
97-
import tensorflow.keras.callbacks as cbks
92+
# TODO(KGF): instead of tensorflow.keras.callbacks.CallbackList()
93+
# until API added in tf-nightly in v2.2.0
94+
import tensorflow.python.keras.callbacks as cbks
9895

9996
g.flush_all_inorder()
10097
g.pprint_unique(conf)
@@ -1092,7 +1089,10 @@ def on_epoch_end(self, val_generator, val_steps, epoch, logs=None):
10921089
self.writer.add_summary(summary, epoch)
10931090
self.writer.flush()
10941091

1095-
tensors = (self.model.inputs + self.model.targets
1092+
# print(type(self.model))
1093+
# print(dir(self.model))
1094+
# KGF: targets attribute of Model class moved to private in tf.keras
1095+
tensors = (self.model.inputs + self.model._targets
10961096
+ self.model.sample_weights)
10971097

10981098
if self.model.uses_learning_phase:

0 commit comments

Comments
 (0)