@@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
4141 a list that specifies one value per input node name.
4242
4343 Returns:
44- A GraphDef with all unnecessary ops removed.
44+ A `GraphDef` with all unnecessary ops removed.
45+
46+ Raises:
47+ ValueError: If any element in `input_node_names` refers to a tensor instead
48+ of an operation.
49+ KeyError: If any element in `input_node_names` is not found in the graph.
4550 """
51+ for name in input_node_names :
52+ if ":" in name :
53+ raise ValueError ("Name '%s' appears to refer to a Tensor, "
54+ "not a Operation." % name )
55+
4656 # Here we replace the nodes we're going to override as inputs with
4757 # placeholders so that any unused nodes that are inputs to them are
4858 # automatically stripped out by extract_sub_graph().
59+ not_found = {name for name in input_node_names }
4960 inputs_replaced_graph_def = graph_pb2 .GraphDef ()
5061 for node in input_graph_def .node :
5162 if node .name in input_node_names :
63+ not_found .remove (node .name )
5264 placeholder_node = node_def_pb2 .NodeDef ()
5365 placeholder_node .op = "Placeholder"
5466 placeholder_node .name = node .name
@@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
6779 else :
6880 inputs_replaced_graph_def .node .extend ([copy .deepcopy (node )])
6981
82+ if not_found :
83+ raise KeyError ("The following input nodes were not found: %s\n " % not_found )
84+
7085 output_graph_def = graph_util .extract_sub_graph (inputs_replaced_graph_def ,
7186 output_node_names )
7287 return output_graph_def
0 commit comments