Skip to content

Commit f0b1a9e

Browse files
committed
Add phase support for draw net
1 parent f28f5ae commit f0b1a9e

2 files changed

Lines changed: 41 additions & 6 deletions

File tree

python/caffe/draw.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype):
127127
return color
128128

129129

130-
def get_pydot_graph(caffe_net, rankdir, label_edges=True):
130+
def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
131131
"""Create a data structure which represents the `caffe_net`.
132132
133133
Parameters
@@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
137137
Direction of graph layout.
138138
label_edges : boolean, optional
139139
Label the edges (default is True).
140+
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
141+
Include layers from this network phase. If None, include all layers.
142+
(the default is None)
140143
141144
Returns
142145
-------
@@ -148,6 +151,19 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
148151
pydot_nodes = {}
149152
pydot_edges = []
150153
for layer in caffe_net.layer:
154+
if phase is not None:
155+
included = False
156+
if len(layer.include) == 0:
157+
included = True
158+
if len(layer.include) > 0 and len(layer.exclude) > 0:
159+
raise ValueError('layer ' + layer.name + ' has both include '
160+
'and exclude specified.')
161+
for layer_phase in layer.include:
162+
included = included or layer_phase.phase == phase
163+
for layer_phase in layer.exclude:
164+
included = included and not layer_phase.phase == phase
165+
if not included:
166+
continue
151167
node_label = get_layer_label(layer, rankdir)
152168
node_name = "%s_%s" % (layer.name, layer.type)
153169
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
@@ -186,7 +202,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
186202
return pydot_graph
187203

188204

189-
def draw_net(caffe_net, rankdir, ext='png'):
205+
def draw_net(caffe_net, rankdir, ext='png', phase=None):
190206
"""Draws a caffe net and returns the image string encoded using the given
191207
extension.
192208
@@ -195,16 +211,19 @@ def draw_net(caffe_net, rankdir, ext='png'):
195211
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
196212
ext : string, optional
197213
The image extension (the default is 'png').
214+
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
215+
Include layers from this network phase. If None, include all layers.
216+
(the default is None)
198217
199218
Returns
200219
-------
201220
string :
202221
Postscript representation of the graph.
203222
"""
204-
return get_pydot_graph(caffe_net, rankdir).create(format=ext)
223+
return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)
205224

206225

207-
def draw_net_to_file(caffe_net, filename, rankdir='LR'):
226+
def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
208227
"""Draws a caffe net, and saves it to file using the format given as the
209228
file extension. Use '.raw' to output raw text that you can manually feed
210229
to graphviz to draw graphs.
@@ -216,7 +235,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
216235
The path to a file where the networks visualization will be stored.
217236
rankdir : {'LR', 'TB', 'BT'}
218237
Direction of graph layout.
238+
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
239+
Include layers from this network phase. If None, include all layers.
240+
(the default is None)
219241
"""
220242
ext = filename[filename.rfind('.')+1:]
221243
with open(filename, 'wb') as fid:
222-
fid.write(draw_net(caffe_net, rankdir, ext))
244+
fid.write(draw_net(caffe_net, rankdir, ext, phase))

python/draw_net.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def parse_args():
2828
'http://www.graphviz.org/doc/info/'
2929
'attrs.html#k:rankdir'),
3030
default='LR')
31+
parser.add_argument('--phase',
32+
help=('Which network phase to draw: can be TRAIN, '
33+
'TEST, or ALL. If ALL, then all layers are drawn '
34+
'regardless of phase.'),
35+
default="ALL")
3136

3237
args = parser.parse_args()
3338
return args
@@ -38,7 +43,15 @@ def main():
3843
net = caffe_pb2.NetParameter()
3944
text_format.Merge(open(args.input_net_proto_file).read(), net)
4045
print('Drawing net to %s' % args.output_image_file)
41-
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
46+
phase=None;
47+
if args.phase == "TRAIN":
48+
phase = caffe.TRAIN
49+
elif args.phase == "TEST":
50+
phase = caffe.TEST
51+
elif args.phase != "ALL":
52+
raise ValueError("Unknown phase: " + args.phase)
53+
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
54+
phase)
4255

4356

4457
if __name__ == '__main__':

0 commit comments

Comments
 (0)