Skip to content

Commit 1c42cc4

Browse files
goat000tensorflower-gardener
authored andcommitted
Make import_pb_to_tensorboard executable
Add main, FLAGS, and BUILD rules. PiperOrigin-RevId: 160434643
1 parent 6485741 commit 1c42cc4

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

tensorflow/python/tools/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ py_binary(
4444
],
4545
)
4646

47+
py_binary(
48+
name = "import_pb_to_tensorboard",
49+
srcs = ["import_pb_to_tensorboard.py"],
50+
srcs_version = "PY2AND3",
51+
deps = [
52+
"//tensorflow/core:protos_all_py",
53+
"//tensorflow/python:client",
54+
"//tensorflow/python:framework",
55+
"//tensorflow/python:framework_ops",
56+
"//tensorflow/python:platform",
57+
"//tensorflow/python:summary",
58+
],
59+
)
60+
4761
py_test(
4862
name = "freeze_graph_test",
4963
size = "small",

tensorflow/python/tools/import_pb_to_tensorboard.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import argparse
22+
import sys
23+
2124
from tensorflow.core.framework import graph_pb2
2225
from tensorflow.python.client import session
2326
from tensorflow.python.framework import importer
2427
from tensorflow.python.framework import ops
28+
from tensorflow.python.platform import app
2529
from tensorflow.python.platform import gfile
2630
from tensorflow.python.summary import summary
2731

@@ -48,3 +52,23 @@ def import_to_tensorboard(model_dir, log_dir):
4852
pb_visual_writer.add_graph(sess.graph)
4953
print("Model Imported. Visualize by running: "
5054
"> tensorboard --logdir={}".format(log_dir))
55+
56+
57+
def main(unused_args):
58+
import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir)
59+
60+
if __name__ == "__main__":
61+
parser = argparse.ArgumentParser()
62+
parser.register("type", "bool", lambda v: v.lower() == "true")
63+
parser.add_argument(
64+
"--model_dir",
65+
type=str,
66+
default="",
67+
help="The location of the protobuf (\'pb\') model to visualize.")
68+
parser.add_argument(
69+
"--log_dir",
70+
type=str,
71+
default="",
72+
help="The location for the Tensorboard log to begin visualization from.")
73+
FLAGS, unparsed = parser.parse_known_args()
74+
app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 commit comments

Comments
 (0)