Skip to content

Commit 2ba3417

Browse files
tensorflower-gardeneraselle
authored andcommitted
Add a --all_tensor_names option, which is useful if I only want to know all tensor names. It is especially useful in cases whether some of the tensors has huge size. Also update the usage description.
PiperOrigin-RevId: 175074541
1 parent 6f7cf68 commit 2ba3417

1 file changed

Lines changed: 18 additions & 5 deletions

File tree

tensorflow/python/tools/inspect_checkpoint.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
FLAGS = None
3030

3131

32-
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
32+
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
33+
all_tensor_names):
3334
"""Prints tensors in a checkpoint file.
3435
3536
If no `tensor_name` is provided, prints the tensor names and shapes
@@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
4142
file_name: Name of the checkpoint file.
4243
tensor_name: Name of the tensor in the checkpoint file to print.
4344
all_tensors: Boolean indicating whether to print all tensors.
45+
all_tensor_names: Boolean indicating whether to print all tensor names.
4446
"""
4547
try:
4648
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
47-
if all_tensors:
49+
if all_tensors or all_tensor_names:
4850
var_to_shape_map = reader.get_variable_to_shape_map()
4951
for key in sorted(var_to_shape_map):
5052
print("tensor_name: ", key)
51-
print(reader.get_tensor(key))
53+
if all_tensors:
54+
print(reader.get_tensor(key))
5255
elif not tensor_name:
5356
print(reader.debug_string().decode("utf-8"))
5457
else:
@@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str):
104107
def main(unused_argv):
105108
if not FLAGS.file_name:
106109
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
107-
"[--tensor_name=tensor_to_print]")
110+
"[--tensor_name=tensor_to_print] "
111+
"[--all_tensors] "
112+
"[--all_tensor_names] "
113+
"[--printoptions]")
108114
sys.exit(1)
109115
else:
110116
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
111-
FLAGS.all_tensors)
117+
FLAGS.all_tensors, FLAGS.all_tensor_names)
112118

113119

114120
if __name__ == "__main__":
@@ -130,6 +136,13 @@ def main(unused_argv):
130136
type="bool",
131137
default=False,
132138
help="If True, print the values of all the tensors.")
139+
parser.add_argument(
140+
"--all_tensor_names",
141+
nargs="?",
142+
const=True,
143+
type="bool",
144+
default=False,
145+
help="If True, print the names of all the tensors.")
133146
parser.add_argument(
134147
"--printoptions",
135148
nargs="*",

0 commit comments

Comments
 (0)