Skip to content

Commit bdab269

Browse files
iganichevtensorflower-gardener
authored andcommitted
Add append_hash_to_fn_name arg to TF_GraphToFunction
PiperOrigin-RevId: 170379490
1 parent 860b30b commit bdab269

7 files changed

Lines changed: 63 additions & 20 deletions

File tree

tensorflow/c/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ tf_cuda_library(
7272
"//tensorflow/core:framework",
7373
"//tensorflow/core:protos_all_cc",
7474
"//tensorflow/core:lib",
75+
"//tensorflow/core:lib_internal",
7576
],
7677
}),
7778
)

tensorflow/c/c_api.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,12 +1039,14 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
10391039
// fn_body - the graph whose operations (or subset of whose operations) will be
10401040
// converted to TF_Function.
10411041
// fn_name - the name of the new TF_Function. Should match the operation
1042-
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
1043-
// from other operation names (at least those registered in graphs
1044-
// where this function will be used).
1045-
// TODO(iga): Allow null in here and have C API come up with
1046-
// a unique name with high probability (similarly to
1047-
// _create_hash_str in function.py)
1042+
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*.
1043+
// If `append_hash_to_fn_name` is false, `fn_name` must be distinct
1044+
// from other function and operation names (at least those
1045+
// registered in graphs where this function will be used).
1046+
// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name
1047+
// of the function will be `fn_name` appended with
1048+
// '_<hash_of_this_function's_definition>'.
1049+
// If set to 0, the function's name will be `fn_name`.
10481050
// num_opers - `num_opers` contains the number of elements in the `opers` array
10491051
// or a special value of -1 meaning that no array is given.
10501052
// The distinction between an empty array of operations and no
@@ -1114,7 +1116,8 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
11141116
//
11151117
// On failure, null.
11161118
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
1117-
const TF_Graph* fn_body, const char* fn_name, int num_opers,
1119+
const TF_Graph* fn_body, const char* fn_name,
1120+
unsigned char append_hash_to_fn_name, int num_opers,
11181121
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
11191122
int noutputs, const TF_Output* outputs, const char* const* output_names,
11201123
const TF_FunctionOptions* opts, const char* description, TF_Status* status);

tensorflow/c/c_api_function.cc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/core/framework/node_def_util.h"
2626
#include "tensorflow/core/framework/types.h"
2727
#include "tensorflow/core/graph/graph.h"
28+
#include "tensorflow/core/lib/strings/base64.h"
2829
#include "tensorflow/core/lib/strings/strcat.h"
2930

3031
using tensorflow::errors::InvalidArgument;
@@ -232,6 +233,7 @@ Status FillFunctionBody(
232233
// Graph to FunctionDef conversion. This code is closely modeled on the Python
233234
// code in third_party/tensorflow/python/framework/function.py.
234235
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
236+
bool append_hash_to_fn_name,
235237
const std::vector<const Node*>& body_nodes,
236238
const std::vector<OutputTensor>& inputs,
237239
const std::vector<OutputTensor>& outputs,
@@ -241,7 +243,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
241243
DCHECK_EQ(output_names.size(), outputs.size());
242244
}
243245

244-
fdef->mutable_signature()->set_name(fn_name);
245246
if (description != nullptr) {
246247
fdef->mutable_signature()->set_description(description);
247248
}
@@ -328,7 +329,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
328329
// Remap return values.
329330
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
330331
const string& ret_name = fdef->signature().output_arg(r).name();
331-
332332
// We convert this flat tensor name to the nested value
333333
// (e.g. `add:z:1`) that we stored in tensor_renaming.
334334
const string& return_value =
@@ -343,6 +343,24 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
343343
(*fdef->mutable_ret())[ret_name] = iter->second;
344344
}
345345

346+
if (append_hash_to_fn_name) {
347+
const uint64 hash = FunctionDefHash(*fdef);
348+
string encoded;
349+
TF_RETURN_IF_ERROR(Base64Encode(
350+
StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
351+
&encoded));
352+
// Besides letters and digits our Base64 encoding uses '_' and '-'.
353+
// Dash is invalid in operation names and multiple underscores in random
354+
// places look strange. Since we never need to decode the hash back,
355+
// replace these chars with with 'a' and 'A'. Replacing with different
356+
// letters keeps more entropy.
357+
std::replace(encoded.begin(), encoded.end(), '-', 'a');
358+
std::replace(encoded.begin(), encoded.end(), '_', 'A');
359+
fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
360+
} else {
361+
fdef->mutable_signature()->set_name(fn_name);
362+
}
363+
346364
return Status::OK();
347365
}
348366

@@ -451,6 +469,7 @@ using tensorflow::Node;
451469
using tensorflow::string;
452470

453471
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
472+
unsigned char append_hash_to_fn_name,
454473
int num_opers, const TF_Operation* const* opers,
455474
int ninputs, const TF_Output* inputs,
456475
int noutputs, const TF_Output* outputs,
@@ -489,9 +508,11 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
489508

