11"""
22Caffe network visualization: draw the NetParameter protobuffer.
33
4- NOTE: this requires pydot>=1.0.2, which is not included in requirements.txt
5- since it requires graphviz and other prerequisites outside the scope of the
6- Caffe.
4+
5+ .. note::
6+
7+ This requires pydot>=1.0.2, which is not included in requirements.txt since
8+ it requires graphviz and other prerequisites outside the scope of the
9+ Caffe.
710"""
811
912from caffe .proto import caffe_pb2
10- from google .protobuf import text_format
1113import pydot
1214
1315# Internal layer and blob styles.
@@ -32,24 +34,35 @@ def get_pooling_types_dict():
3234 return d
3335
3436
35- def determine_edge_label_by_layertype (layer , layertype ):
36- """Define edge label based on layer type
37+ def get_edge_label (layer ):
38+ """Define edge label based on layer type.
3739 """
3840
39- if layertype == 'Data' :
41+ if layer . type == 'Data' :
4042 edge_label = 'Batch ' + str (layer .data_param .batch_size )
41- elif layertype == 'Convolution' :
43+ elif layer . type == 'Convolution' :
4244 edge_label = str (layer .convolution_param .num_output )
43- elif layertype == 'InnerProduct' :
45+ elif layer . type == 'InnerProduct' :
4446 edge_label = str (layer .inner_product_param .num_output )
4547 else :
4648 edge_label = '""'
4749
4850 return edge_label
4951
5052
51- def determine_node_label_by_layertype (layer , layertype , rankdir ):
52- """Define node label based on layer type
53+ def get_layer_label (layer , rankdir ):
54+ """Define node label based on layer type.
55+
56+ Parameters
57+ ----------
58+ layer : ?
59+ rankdir : {'LR', 'TB', 'BT'}
60+ Direction of graph layout.
61+
62+ Returns
63+ -------
64+ string :
65+ A label for the current layer
5366 """
5467
5568 if rankdir in ('TB' , 'BT' ):
@@ -61,39 +74,39 @@ def determine_node_label_by_layertype(layer, layertype, rankdir):
6174 # horizontal space is not; separate words with newlines
6275 separator = '\n '
6376
64- if layertype == 'Convolution' :
77+ if layer . type == 'Convolution' :
6578 # Outer double quotes needed or else colon characters don't parse
6679 # properly
6780 node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' % \
6881 (layer .name ,
6982 separator ,
70- layertype ,
83+ layer . type ,
7184 separator ,
7285 layer .convolution_param .kernel_size ,
7386 separator ,
7487 layer .convolution_param .stride ,
7588 separator ,
7689 layer .convolution_param .pad )
77- elif layertype == 'Pooling' :
90+ elif layer . type == 'Pooling' :
7891 pooling_types_dict = get_pooling_types_dict ()
7992 node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' % \
8093 (layer .name ,
8194 separator ,
8295 pooling_types_dict [layer .pooling_param .pool ],
83- layertype ,
96+ layer . type ,
8497 separator ,
8598 layer .pooling_param .kernel_size ,
8699 separator ,
87100 layer .pooling_param .stride ,
88101 separator ,
89102 layer .pooling_param .pad )
90103 else :
91- node_label = '"%s%s(%s)"' % (layer .name , separator , layertype )
104+ node_label = '"%s%s(%s)"' % (layer .name , separator , layer . type )
92105 return node_label
93106
94107
95108def choose_color_by_layertype (layertype ):
96- """Define colors for nodes based on the layer type
109+ """Define colors for nodes based on the layer type.
97110 """
98111 color = '#6495ED' # Default
99112 if layertype == 'Convolution' :
@@ -106,48 +119,62 @@ def choose_color_by_layertype(layertype):
106119
107120
108121def get_pydot_graph (caffe_net , rankdir , label_edges = True ):
109- pydot_graph = pydot .Dot (caffe_net .name , graph_type = 'digraph' , rankdir = rankdir )
110- pydot_nodes = {}
111- pydot_edges = []
112- for layer in caffe_net .layer :
113- name = layer .name
114- layertype = layer .type
115- node_label = determine_node_label_by_layertype (layer , layertype , rankdir )
116- if (len (layer .bottom ) == 1 and len (layer .top ) == 1 and
117- layer .bottom [0 ] == layer .top [0 ]):
118- # We have an in-place neuron layer.
119- pydot_nodes [name + '_' + layertype ] = pydot .Node (
120- node_label , ** NEURON_LAYER_STYLE )
121- else :
122- layer_style = LAYER_STYLE_DEFAULT
123- layer_style ['fillcolor' ] = choose_color_by_layertype (layertype )
124- pydot_nodes [name + '_' + layertype ] = pydot .Node (
125- node_label , ** layer_style )
126- for bottom_blob in layer .bottom :
127- pydot_nodes [bottom_blob + '_blob' ] = pydot .Node (
128- '%s' % (bottom_blob ), ** BLOB_STYLE )
129- edge_label = '""'
130- pydot_edges .append ({'src' : bottom_blob + '_blob' ,
131- 'dst' : name + '_' + layertype ,
132- 'label' : edge_label })
133- for top_blob in layer .top :
134- pydot_nodes [top_blob + '_blob' ] = pydot .Node (
135- '%s' % (top_blob ))
136- if label_edges :
137- edge_label = determine_edge_label_by_layertype (layer , layertype )
138- else :
139- edge_label = '""'
140- pydot_edges .append ({'src' : name + '_' + layertype ,
141- 'dst' : top_blob + '_blob' ,
142- 'label' : edge_label })
143- # Now, add the nodes and edges to the graph.
144- for node in pydot_nodes .values ():
145- pydot_graph .add_node (node )
146- for edge in pydot_edges :
147- pydot_graph .add_edge (
148- pydot .Edge (pydot_nodes [edge ['src' ]], pydot_nodes [edge ['dst' ]],
149- label = edge ['label' ]))
150- return pydot_graph
122+ """Create a data structure which represents the `caffe_net`.
123+
124+ Parameters
125+ ----------
126+ caffe_net : object
127+ rankdir : {'LR', 'TB', 'BT'}
128+ Direction of graph layout.
129+ label_edges : boolean, optional
130+ Label the edges (default is True).
131+
132+ Returns
133+ -------
134+ pydot graph object
135+ """
136+ pydot_graph = pydot .Dot (caffe_net .name ,
137+ graph_type = 'digraph' ,
138+ rankdir = rankdir )
139+ pydot_nodes = {}
140+ pydot_edges = []
141+ for layer in caffe_net .layer :
142+ node_label = get_layer_label (layer , rankdir )
143+ node_name = "%s_%s" % (layer .name , layer .type )
144+ if (len (layer .bottom ) == 1 and len (layer .top ) == 1 and
145+ layer .bottom [0 ] == layer .top [0 ]):
146+ # We have an in-place neuron layer.
147+ pydot_nodes [node_name ] = pydot .Node (node_label ,
148+ ** NEURON_LAYER_STYLE )
149+ else :
150+ layer_style = LAYER_STYLE_DEFAULT
151+ layer_style ['fillcolor' ] = choose_color_by_layertype (layer .type )
152+ pydot_nodes [node_name ] = pydot .Node (node_label , ** layer_style )
153+ for bottom_blob in layer .bottom :
154+ pydot_nodes [bottom_blob + '_blob' ] = pydot .Node ('%s' % bottom_blob ,
155+ ** BLOB_STYLE )
156+ edge_label = '""'
157+ pydot_edges .append ({'src' : bottom_blob + '_blob' ,
158+ 'dst' : node_name ,
159+ 'label' : edge_label })
160+ for top_blob in layer .top :
161+ pydot_nodes [top_blob + '_blob' ] = pydot .Node ('%s' % (top_blob ))
162+ if label_edges :
163+ edge_label = get_edge_label (layer )
164+ else :
165+ edge_label = '""'
166+ pydot_edges .append ({'src' : node_name ,
167+ 'dst' : top_blob + '_blob' ,
168+ 'label' : edge_label })
169+ # Now, add the nodes and edges to the graph.
170+ for node in pydot_nodes .values ():
171+ pydot_graph .add_node (node )
172+ for edge in pydot_edges :
173+ pydot_graph .add_edge (
174+ pydot .Edge (pydot_nodes [edge ['src' ]],
175+ pydot_nodes [edge ['dst' ]],
176+ label = edge ['label' ]))
177+ return pydot_graph
151178
152179
153180def draw_net (caffe_net , rankdir , ext = 'png' ):
@@ -156,8 +183,14 @@ def draw_net(caffe_net, rankdir, ext='png'):
156183
157184 Parameters
158185 ----------
159- caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
160- ext: the image extension. Default 'png'.
186+ caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
187+ ext : string, optional
188+ The image extension (the default is 'png').
189+
190+ Returns
191+ -------
192+ string :
193+ Postscript representation of the graph.
161194 """
162195 return get_pydot_graph (caffe_net , rankdir ).create (format = ext )
163196
@@ -166,6 +199,14 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
166199 """Draws a caffe net, and saves it to file using the format given as the
167200 file extension. Use '.raw' to output raw text that you can manually feed
168201 to graphviz to draw graphs.
202+
203+ Parameters
204+ ----------
205+ caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
206+ filename : string
207+ The path to a file where the networks visualization will be stored.
208+ rankdir : {'LR', 'TB', 'BT'}
209+ Direction of graph layout.
169210 """
170211 ext = filename [filename .rfind ('.' )+ 1 :]
171212 with open (filename , 'wb' ) as fid :
0 commit comments