Skip to content

Commit 64a9c77

Browse files
committed
Remove TFOptimizer wrapper
1 parent ce68366 commit 64a9c77

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

plasma/models/mpi_runner.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,24 @@ def compile(self, optimizer, clipnorm, loss='mse'):
255255
# TODO(KGF): check the following import taken from runner.py
256256
# Was not in this file, originally.
257257
from tensorflow.keras.optimizers import (
258-
SGD, Adam, RMSprop, Nadam, TFOptimizer
258+
SGD, Adam, RMSprop, Nadam
259259
)
260260
if optimizer == 'sgd':
261261
optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm)
262262
elif optimizer == 'momentum_sgd':
263263
optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm,
264264
decay=1e-6, momentum=0.9)
265265
elif optimizer == 'tf_momentum_sgd':
266-
optimizer_class = TFOptimizer(tf.train.MomentumOptimizer(
267-
learning_rate=self.DUMMY_LR, momentum=0.9))
266+
# TODO(KGF): removed TFOptimizer wrapper from here and below
267+
# may not work anymore? See
268+
# https://github.com/tensorflow/tensorflow/issues/22780
269+
optimizer_class = tf.train.MomentumOptimizer(
270+
learning_rate=self.DUMMY_LR, momentum=0.9)
268271
elif optimizer == 'adam':
269272
optimizer_class = Adam(lr=self.DUMMY_LR, clipnorm=clipnorm)
270273
elif optimizer == 'tf_adam':
271-
optimizer_class = TFOptimizer(tf.train.AdamOptimizer(
272-
learning_rate=self.DUMMY_LR))
274+
optimizer_class = tf.train.AdamOptimizer(
275+
learning_rate=self.DUMMY_LR)
273276
elif optimizer == 'rmsprop':
274277
optimizer_class = RMSprop(lr=self.DUMMY_LR, clipnorm=clipnorm)
275278
elif optimizer == 'nadam':

plasma/utils/state_reset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def get_states(model):
77
if hasattr(layer, "states"):
88
layer_states = []
99
for state in layer.states:
10-
import keras.backend as K
10+
import tensorflow.keras.backend as K
1111
layer_states.append(K.get_value(state))
1212
all_states.append(layer_states)
1313
return all_states

0 commit comments

Comments
 (0)