Skip to content

Commit ebb2785

Browse files
MarkDaousttensorflower-gardener
authored andcommitted
Give clear errors for bad input names.
PiperOrigin-RevId: 155857515
1 parent 8a13f77 commit ebb2785

2 files changed

Lines changed: 31 additions & 7 deletions

File tree

tensorflow/python/tools/strip_unused_lib.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
4141
a list that specifies one value per input node name.
4242
4343
Returns:
44-
A GraphDef with all unnecessary ops removed.
44+
A `GraphDef` with all unnecessary ops removed.
45+
46+
Raises:
47+
ValueError: If any element in `input_node_names` refers to a tensor instead
48+
of an operation.
49+
KeyError: If any element in `input_node_names` is not found in the graph.
4550
"""
51+
for name in input_node_names:
52+
if ":" in name:
53+
raise ValueError("Name '%s' appears to refer to a Tensor, "
54+
"not a Operation." % name)
55+
4656
# Here we replace the nodes we're going to override as inputs with
4757
# placeholders so that any unused nodes that are inputs to them are
4858
# automatically stripped out by extract_sub_graph().
59+
not_found = {name for name in input_node_names}
4960
inputs_replaced_graph_def = graph_pb2.GraphDef()
5061
for node in input_graph_def.node:
5162
if node.name in input_node_names:
63+
not_found.remove(node.name)
5264
placeholder_node = node_def_pb2.NodeDef()
5365
placeholder_node.op = "Placeholder"
5466
placeholder_node.name = node.name
@@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
6779
else:
6880
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
6981

82+
if not_found:
83+
raise KeyError("The following input nodes were not found: %s\n" % not_found)
84+
7085
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
7186
output_node_names)
7287
return output_graph_def

tensorflow/python/tools/strip_unused_test.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,25 @@ def testStripUnused(self):
5858
# routine.
5959
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
6060
input_binary = False
61-
input_node_names = "wanted_input_node"
6261
output_binary = True
6362
output_node_names = "output_node"
6463
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
6564

66-
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
67-
output_graph_path, output_binary,
68-
input_node_names,
69-
output_node_names,
70-
dtypes.float32.as_datatype_enum)
65+
def strip(input_node_names):
66+
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
67+
output_graph_path, output_binary,
68+
input_node_names,
69+
output_node_names,
70+
dtypes.float32.as_datatype_enum)
71+
72+
with self.assertRaises(KeyError):
73+
strip("does_not_exist")
74+
75+
with self.assertRaises(ValueError):
76+
strip("wanted_input_node:0")
77+
78+
input_node_names = "wanted_input_node"
79+
strip(input_node_names)
7180

7281
# Now we make sure the variable is now a constant, and that the graph still
7382
# produces the expected result.

0 commit comments

Comments
 (0)