@@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype):
127127 return color
128128
129129
130- def get_pydot_graph (caffe_net , rankdir , label_edges = True ):
130+ def get_pydot_graph (caffe_net , rankdir , label_edges = True , phase = None ):
131131 """Create a data structure which represents the `caffe_net`.
132132
133133 Parameters
@@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
137137 Direction of graph layout.
138138 label_edges : boolean, optional
139139 Label the edges (default is True).
140+ phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
141+ Include layers from this network phase. If None, include all layers.
142+ (the default is None)
140143
141144 Returns
142145 -------
@@ -148,6 +151,19 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
148151 pydot_nodes = {}
149152 pydot_edges = []
150153 for layer in caffe_net .layer :
154+ if phase is not None :
155+ included = False
156+ if len (layer .include ) == 0 :
157+ included = True
158+ if len (layer .include ) > 0 and len (layer .exclude ) > 0 :
159+ raise ValueError ('layer ' + layer .name + ' has both include '
160+ 'and exclude specified.' )
161+ for layer_phase in layer .include :
162+ included = included or layer_phase .phase == phase
163+ for layer_phase in layer .exclude :
164+ included = included and not layer_phase .phase == phase
165+ if not included :
166+ continue
151167 node_label = get_layer_label (layer , rankdir )
152168 node_name = "%s_%s" % (layer .name , layer .type )
153169 if (len (layer .bottom ) == 1 and len (layer .top ) == 1 and
@@ -186,7 +202,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
186202 return pydot_graph
187203
188204
189- def draw_net (caffe_net , rankdir , ext = 'png' ):
205+ def draw_net (caffe_net , rankdir , ext = 'png' , phase = None ):
190206 """Draws a caffe net and returns the image string encoded using the given
191207 extension.
192208
@@ -195,16 +211,19 @@ def draw_net(caffe_net, rankdir, ext='png'):
195211 caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
196212 ext : string, optional
197213 The image extension (the default is 'png').
214+ phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
215+ Include layers from this network phase. If None, include all layers.
216+ (the default is None)
198217
199218 Returns
200219 -------
201220 string :
202221 Postscript representation of the graph.
203222 """
204- return get_pydot_graph (caffe_net , rankdir ).create (format = ext )
223+ return get_pydot_graph (caffe_net , rankdir , phase = phase ).create (format = ext )
205224
206225
207- def draw_net_to_file (caffe_net , filename , rankdir = 'LR' ):
226+ def draw_net_to_file (caffe_net , filename , rankdir = 'LR' , phase = None ):
208227 """Draws a caffe net, and saves it to file using the format given as the
209228 file extension. Use '.raw' to output raw text that you can manually feed
210229 to graphviz to draw graphs.
@@ -216,7 +235,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
216235 The path to a file where the networks visualization will be stored.
217236 rankdir : {'LR', 'TB', 'BT'}
218237 Direction of graph layout.
238+ phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
239+ Include layers from this network phase. If None, include all layers.
240+ (the default is None)
219241 """
220242 ext = filename [filename .rfind ('.' )+ 1 :]
221243 with open (filename , 'wb' ) as fid :
222- fid .write (draw_net (caffe_net , rankdir , ext ))
244+ fid .write (draw_net (caffe_net , rankdir , ext , phase ))
0 commit comments