Skip to content

Commit 7ad8e25

Browse files
iganichevtensorflower-gardener
authored andcommitted
Add attribute setting and getting support to TF_Function
PiperOrigin-RevId: 169337159
1 parent ed89a2b commit 7ad8e25

5 files changed

Lines changed: 101 additions & 5 deletions

File tree

tensorflow/c/c_api.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,24 @@ TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
11361136
TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef(
11371137
const TF_Buffer* func_def, TF_Status* status);
11381138

1139+
// Sets function attribute named `attr_name` to value stored in `proto`.
1140+
// If this attribute is already set to another value, it is overriden.
1141+
// `proto` should point to a sequence of bytes of length `proto_len`
1142+
// representing a binary serialization of an AttrValue protocol
1143+
// buffer.
1144+
TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func,
1145+
const char* attr_name,
1146+
const void* proto,
1147+
size_t proto_len,
1148+
TF_Status* status);
1149+
1150+
// Sets `output_attr_value` to the binary-serialized AttrValue proto
1151+
// representation of the value of the `attr_name` attr of `func`.
1152+
// If `attr_name` attribute is not present, status is set to an error.
1153+
TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto(
1154+
TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value,
1155+
TF_Status* status);
1156+
11391157
// Frees the memory used by the `func` struct.
11401158
// TF_DeleteFunction is a noop if `func` is null.
11411159
// Deleting a function does not remove it from any graphs it was copied to.

tensorflow/c/c_api_function.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,4 +545,31 @@ TF_Function* TF_FunctionImportFunctionDef(const TF_Buffer* func_def,
545545
return func;
546546
}
547547

548+
void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
549+
const void* proto, size_t proto_len,
550+
TF_Status* status) {
551+
tensorflow::AttrValue attr_value;
552+
if (!attr_value.ParseFromArray(proto, proto_len)) {
553+
status->status = InvalidArgument(
554+
"Unparseable AttrValue proto passed to "
555+
"TF_FunctionSetAttrValueProto");
556+
return;
557+
}
558+
(*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
559+
status->status = tensorflow::Status::OK();
560+
}
561+
562+
void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
563+
TF_Buffer* output_attr_value,
564+
TF_Status* status) {
565+
const auto& it = func->fdef.attr().find(attr_name);
566+
if (it == func->fdef.attr().end()) {
567+
status->status =
568+
InvalidArgument("Function '", func->fdef.signature().name(),
569+
"' has no attr named '", attr_name, "'.");
570+
return;
571+
}
572+
status->status = MessageToBuffer(it->second, output_attr_value);
573+
}
574+
548575
void TF_DeleteFunction(TF_Function* func) { delete func; }

tensorflow/c/c_api_function_test.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test {
372372
TF_DeleteBuffer(buf);
373373
}
374374

375+
void GetAttr(const char* attr_name, AttrValue* out_attr) {
376+
TF_Buffer* attr_buf = TF_NewBuffer();
377+
TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
378+
ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
379+
TF_DeleteBuffer(attr_buf);
380+
}
381+
375382
const char* func_name_ = "MyFunc";
376383
const char* func_node_name_ = "MyFunc_0";
377384
TF_Status* s_;
@@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
14061413
string(TF_Message(s_)));
14071414
}
14081415

1416+
TEST_F(CApiFunctionTest, Attribute) {
1417+
DefineFunction(func_name_, &func_);
1418+
1419+
// Get non existent attribute
1420+
TF_Buffer* attr_buf = TF_NewBuffer();
1421+
TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
1422+
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1423+
EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
1424+
string(TF_Message(s_)));
1425+
TF_DeleteBuffer(attr_buf);
1426+
1427+
// Set attr
1428+
tensorflow::AttrValue attr;
1429+
attr.set_s("test_attr_value");
1430+
string bytes;
1431+
attr.SerializeToString(&bytes);
1432+
TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
1433+
bytes.size(), s_);
1434+
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1435+
1436+
// Get attr
1437+
AttrValue read_attr;
1438+
GetAttr("test_attr_name", &read_attr);
1439+
ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
1440+
1441+
// Retrieve the same attr after save/restore
1442+
Reincarnate();
1443+
AttrValue read_attr2;
1444+
GetAttr("test_attr_name", &read_attr2);
1445+
ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
1446+
}
1447+
14091448
} // namespace
14101449
} // namespace tensorflow

tensorflow/python/framework/function.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,25 @@ def _create_definition_if_needed_impl(self):
422422
output_names,
423423
None, # opts
424424
status)
425+
self._set_c_attrs(kwargs_attr)
425426
# pylint: enable=protected-access
426427

428+
def _set_c_attrs(self, attrs):
429+
"""Sets `attrs` as attributes of self._c_func.
430+
431+
Requires that self._c_func is not None.
432+
433+
Args:
434+
attrs: a dictionary from attribute name to attribute proto value
435+
"""
436+
for name, attr_value in attrs.items():
437+
serialized = attr_value.SerializeToString()
438+
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
439+
# It might be worth creating a convenient way to re-use the same status.
440+
with errors.raise_exception_on_not_ok_status() as status:
441+
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
442+
serialized, status)
443+
427444
def _create_hash_str(self, input_arg, output_arg, node_def):
428445
"""Creates an 8-character string unique to this input.
429446

tensorflow/python/framework/function_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@ def Forward(x):
211211
out, = sess.run(dx, feed)
212212
self.assertAllClose(1 - np.square(np.tanh(inp)), out)
213213

214-
# C API functions don't support all optimizer options on cuda yet
215-
@test_util.skip_if(test_util.c_api_and_cuda_enabled)
216214
def testCustomGradient(self):
217215
dtype = dtypes.float32
218216

@@ -285,9 +283,6 @@ def testSymGradShape(self):
285283
self.assertEqual(x.get_shape(), dx.get_shape())
286284
self.assertEqual(y.get_shape(), dy.get_shape())
287285

288-
# C API functions don't support attributes yet (i.e. noinline).
289-
# This attribute is required to run sucessfully with cuda.
290-
@test_util.skip_if(test_util.c_api_and_cuda_enabled)
291286
def testSymGradAttr(self):
292287

293288
@function.Defun(noinline=True)

0 commit comments

Comments
 (0)