@@ -255,21 +255,24 @@ def compile(self, optimizer, clipnorm, loss='mse'):
255255 # TODO(KGF): check the following import taken from runner.py
256256 # Was not in this file, originally.
257257 from tensorflow .keras .optimizers import (
258- SGD , Adam , RMSprop , Nadam , TFOptimizer
258+ SGD , Adam , RMSprop , Nadam
259259 )
260260 if optimizer == 'sgd' :
261261 optimizer_class = SGD (lr = self .DUMMY_LR , clipnorm = clipnorm )
262262 elif optimizer == 'momentum_sgd' :
263263 optimizer_class = SGD (lr = self .DUMMY_LR , clipnorm = clipnorm ,
264264 decay = 1e-6 , momentum = 0.9 )
265265 elif optimizer == 'tf_momentum_sgd' :
266- optimizer_class = TFOptimizer (tf .train .MomentumOptimizer (
267- learning_rate = self .DUMMY_LR , momentum = 0.9 ))
266+ # TODO(KGF): removed TFOptimizer wrapper from here and below
267+ # may not work anymore? See
268+ # https://github.com/tensorflow/tensorflow/issues/22780
269+ optimizer_class = tf .train .MomentumOptimizer (
270+ learning_rate = self .DUMMY_LR , momentum = 0.9 )
268271 elif optimizer == 'adam' :
269272 optimizer_class = Adam (lr = self .DUMMY_LR , clipnorm = clipnorm )
270273 elif optimizer == 'tf_adam' :
271- optimizer_class = TFOptimizer ( tf .train .AdamOptimizer (
272- learning_rate = self .DUMMY_LR ))
274+ optimizer_class = tf .train .AdamOptimizer (
275+ learning_rate = self .DUMMY_LR )
273276 elif optimizer == 'rmsprop' :
274277 optimizer_class = RMSprop (lr = self .DUMMY_LR , clipnorm = clipnorm )
275278 elif optimizer == 'nadam' :
0 commit comments