Skip to content

Commit adf3189

Browse files
committed
python/draw_net.py and python/caffe/draw.py: Simplified code; added more docstrings; adjusted code according to PEP8
1 parent e8d93cb commit adf3189

2 files changed

Lines changed: 109 additions & 67 deletions

File tree

python/caffe/draw.py

Lines changed: 102 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""
22
Caffe network visualization: draw the NetParameter protobuffer.
33
4-
NOTE: this requires pydot>=1.0.2, which is not included in requirements.txt
5-
since it requires graphviz and other prerequisites outside the scope of the
6-
Caffe.
4+
5+
.. note::
6+
7+
This requires pydot>=1.0.2, which is not included in requirements.txt since
8+
it requires graphviz and other prerequisites outside the scope of the
9+
Caffe.
710
"""
811

912
from caffe.proto import caffe_pb2
10-
from google.protobuf import text_format
1113
import pydot
1214

1315
# Internal layer and blob styles.
@@ -32,24 +34,35 @@ def get_pooling_types_dict():
3234
return d
3335

3436

35-
def determine_edge_label_by_layertype(layer, layertype):
36-
"""Define edge label based on layer type
37+
def get_edge_label(layer):
38+
"""Define edge label based on layer type.
3739
"""
3840

39-
if layertype == 'Data':
41+
if layer.type == 'Data':
4042
edge_label = 'Batch ' + str(layer.data_param.batch_size)
41-
elif layertype == 'Convolution':
43+
elif layer.type == 'Convolution':
4244
edge_label = str(layer.convolution_param.num_output)
43-
elif layertype == 'InnerProduct':
45+
elif layer.type == 'InnerProduct':
4446
edge_label = str(layer.inner_product_param.num_output)
4547
else:
4648
edge_label = '""'
4749

4850
return edge_label
4951

5052

51-
def determine_node_label_by_layertype(layer, layertype, rankdir):
52-
"""Define node label based on layer type
53+
def get_layer_label(layer, rankdir):
54+
"""Define node label based on layer type.
55+
56+
Parameters
57+
----------
58+
layer : ?
59+
rankdir : {'LR', 'TB', 'BT'}
60+
Direction of graph layout.
61+
62+
Returns
63+
-------
64+
string :
65+
A label for the current layer
5366
"""
5467

5568
if rankdir in ('TB', 'BT'):
@@ -61,39 +74,39 @@ def determine_node_label_by_layertype(layer, layertype, rankdir):
6174
# horizontal space is not; separate words with newlines
6275
separator = '\n'
6376

64-
if layertype == 'Convolution':
77+
if layer.type == 'Convolution':
6578
# Outer double quotes needed or else colon characters don't parse
6679
# properly
6780
node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
6881
(layer.name,
6982
separator,
70-
layertype,
83+
layer.type,
7184
separator,
7285
layer.convolution_param.kernel_size,
7386
separator,
7487
layer.convolution_param.stride,
7588
separator,
7689
layer.convolution_param.pad)
77-
elif layertype == 'Pooling':
90+
elif layer.type == 'Pooling':
7891
pooling_types_dict = get_pooling_types_dict()
7992
node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
8093
(layer.name,
8194
separator,
8295
pooling_types_dict[layer.pooling_param.pool],
83-
layertype,
96+
layer.type,
8497
separator,
8598
layer.pooling_param.kernel_size,
8699
separator,
87100
layer.pooling_param.stride,
88101
separator,
89102
layer.pooling_param.pad)
90103
else:
91-
node_label = '"%s%s(%s)"' % (layer.name, separator, layertype)
104+
node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
92105
return node_label
93106

94107

95108
def choose_color_by_layertype(layertype):
96-
"""Define colors for nodes based on the layer type
109+
"""Define colors for nodes based on the layer type.
97110
"""
98111
color = '#6495ED' # Default
99112
if layertype == 'Convolution':
@@ -106,48 +119,62 @@ def choose_color_by_layertype(layertype):
106119

107120