490509
// Do the actual function creation.
491510
TF_Function* tf_function = new TF_Function();
511+
DCHECK(append_hash_to_fn_name <= 1);
492512
status->status = tensorflow::GraphToFunctionDef(
493-
fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
494-
output_names_vec, description, &tf_function->fdef);
513+
fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
514+
input_tensors, output_tensors, output_names_vec, description,
515+
&tf_function->fdef);
495516
if (!status->status.ok()) {
496517
TF_DeleteFunction(tf_function);
497518
return nullptr;

tensorflow/c/c_api_function_test.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ class CApiFunctionTest : public ::testing::Test {
179179
bool expect_failure = false) {
180180
ASSERT_EQ(func_, nullptr);
181181
const char** output_names_ptr = ToArray(output_names);
182-
func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers,
182+
func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
183183
num_opers == -1 ? nullptr : opers.data(),
184184
inputs.size(), inputs.data(), outputs.size(),
185185
outputs.data(), output_names_ptr,
@@ -1200,7 +1200,8 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) {
12001200
}
12011201

12021202
void DefineFunction(const char* name, TF_Function** func,
1203-
const char* description = nullptr) {
1203+
const char* description = nullptr,
1204+
bool append_hash = false) {
12041205
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
12051206
TF_NewGraph(), TF_DeleteGraph);
12061207
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
@@ -1211,7 +1212,7 @@ void DefineFunction(const char* name, TF_Function** func,
12111212

12121213
TF_Output inputs[] = {{feed, 0}};
12131214
TF_Output outputs[] = {{neg, 0}};
1214-
*func = TF_GraphToFunction(func_graph.get(), name, -1,
1215+
*func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
12151216
/*opers=*/nullptr, 1, inputs, 1, outputs,
12161217
/*output_names=*/nullptr,
12171218
/*opts=*/nullptr, description, s.get());
@@ -1453,5 +1454,21 @@ TEST_F(CApiFunctionTest, Description) {
14531454
ASSERT_EQ(string("Return something"), fdef.signature().description());
14541455
}
14551456

1457+
TEST_F(CApiFunctionTest, Name) {
1458+
DefineFunction("long_func_name", &func_, "Return something",
1459+
/*append_hash=*/false);
1460+
tensorflow::FunctionDef fdef;
1461+
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1462+
ASSERT_EQ(string("long_func_name"), fdef.signature().name());
1463+
}
1464+
1465+
TEST_F(CApiFunctionTest, AppendHash) {
1466+
DefineFunction("func_name_base", &func_, "Return something",
1467+
/*append_hash=*/true);
1468+
tensorflow::FunctionDef fdef;
1469+
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1470+
ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
1471+
}
1472+
14561473
} // namespace
14571474
} // namespace tensorflow

tensorflow/python/client/tf_session_helper.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
348348
}
349349

350350
TF_Function* TF_GraphToFunction_wrapper(
351-
const TF_Graph* fn_body, const char* fn_name,
351+
const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
352352
const std::vector<TF_Operation*>* opers,
353353
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
354354
const NameVector& output_names, const TF_FunctionOptions* opts,
@@ -374,10 +374,10 @@ TF_Function* TF_GraphToFunction_wrapper(
374374
output_names.empty() ? nullptr
375375
: const_cast<const char**>(output_names.data());
376376

377-
return TF_GraphToFunction(fn_body, fn_name, nopers, opers_array,
378-
inputs.size(), inputs.data(), outputs.size(),
379-
outputs.data(), output_names_ptr, opts, description,
380-
out_status);
377+
return TF_GraphToFunction(fn_body, fn_name, append_hash_to_fn_name, nopers,
378+
opers_array, inputs.size(), inputs.data(),
379+
outputs.size(), outputs.data(), output_names_ptr,
380+
opts, description, out_status);
381381
}
382382

383383
} // namespace tensorflow

tensorflow/python/client/tf_session_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
153153
// `opers` equaling NULL are converted to `nopers = -1`.
154154
// `output_names` must be empty or have the same length as `outputs`.
155155
TF_Function* TF_GraphToFunction_wrapper(
156-
const TF_Graph* fn_body, const char* fn_name,
156+
const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
157157
const std::vector<TF_Operation*>* opers,
158158
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
159159
const NameVector& output_names, const TF_FunctionOptions* opts,

tensorflow/python/framework/function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def _create_definition_if_needed_impl(self):
363363
self._c_func = c_api.TF_GraphToFunction_wrapper(
364364
temp_graph._c_graph,
365365
self._func_name,
366+
False, # append_hash_to_fn_name
366367
None, # opers
367368
[t._as_tf_output() for t in inputs],
368369
[t._as_tf_output() for t in outputs],

0 commit comments

Comments
 (0)