@@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test {
372372 TF_DeleteBuffer (buf);
373373 }
374374
375+ void GetAttr (const char * attr_name, AttrValue* out_attr) {
376+ TF_Buffer* attr_buf = TF_NewBuffer ();
377+ TF_FunctionGetAttrValueProto (func_, attr_name, attr_buf, s_);
378+ ASSERT_TRUE (out_attr->ParseFromArray (attr_buf->data , attr_buf->length ));
379+ TF_DeleteBuffer (attr_buf);
380+ }
381+
375382 const char * func_name_ = " MyFunc" ;
376383 const char * func_node_name_ = " MyFunc_0" ;
377384 TF_Status* s_;
@@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
14061413 string (TF_Message (s_)));
14071414}
14081415
1416+ TEST_F (CApiFunctionTest, Attribute) {
1417+ DefineFunction (func_name_, &func_);
1418+
1419+ // Get non existent attribute
1420+ TF_Buffer* attr_buf = TF_NewBuffer ();
1421+ TF_FunctionGetAttrValueProto (func_, " foo_attr" , attr_buf, s_);
1422+ EXPECT_EQ (TF_INVALID_ARGUMENT, TF_GetCode (s_));
1423+ EXPECT_EQ (string (" Function 'MyFunc' has no attr named 'foo_attr'." ),
1424+ string (TF_Message (s_)));
1425+ TF_DeleteBuffer (attr_buf);
1426+
1427+ // Set attr
1428+ tensorflow::AttrValue attr;
1429+ attr.set_s (" test_attr_value" );
1430+ string bytes;
1431+ attr.SerializeToString (&bytes);
1432+ TF_FunctionSetAttrValueProto (func_, " test_attr_name" , bytes.data (),
1433+ bytes.size (), s_);
1434+ ASSERT_EQ (TF_OK, TF_GetCode (s_)) << TF_Message (s_);
1435+
1436+ // Get attr
1437+ AttrValue read_attr;
1438+ GetAttr (" test_attr_name" , &read_attr);
1439+ ASSERT_EQ (attr.DebugString (), read_attr.DebugString ());
1440+
1441+ // Retrieve the same attr after save/restore
1442+ Reincarnate ();
1443+ AttrValue read_attr2;
1444+ GetAttr (" test_attr_name" , &read_attr2);
1445+ ASSERT_EQ (attr.DebugString (), read_attr2.DebugString ());
1446+ }
1447+
14091448} // namespace
14101449} // namespace tensorflow
0 commit comments