108121
def get_pydot_graph(caffe_net, rankdir, label_edges=True):
109-
pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph', rankdir=rankdir)
110-
pydot_nodes = {}
111-
pydot_edges = []
112-
for layer in caffe_net.layer:
113-
name = layer.name
114-
layertype = layer.type
115-
node_label = determine_node_label_by_layertype(layer, layertype, rankdir)
116-
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
117-
layer.bottom[0] == layer.top[0]):
118-
# We have an in-place neuron layer.
119-
pydot_nodes[name + '_' + layertype] = pydot.Node(
120-
node_label, **NEURON_LAYER_STYLE)
121-
else:
122-
layer_style = LAYER_STYLE_DEFAULT
123-
layer_style['fillcolor'] = choose_color_by_layertype(layertype)
124-
pydot_nodes[name + '_' + layertype] = pydot.Node(
125-
node_label, **layer_style)
126-
for bottom_blob in layer.bottom:
127-
pydot_nodes[bottom_blob + '_blob'] = pydot.Node(
128-
'%s' % (bottom_blob), **BLOB_STYLE)
129-
edge_label = '""'
130-
pydot_edges.append({'src': bottom_blob + '_blob',
131-
'dst': name + '_' + layertype,
132-
'label': edge_label})
133-
for top_blob in layer.top:
134-
pydot_nodes[top_blob + '_blob'] = pydot.Node(
135-
'%s' % (top_blob))
136-
if label_edges:
137-
edge_label = determine_edge_label_by_layertype(layer, layertype)
138-
else:
139-
edge_label = '""'
140-
pydot_edges.append({'src': name + '_' + layertype,
141-
'dst': top_blob + '_blob',
142-
'label': edge_label})
143-
# Now, add the nodes and edges to the graph.
144-
for node in pydot_nodes.values():
145-
pydot_graph.add_node(node)
146-
for edge in pydot_edges:
147-
pydot_graph.add_edge(
148-
pydot.Edge(pydot_nodes[edge['src']], pydot_nodes[edge['dst']],
149-
label=edge['label']))
150-
return pydot_graph
122+
"""Create a data structure which represents the `caffe_net`.
123+
124+
Parameters
125+
----------
126+
caffe_net : object
127+
rankdir : {'LR', 'TB', 'BT'}
128+
Direction of graph layout.
129+
label_edges : boolean, optional
130+
Label the edges (default is True).
131+
132+
Returns
133+
-------
134+
pydot graph object
135+
"""
136+
pydot_graph = pydot.Dot(caffe_net.name,
137+
graph_type='digraph',
138+
rankdir=rankdir)
139+
pydot_nodes = {}
140+
pydot_edges = []
141+
for layer in caffe_net.layer:
142+
node_label = get_layer_label(layer, rankdir)
143+
node_name = "%s_%s" % (layer.name, layer.type)
144+
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
145+
layer.bottom[0] == layer.top[0]):
146+
# We have an in-place neuron layer.
147+
pydot_nodes[node_name] = pydot.Node(node_label,
148+
**NEURON_LAYER_STYLE)
149+
else:
150+
layer_style = LAYER_STYLE_DEFAULT
151+
layer_style['fillcolor'] = choose_color_by_layertype(layer.type)
152+
pydot_nodes[node_name] = pydot.Node(node_label, **layer_style)
153+
for bottom_blob in layer.bottom:
154+
pydot_nodes[bottom_blob + '_blob'] = pydot.Node('%s' % bottom_blob,
155+
**BLOB_STYLE)
156+
edge_label = '""'
157+
pydot_edges.append({'src': bottom_blob + '_blob',
158+
'dst': node_name,
159+
'label': edge_label})
160+
for top_blob in layer.top:
161+
pydot_nodes[top_blob + '_blob'] = pydot.Node('%s' % (top_blob))
162+
if label_edges:
163+
edge_label = get_edge_label(layer)
164+
else:
165+
edge_label = '""'
166+
pydot_edges.append({'src': node_name,
167+
'dst': top_blob + '_blob',
168+
'label': edge_label})
169+
# Now, add the nodes and edges to the graph.
170+
for node in pydot_nodes.values():
171+
pydot_graph.add_node(node)
172+
for edge in pydot_edges:
173+
pydot_graph.add_edge(
174+
pydot.Edge(pydot_nodes[edge['src']],
175+
pydot_nodes[edge['dst']],
176+
label=edge['label']))
177+
return pydot_graph
151178

152179

153180
def draw_net(caffe_net, rankdir, ext='png'):
@@ -156,8 +183,14 @@ def draw_net(caffe_net, rankdir, ext='png'):
156183
157184
Parameters
158185
----------
159-
caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
160-
ext: the image extension. Default 'png'.
186+
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
187+
ext : string, optional
188+
The image extension (the default is 'png').
189+
190+
Returns
191+
-------
192+
string :
193+
Postscript representation of the graph.
161194
"""
162195
return get_pydot_graph(caffe_net, rankdir).create(format=ext)
163196

@@ -166,6 +199,14 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
166199
"""Draws a caffe net, and saves it to file using the format given as the
167200
file extension. Use '.raw' to output raw text that you can manually feed
168201
to graphviz to draw graphs.
202+
203+
Parameters
204+
----------
205+
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
206+
filename : string
207+
The path to a file where the networks visualization will be stored.
208+
rankdir : {'LR', 'TB', 'BT'}
209+
Direction of graph layout.
169210
"""
170211
ext = filename[filename.rfind('.')+1:]
171212
with open(filename, 'wb') as fid:

python/draw_net.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Draw a graph of the net architecture.
44
"""
5-
import argparse
5+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
66
from google.protobuf import text_format
77

88
import caffe
@@ -14,18 +14,19 @@ def parse_args():
1414
"""Parse input arguments
1515
"""
1616

17-
parser = argparse.ArgumentParser(description='Draw a network graph')
17+
parser = ArgumentParser(description=__doc__,
18+
formatter_class=ArgumentDefaultsHelpFormatter)
1819

1920
parser.add_argument('input_net_proto_file',
2021
help='Input network prototxt file')
2122
parser.add_argument('output_image_file',
2223
help='Output image file')
2324
parser.add_argument('--rankdir',
2425
help=('One of TB (top-bottom, i.e., vertical), '
25-
'RL (right-left, i.e., horizontal), or another'
26-
'valid dot option; see'
27-
'http://www.graphviz.org/doc/info/attrs.html#k:rankdir'
28-
'(default: LR)'),
26+
'RL (right-left, i.e., horizontal), or another '
27+
'valid dot option; see '
28+
'http://www.graphviz.org/doc/info/'
29+
'attrs.html#k:rankdir'),
2930
default='LR')
3031

3132
args = parser.parse_args()

0 commit comments

Comments
 (0)