2929FLAGS = None
3030
3131
32- def print_tensors_in_checkpoint_file (file_name , tensor_name , all_tensors ):
32+ def print_tensors_in_checkpoint_file (file_name , tensor_name , all_tensors ,
33+ all_tensor_names ):
3334 """Prints tensors in a checkpoint file.
3435
3536 If no `tensor_name` is provided, prints the tensor names and shapes
@@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
4142 file_name: Name of the checkpoint file.
4243 tensor_name: Name of the tensor in the checkpoint file to print.
4344 all_tensors: Boolean indicating whether to print all tensors.
45+ all_tensor_names: Boolean indicating whether to print all tensor names.
4446 """
4547 try :
4648 reader = pywrap_tensorflow .NewCheckpointReader (file_name )
47- if all_tensors :
49+ if all_tensors or all_tensor_names :
4850 var_to_shape_map = reader .get_variable_to_shape_map ()
4951 for key in sorted (var_to_shape_map ):
5052 print ("tensor_name: " , key )
51- print (reader .get_tensor (key ))
53+ if all_tensors :
54+ print (reader .get_tensor (key ))
5255 elif not tensor_name :
5356 print (reader .debug_string ().decode ("utf-8" ))
5457 else :
@@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str):
104107def main (unused_argv ):
105108 if not FLAGS .file_name :
106109 print ("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
107- "[--tensor_name=tensor_to_print]" )
110+ "[--tensor_name=tensor_to_print] "
111+ "[--all_tensors] "
112+ "[--all_tensor_names] "
113+ "[--printoptions]" )
108114 sys .exit (1 )
109115 else :
110116 print_tensors_in_checkpoint_file (FLAGS .file_name , FLAGS .tensor_name ,
111- FLAGS .all_tensors )
117+ FLAGS .all_tensors , FLAGS . all_tensor_names )
112118
113119
114120if __name__ == "__main__" :
@@ -130,6 +136,13 @@ def main(unused_argv):
130136 type = "bool" ,
131137 default = False ,
132138 help = "If True, print the values of all the tensors." )
139+ parser .add_argument (
140+ "--all_tensor_names" ,
141+ nargs = "?" ,
142+ const = True ,
143+ type = "bool" ,
144+ default = False ,
145+ help = "If True, print the names of all the tensors." )
133146 parser .add_argument (
134147 "--printoptions" ,
135148 nargs = "*" ,
0 commit comments