Skip to content

Commit c2a1ecd

Browse files
authored
Merge pull request #4343 from nitnelave/python/top_names
improve top_names and bottom_names in pycaffe
2 parents 2f49fd2 + 7c50a2c commit c2a1ecd

2 files changed

Lines changed: 38 additions & 15 deletions

File tree

python/caffe/pycaffe.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -292,21 +292,31 @@ def _Net_batch(self, blobs):
292292
padding])
293293
yield padded_batch
294294

295-
296-
class _Net_IdNameWrapper:
297-
"""
298-
A simple wrapper that allows the ids propery to be accessed as a dict
299-
indexed by names. Used for top and bottom names
295+
def _Net_get_id_name(func, field):
300296
"""
301-
def __init__(self, net, func):
302-
self.net, self.func = net, func
297+
Generic property that maps func to the layer names into an OrderedDict.
298+
299+
Used for top_names and bottom_names.
303300
304-
def __getitem__(self, name):
305-
# Map the layer name to id
306-
ids = self.func(self.net, list(self.net._layer_names).index(name))
307-
# Map the blob id to name
308-
id_to_name = list(self.net.blobs)
309-
return [id_to_name[i] for i in ids]
301+
Parameters
302+
----------
303+
func: function id -> [id]
304+
field: implementation field name (cache)
305+
306+
Returns
307+
------
308+
A one-parameter function that can be set as a property.
309+
"""
310+
@property
311+
def get_id_name(self):
312+
if not hasattr(self, field):
313+
id_to_name = list(self.blobs)
314+
res = OrderedDict([(self._layer_names[i],
315+
[id_to_name[j] for j in func(self, i)])
316+
for i in range(len(self.layers))])
317+
setattr(self, field, res)
318+
return getattr(self, field)
319+
return get_id_name
310320

311321
# Attach methods to Net.
312322
Net.blobs = _Net_blobs
@@ -320,5 +330,5 @@ def __getitem__(self, name):
320330
Net._batch = _Net_batch
321331
Net.inputs = _Net_inputs
322332
Net.outputs = _Net_outputs
323-
Net.top_names = property(lambda n: _Net_IdNameWrapper(n, Net._top_ids))
324-
Net.bottom_names = property(lambda n: _Net_IdNameWrapper(n, Net._bottom_ids))
333+
Net.top_names = _Net_get_id_name(Net._top_ids, "_top_names")
334+
Net.bottom_names = _Net_get_id_name(Net._bottom_ids, "_bottom_names")

python/caffe/test/test_net.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import numpy as np
55
import six
6+
from collections import OrderedDict
67

78
import caffe
89

@@ -78,6 +79,18 @@ def test_inputs_outputs(self):
7879
self.assertEqual(self.net.inputs, [])
7980
self.assertEqual(self.net.outputs, ['loss'])
8081

82+
def test_top_bottom_names(self):
83+
self.assertEqual(self.net.top_names,
84+
OrderedDict([('data', ['data', 'label']),
85+
('conv', ['conv']),
86+
('ip', ['ip']),
87+
('loss', ['loss'])]))
88+
self.assertEqual(self.net.bottom_names,
89+
OrderedDict([('data', []),
90+
('conv', ['data']),
91+
('ip', ['conv']),
92+
('loss', ['ip', 'label'])]))
93+
8194
def test_save_and_read(self):
8295
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
8396
f.close()

0 commit comments

Comments
 (0)