Skip to content

Commit 9624d16

Browse files
iganichevtensorflower-gardener
authored andcommitted
Add function support to Tensorflow C API
This change adds minimal functionality. Support for FunctionOptions, attributes, output name rewriting, function name generation, etc is comming next. PiperOrigin-RevId: 167091238
1 parent 424aa9a commit 9624d16

18 files changed

Lines changed: 2072 additions & 29 deletions

tensorflow/c/BUILD

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@ tf_cuda_library(
4545

4646
tf_cuda_library(
4747
name = "c_api",
48-
srcs = ["c_api.cc"],
49-
hdrs = ["c_api.h"],
48+
srcs = [
49+
"c_api.cc",
50+
"c_api_function.cc",
51+
],
52+
hdrs = [
53+
"c_api.h",
54+
],
5055
copts = tf_copts(),
5156
visibility = ["//visibility:public"],
5257
deps = select({
@@ -157,6 +162,21 @@ tf_cc_test(
157162
],
158163
)
159164

165+
tf_cc_test(
166+
name = "c_api_function_test",
167+
size = "small",
168+
srcs = ["c_api_function_test.cc"],
169+
deps = [
170+
":c_api",
171+
":c_test_util",
172+
"//tensorflow/core:lib",
173+
"//tensorflow/core:lib_internal",
174+
"//tensorflow/core:protos_all_cc",
175+
"//tensorflow/core:test",
176+
"//tensorflow/core:test_main",
177+
],
178+
)
179+
160180
tf_cc_test(
161181
name = "while_loop_test",
162182
size = "small",

tensorflow/c/c_api.cc

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
165165
tensorflow::cpu_allocator()->DeallocateRaw(data);
166166
}
167167

168-
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
169-
TF_Buffer* out) {
170-
if (out->data != nullptr) {
171-
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
172-
}
173-
const auto proto_size = in.ByteSizeLong();
174-
void* buf = tensorflow::port::Malloc(proto_size);
175-
in.SerializeToArray(buf, proto_size);
176-
out->data = buf;
177-
out->length = proto_size;
178-
out->data_deallocator = [](void* data, size_t length) {
179-
tensorflow::port::Free(data);
180-
};
181-
return Status::OK();
182-
}
183-
184168
} // namespace
185169

186170
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
@@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
559543
dimvec.size(), base, size, DeleteArray, base);
560544
}
561545

546+
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
547+
TF_Buffer* out) {
548+
if (out->data != nullptr) {
549+
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
550+
}
551+
const size_t proto_size = in.ByteSizeLong();
552+
void* buf = tensorflow::port::Malloc(proto_size);
553+
if (buf == nullptr) {
554+
return tensorflow::errors::ResourceExhausted(
555+
"Failed to allocate memory to serialize message of type '",
556+
in.GetTypeName(), "' and size ", proto_size);
557+
}
558+
in.SerializeToArray(buf, proto_size);
559+
out->data = buf;
560+
out->length = proto_size;
561+
out->data_deallocator = [](void* data, size_t length) {
562+
tensorflow::port::Free(data);
563+
};
564+
return Status::OK();
565+
}
566+
562567
// Helpers for loading a TensorFlow plugin (a .so file).
563568
Status LoadLibrary(const char* library_filename, void** result,
564569
const void** buf, size_t* len);

tensorflow/c/c_api.h

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,14 @@ typedef struct TF_Output {
357357
int index; // The index of the output within oper.
358358
} TF_Output;
359359

360+
// TF_Function is a grouping of operations with defined inputs and outputs.
361+
// Once created and added to graphs, functions can be invoked by creating an
362+
// operation whose operation type matches the function name.
363+
typedef struct TF_Function TF_Function;
364+
365+
// Function definition options. TODO(iga): Define and implement
366+
typedef struct TF_FunctionOptions TF_FunctionOptions;
367+
360368
// Sets the shape of the Tensor referenced by `output` in `graph` to
361369
// the shape described by `dims` and `num_dims`.
362370
//
@@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
914922
TF_Graph* graph, const TF_Buffer* graph_def,
915923
const TF_ImportGraphDefOptions* options, TF_Status* status);
916924

925+
// Add `function` to graph `g`. Once `function` is added to `g`,
926+
// it can be called by creating an operation using the function's name.
927+
//
928+
// If successful, status is set to OK and function is added to g
929+
// Otherwise, status is set to the encountered error and g is unmodified
930+
TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g,
931+
const TF_Function* function,
932+
TF_Status* status);
933+
917934
// Note: The following function may fail on very large protos in the future.
918935

919936
TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper,
@@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
10011018
TF_Output* x, int nx, TF_Output* dx,
10021019
TF_Status* status, TF_Output* dy);
10031020

