Skip to content

Commit 020501c

Browse files
alextptensorflower-gardener
authored andcommitted
Hoist resource variable reads outside functions.
Change: 149360110
1 parent 6bac753 commit 020501c

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

tensorflow/python/framework/function.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensorflow.python.framework import op_def_registry
3434
from tensorflow.python.framework import ops
3535
from tensorflow.python.ops import array_ops
36+
from tensorflow.python.ops import resource_variable_ops
3637
from tensorflow.python.ops import variable_scope as vs
3738
from tensorflow.python.util import compat
3839

@@ -324,6 +325,12 @@ def getvar(self,
324325
collections=collections,
325326
use_resource=use_resource)
326327
self.extra_vars.append(var)
328+
if isinstance(var, resource_variable_ops.ResourceVariable):
329+
# For resource-based variables read the variable outside the function
330+
# and pass in the value. This ensures that the function is pure and
331+
# differentiable. TODO(apassos) this may have performance problems if
332+
# the function will only do embedding lookups on the variable.
333+
return var.value()
327334
return var
328335

329336
def create_op(self, op_type, inputs, data_types, **kwargs):

tensorflow/python/framework/function_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,9 +1178,7 @@ def testBasic(self):
11781178
self._testSimpleModel(True)
11791179
self._testSimpleModel(False)
11801180

1181-
# TODO(b/35668241): disabled because resource variable handling inside
1182-
# functions does not work.
1183-
def DISABLED_testBasicResource(self):
1181+
def testBasicResource(self):
11841182
self._testSimpleModel(True, use_resource=True)
11851183
self._testSimpleModel(False, use_resource=True)
11861184

0 commit comments

Comments
 (0)