Skip to content

Commit eb281da

Browse files
Fix unsuccessful opdef lookup when function was not created via Defun.
Change: 149826365
1 parent 49ecd2a commit eb281da

3 files changed

Lines changed: 68 additions & 7 deletions

File tree

tensorflow/python/framework/function.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,7 @@ def _get_node_def(op):
6868

6969

7070
def _get_op_def(op):
71-
# pylint: disable=protected-access
72-
if hasattr(op, "_sig"):
73-
return getattr(op, "_sig")
74-
else:
75-
return op_def_registry.get_registered_ops()[op.type]
76-
# pylint: enable=protected-access
71+
return op.op_def or op_def_registry.get_registered_ops()[op.type]
7772

7873

7974
def _is_in_placeholders(op, func_arg_placeholders):
@@ -248,8 +243,8 @@ def _call(sig, *inputs, **kwargs):
248243
output_types,
249244
name=name,
250245
attrs=attrs,
246+
op_def=sig,
251247
compute_shapes=False)
252-
setattr(op, "_sig", sig) # Remember the signature.
253248
if op.outputs:
254249
if len(op.outputs) == 1:
255250
ret = op.outputs[0]
@@ -452,6 +447,7 @@ def __init__(self,
452447
self._shape_func = shape_func
453448
self._extra_kwargs = kwargs
454449
self._definition = None # Constructed lazily.
450+
self._sub_functions = dict() # Constructed with definition.
455451

456452
self._args = []
457453
assert isinstance(input_types, (list, tuple))

tensorflow/python/framework/function_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,48 @@ def BN1(x):
526526
self.assertAllClose(vals[0], vals[1])
527527
self.assertAllClose(vals[2], vals[3])
528528

529+
def testDeclare(self):
530+
foo = function.Declare("Foo", [("x", dtypes.float32)],
531+
[("y", dtypes.float32)])
532+
533+
@function.Defun(dtypes.float32, func_name="Foo", out_names=["y"])
534+
def FooImpl(x):
535+
return x * x + 1
536+
537+
x = array_ops.placeholder(dtypes.float32)
538+
y = foo(x)
539+
540+
g = ops.get_default_graph()
541+
FooImpl.add_to_graph(g)
542+
543+
with self.test_session():
544+
rand = np.random.uniform(size=(3, 3))
545+
expected = rand * rand + 1.0
546+
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
547+
548+
def testDeclareUsedInDefun(self):
549+
foo = function.Declare("Foo", [("x", dtypes.float32)],
550+
[("y", dtypes.float32)])
551+
552+
@function.Defun()
553+
def Bar(x):
554+
return foo(x)
555+
556+
@function.Defun(dtypes.float32, func_name="Foo", out_names=["y"])
557+
def FooImpl(x):
558+
return x * x + 1
559+
560+
x = array_ops.placeholder(dtypes.float32)
561+
y = Bar(x)
562+
563+
g = ops.get_default_graph()
564+
FooImpl.add_to_graph(g)
565+
566+
with self.test_session():
567+
rand = np.random.uniform(size=(3, 3))
568+
expected = rand * rand + 1.0
569+
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
570+
529571
def testDeclareTypeMistake(self):
530572
foo = function.Declare("Foo", [("x", dtypes.float32)],
531573
[("y", dtypes.float32)])

tensorflow/python/framework/importer_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,29 @@ def InnerFunc(x):
981981
self.assertEqual(sess.run("external:0"), 11)
982982
self.assertEqual(sess.run("outer:0"), 21)
983983

984+
def testImportInsideDefun(self):
985+
g = ops.Graph()
986+
with g.as_default():
987+
@function.Defun()
988+
def Add2(x, y):
989+
return math_ops.add(x, y)
990+
991+
x = constant_op.constant(3.0, dtype=dtypes.float32)
992+
y = constant_op.constant(-5.0, dtype=dtypes.float32)
993+
z = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg
994+
995+
gdef = g.as_graph_def()
996+
997+
@function.Defun()
998+
def TestFunc():
999+
return importer.import_graph_def(gdef, return_elements=["z:0"])[0]
1000+
1001+
z = TestFunc()
1002+
1003+
with self.test_session():
1004+
z_val = z.eval()
1005+
self.assertEqual(z_val, -2.0)
1006+
9841007

9851008
if __name__ == "__main__":
9861009
test.main()

0 commit comments

Comments
 (0)