1021+
// Create a TF_Function from a TF_Graph
1022+
//
1023+
// Params:
1024+
// fn_body - the graph whose operations (or subset of whose operations) will be
1025+
// converted to TF_Function.
1026+
// fn_name - the name of the new TF_Function. Should match the operation
1027+
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
1028+
// from other operation names (at least those registered in graphs
1029+
// where this function will be used).
1030+
// TODO(iga): Allow null in here and have C API come up with
1031+
// a unique name with high probability (similarly to
1032+
// _create_hash_str in function.py)
1033+
// num_opers - `num_opers` contains the number of elements in the `opers` array
1034+
// or a special value of -1 meaning that no array is given.
1035+
// The distinction between an empty array of operations and no
1036+
// array of operations is necessary to distinguish the case of
1037+
// creating a function with no body (e.g. identity or permutation)
1038+
// and the case of creating a function whose body contains all
1039+
// the nodes in the graph (except for the automatic skipping, see
1040+
// below).
1041+
// opers - Array of operations to become the body of the function or null.
1042+
// - If no array is given (`num_opers` = -1), all the
1043+
// operations in `fn_body` will become part of the function
1044+
// except operations referenced in `inputs`. These operations
1045+
// must have a single output (these operations are typically
1046+
// placeholders created for the sole purpose of representing
1047+
// an input. We can relax this constraint if there are
1048+
// compelling use cases).
1049+
// - If an array is given (`num_opers` >= 0), all operations
1050+
// in it will become part of the function. In particular, no
1051+
// automatic skipping of dummy input operations is performed.
1052+
// ninputs - number of elements in `inputs` array
1053+
// inputs - array of TF_Outputs that specify the inputs to the function.
1054+
// If `ninputs` is zero (the function takes no inputs), `inputs`
1055+
// can be null. The names used for function inputs are normalized
1056+
// names of the operations (usually placeholders) pointed to by
1057+
// `inputs`. These operation names should start with a letter.
1058+
// Normalization will convert all letters to lowercase and
1059+
// non-alphanumeric characters to '_' to make resulting names match
1060+
// the "[a-z][a-z0-9_]*" pattern for operation argument names.
1061+
// `inputs` cannot contain the same tensor twice.
1062+
// noutputs - number of elements in `outputs` array
1063+
// outputs - array of TF_Outputs that specify the outputs of the function.
1064+
// If `noutputs` is zero (the function returns no outputs), `outputs`
1065+
// can be null. `outputs` can contain the same tensor more than once.
1066+
// output_names - The names of the function's outputs. `output_names` array
1067+
// must either have the same length as `outputs`
1068+
// (i.e. `noutputs`) or be null. In the former case,
1069+
// the names should match the regular expression for ArgDef
1070+
// names - "[a-z][a-z0-9_]*". In the latter case,
1071+
// names for outputs will be generated automatically.
1072+
// opts - various options for the function, e.g. XLA's inlining control.
1073+
// status - Set to OK on success and an appropriate error on failure.
1074+
//
1075+
// Note that when the same TF_Output is listed as both an input and an output,
1076+
// the corresponding function's output will equal to this input,
1077+
// instead of the original node's output.
1078+
//
1079+
// Callers must also satisfy the following constraints:
1080+
// - `inputs` cannot refer to TF_Outputs within a control flow context. For
1081+
// example, one cannot use the output of "switch" node as input.
1082+
// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`)
1083+
// is allowed to have a reference type. Reference types are not exposed
1084+
// through C API and are being deprecated.
1085+
// - Every node in the function's body must have all of its inputs (including
1086+
// control inputs). In other words, for every node in the body, each input
1087+
// must be either listed in `inputs` or must come from another node in
1088+
// the body. In particular, it is an error to have a control edge going from
1089+
// a node outside of the body into a node in the body. This applies to control
1090+
// edges going from nodes referenced in `inputs` to nodes in the body when
1091+
// the former nodes are not in the body (automatically skipped or not
1092+
// included in explicitly specified body).
1093+
//
1094+
// Returns:
1095+
// On successful, a newly created TF_Function instance. It must be deleted by
1096+
// calling TF_DeleteFunction.
1097+
//
1098+
// On failure, null.
1099+
//
1100+
// TODO(iga): Add input_names argument and get output_names working (they are
1101+
// currently ignored)
1102+
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
1103+
const TF_Graph* fn_body, const char* fn_name, int num_opers,
1104+
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
1105+
int noutputs, const TF_Output* outputs, const char* const* output_names,
1106+
const TF_FunctionOptions* opts, TF_Status* status);
1107+
1108+
// Write out a serialized representation of `func` (as a FunctionDef protocol
1109+
// message) to `output_func_def` (allocated by TF_NewBuffer()).
1110+
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
1111+
// is called.
1112+
//
1113+
// May fail on very large graphs in the future.
1114+
TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
1115+
TF_Buffer* output_func_def,
1116+
TF_Status* status);
1117+
1118+
TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*);
1119+
10041120
// TODO(josh11b): Register OpDef, available to all operations added
10051121
// to this graph.
10061122

0 commit comments

Comments
 (0)