Skip to content

Commit 1c2bcf9

Browse files
mrrytensorflower-gardener
authored andcommitted
Fix bug in kernel creation with functions marked "stateful".
The CallOp kernel caches a handle for invoking the function. This handle is only valid in a single subgraph (it is scoped to the FunctionLibraryRuntime). Marking a function as stateful causes its CallOp kernel to be shared between multiple subgraphs. Therefore, this change overrides the kernel creation logic to ensure that each subgraph gets its own CallOp. PiperOrigin-RevId: 178820064
1 parent d0a4a79 commit 1c2bcf9

3 files changed

Lines changed: 43 additions & 4 deletions

File tree

tensorflow/core/common_runtime/direct_session.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,8 +1201,14 @@ Status DirectSession::GetOrCreateExecutors(
12011201
auto opseg = device->op_segment();
12021202
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
12031203
OpKernel** kernel) {
1204-
// Caches the kernel only if the node is stateful.
1205-
if (!lib->IsStateful(ndef.op())) {
1204+
// We do not share the kernel via the OpSegment if the node is
1205+
// stateless, or a function.
1206+
// NOTE(mrry): We must not share function kernels (implemented
1207+
// using `CallOp`) between subgraphs, because `CallOp::handle_`
1208+
// is tied to a particular subgraph. Even if the function itself
1209+
// is stateful, the `CallOp` that invokes it is not.
1210+
if (!lib->IsStateful(ndef.op()) ||
1211+
lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
12061212
return lib->CreateKernel(ndef, kernel);
12071213
}
12081214
auto create_fn = [lib, &ndef](OpKernel** kernel) {

tensorflow/core/distributed_runtime/graph_mgr.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,14 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
228228
params.function_library = lib;
229229
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
230230
OpKernel** kernel) {
231-
// Caches the kernel only if the node is stateful.
232-
if (!lib->IsStateful(ndef.op())) {
231+
// We do not share the kernel via the OpSegment if the node is
232+
// stateless, or a function.
233+
// NOTE(mrry): We must not share function kernels (implemented
234+
// using `CallOp`) between subgraphs, because `CallOp::handle_`
235+
// is tied to a particular subgraph. Even if the function itself
236+
// is stateful, the `CallOp` that invokes it is not.
237+
if (!lib->IsStateful(ndef.op()) ||
238+
lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
233239
return lib->CreateKernel(ndef, kernel);
234240
}
235241
auto create_fn = [lib, &ndef](OpKernel** kernel) {

tensorflow/python/framework/function_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,33 @@ def Foo(t, x):
914914
np.array([1.0, 0.0]).astype(np.float32),
915915
sess.run(dinp, {inp: x}))
916916

917+
def testFunctionMarkedStateful(self):
918+
919+
@function.Defun(dtypes.int32, dtypes.float32)
920+
def Foo(t, x):
921+
return x[t]
922+
923+
@function.Defun(dtypes.int64)
924+
def Bar(x):
925+
return x
926+
927+
# NOTE(mrry): All functions are currently considered stateless by the
928+
# runtime, so we simulate a "stateful" function.
929+
# TODO(b/70565970): Remove this hack when we are able to build stateful
930+
# functions using the API.
931+
# pylint: disable=protected-access
932+
Foo._signature.is_stateful = True
933+
Bar._signature.is_stateful = True
934+
# pylint: enable=protected-access
935+
936+
result_1 = Foo(3, [1.0, 2.0, 3.0, 4.0])
937+
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
938+
939+
with session.Session() as sess:
940+
self.assertEqual(4.0, sess.run(result_1))
941+
self.assertEqual(100, sess.run(result_2))
942+
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
943+
917944

918945
@test_util.with_c_api
919946
class FunctionsFromProtos(test.TestCase):

0 commit comments

Comments
 (0)