Skip to content

Commit 19adc7a

Browse files
committed
Add level and stages to pycaffe
Uses Boost.Python's pattern matching to differentiate between constructors Also adds Python tests for all-in-one nets
1 parent 66e84d7 commit 19adc7a

2 files changed

Lines changed: 263 additions & 9 deletions

File tree

python/caffe/_caffe.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,42 @@ void CheckContiguousArray(PyArrayObject* arr, string name,
8686
}
8787
}
8888

89-
// Net constructor for passing phase as int
90-
shared_ptr<Net<Dtype> > Net_Init(
91-
string param_file, int phase) {
92-
CheckFile(param_file);
89+
// Net constructor
90+
shared_ptr<Net<Dtype> > Net_Init(string network_file, int phase,
91+
const int level, const bp::object& stages,
92+
const bp::object& weights) {
93+
CheckFile(network_file);
94+
95+
// Convert stages from list to vector
96+
vector<string> stages_vector;
97+
if (!stages.is_none()) {
98+
for (int i = 0; i < len(stages); i++) {
99+
stages_vector.push_back(bp::extract<string>(stages[i]));
100+
}
101+
}
102+
103+
// Initialize net
104+
shared_ptr<Net<Dtype> > net(new Net<Dtype>(network_file,
105+
static_cast<Phase>(phase), level, &stages_vector));
106+
107+
// Load weights
108+
if (!weights.is_none()) {
109+
std::string weights_file_str = bp::extract<std::string>(weights);
110+
CheckFile(weights_file_str);
111+
net->CopyTrainedLayersFrom(weights_file_str);
112+
}
93113

94-
shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
95-
static_cast<Phase>(phase)));
96114
return net;
97115
}
98116

99-
// Net construct-and-load convenience constructor
117+
// Legacy Net construct-and-load convenience constructor
100118
shared_ptr<Net<Dtype> > Net_Init_Load(
101119
string param_file, string pretrained_param_file, int phase) {
120+
LOG(WARNING) << "DEPRECATION WARNING - deprecated use of Python interface";
121+
LOG(WARNING) << "Use this instead (with the named \"weights\""
122+
<< " parameter):";
123+
LOG(WARNING) << "Net('" << param_file << "', " << phase
124+
<< ", weights='" << pretrained_param_file << "')";
102125
CheckFile(param_file);
103126
CheckFile(pretrained_param_file);
104127

@@ -245,7 +268,12 @@ BOOST_PYTHON_MODULE(_caffe) {
245268

246269
bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
247270
bp::no_init)
248-
.def("__init__", bp::make_constructor(&Net_Init))
271+
// Constructor
272+
.def("__init__", bp::make_constructor(&Net_Init,
273+
bp::default_call_policies(), (bp::arg("network_file"), "phase",
274+
bp::arg("level")=0, bp::arg("stages")=bp::object(),
275+
bp::arg("weights")=bp::object())))
276+
// Legacy constructor
249277
.def("__init__", bp::make_constructor(&Net_Init_Load))
250278
.def("_forward", &Net<Dtype>::ForwardFromTo)
251279
.def("_backward", &Net<Dtype>::BackwardFromTo)

python/caffe/test/test_net.py

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def test_save_and_read(self):
7272
f.close()
7373
self.net.save(f.name)
7474
net_file = simple_net_file(self.num_output)
75-
net2 = caffe.Net(net_file, f.name, caffe.TRAIN)
75+
# Test legacy constructor
76+
# should print deprecation warning
77+
caffe.Net(net_file, f.name, caffe.TRAIN)
78+
# Test named constructor
79+
net2 = caffe.Net(net_file, caffe.TRAIN, weights=f.name)
7680
os.remove(net_file)
7781
os.remove(f.name)
7882
for name in self.net.params:
@@ -93,3 +97,225 @@ def test_save_hdf5(self):
9397
for i in range(len(self.net.params[name])):
9498
self.assertEqual(abs(self.net.params[name][i].data
9599
- net2.params[name][i].data).sum(), 0)
100+
101+
class TestLevels(unittest.TestCase):
102+
103+
TEST_NET = """
104+
layer {
105+
name: "data"
106+
type: "DummyData"
107+
top: "data"
108+
dummy_data_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } }
109+
}
110+
layer {
111+
name: "NoLevel"
112+
type: "InnerProduct"
113+
bottom: "data"
114+
top: "NoLevel"
115+
inner_product_param { num_output: 1 }
116+
}
117+
layer {
118+
name: "Level0Only"
119+
type: "InnerProduct"
120+
bottom: "data"
121+
top: "Level0Only"
122+
include { min_level: 0 max_level: 0 }
123+
inner_product_param { num_output: 1 }
124+
}
125+
layer {
126+
name: "Level1Only"
127+
type: "InnerProduct"
128+
bottom: "data"
129+
top: "Level1Only"
130+
include { min_level: 1 max_level: 1 }
131+
inner_product_param { num_output: 1 }
132+
}
133+
layer {
134+
name: "Level>=0"
135+
type: "InnerProduct"
136+
bottom: "data"
137+
top: "Level>=0"
138+
include { min_level: 0 }
139+
inner_product_param { num_output: 1 }
140+
}
141+
layer {
142+
name: "Level>=1"
143+
type: "InnerProduct"
144+
bottom: "data"
145+
top: "Level>=1"
146+
include { min_level: 1 }
147+
inner_product_param { num_output: 1 }
148+
}
149+
"""
150+
151+
def setUp(self):
152+
self.f = tempfile.NamedTemporaryFile(mode='w+')
153+
self.f.write(self.TEST_NET)
154+
self.f.flush()
155+
156+
def tearDown(self):
157+
self.f.close()
158+
159+
def check_net(self, net, blobs):
160+
net_blobs = [b for b in net.blobs.keys() if 'data' not in b]
161+
self.assertEqual(net_blobs, blobs)
162+
163+
def test_0(self):
164+
net = caffe.Net(self.f.name, caffe.TEST)
165+
self.check_net(net, ['NoLevel', 'Level0Only', 'Level>=0'])
166+
167+
def test_1(self):
168+
net = caffe.Net(self.f.name, caffe.TEST, level=1)
169+
self.check_net(net, ['NoLevel', 'Level1Only', 'Level>=0', 'Level>=1'])
170+
171+
172+
class TestStages(unittest.TestCase):
173+
174+
TEST_NET = """
175+
layer {
176+
name: "data"
177+
type: "DummyData"
178+
top: "data"
179+
dummy_data_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } }
180+
}
181+
layer {
182+
name: "A"
183+
type: "InnerProduct"
184+
bottom: "data"
185+
top: "A"
186+
include { stage: "A" }
187+
inner_product_param { num_output: 1 }
188+
}
189+
layer {
190+
name: "B"
191+
type: "InnerProduct"
192+
bottom: "data"
193+
top: "B"
194+
include { stage: "B" }
195+
inner_product_param { num_output: 1 }
196+
}
197+
layer {
198+
name: "AorB"
199+
type: "InnerProduct"
200+
bottom: "data"
201+
top: "AorB"
202+
include { stage: "A" }
203+
include { stage: "B" }
204+
inner_product_param { num_output: 1 }
205+
}
206+
layer {
207+
name: "AandB"
208+
type: "InnerProduct"
209+
bottom: "data"
210+
top: "AandB"
211+
include { stage: "A" stage: "B" }
212+
inner_product_param { num_output: 1 }
213+
}
214+
"""
215+
216+
def setUp(self):
217+
self.f = tempfile.NamedTemporaryFile(mode='w+')
218+
self.f.write(self.TEST_NET)
219+
self.f.flush()
220+
221+
def tearDown(self):
222+
self.f.close()
223+
224+
def check_net(self, net, blobs):
225+
net_blobs = [b for b in net.blobs.keys() if 'data' not in b]
226+
self.assertEqual(net_blobs, blobs)
227+
228+
def test_A(self):
229+
net = caffe.Net(self.f.name, caffe.TEST, stages=['A'])
230+
self.check_net(net, ['A', 'AorB'])
231+
232+
def test_B(self):
233+
net = caffe.Net(self.f.name, caffe.TEST, stages=['B'])
234+
self.check_net(net, ['B', 'AorB'])
235+
236+
def test_AandB(self):
237+
net = caffe.Net(self.f.name, caffe.TEST, stages=['A', 'B'])
238+
self.check_net(net, ['A', 'B', 'AorB', 'AandB'])
239+
240+
241+
class TestAllInOne(unittest.TestCase):
242+
243+
TEST_NET = """
244+
layer {
245+
name: "train_data"
246+
type: "DummyData"
247+
top: "data"
248+
top: "label"
249+
dummy_data_param {
250+
shape { dim: 1 dim: 1 dim: 10 dim: 10 }
251+
shape { dim: 1 dim: 1 dim: 1 dim: 1 }
252+
}
253+
include { phase: TRAIN stage: "train" }
254+
}
255+
layer {
256+
name: "val_data"
257+
type: "DummyData"
258+
top: "data"
259+
top: "label"
260+
dummy_data_param {
261+
shape { dim: 1 dim: 1 dim: 10 dim: 10 }
262+
shape { dim: 1 dim: 1 dim: 1 dim: 1 }
263+
}
264+
include { phase: TEST stage: "val" }
265+
}
266+
layer {
267+
name: "deploy_data"
268+
type: "Input"
269+
top: "data"
270+
input_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } }
271+
include { phase: TEST stage: "deploy" }
272+
}
273+
layer {
274+
name: "ip"
275+
type: "InnerProduct"
276+
bottom: "data"
277+
top: "ip"
278+
inner_product_param { num_output: 2 }
279+
}
280+
layer {
281+
name: "loss"
282+
type: "SoftmaxWithLoss"
283+
bottom: "ip"
284+
bottom: "label"
285+
top: "loss"
286+
include: { phase: TRAIN stage: "train" }
287+
include: { phase: TEST stage: "val" }
288+
}
289+
layer {
290+
name: "pred"
291+
type: "Softmax"
292+
bottom: "ip"
293+
top: "pred"
294+
include: { phase: TEST stage: "deploy" }
295+
}
296+
"""
297+
298+
def setUp(self):
299+
self.f = tempfile.NamedTemporaryFile(mode='w+')
300+
self.f.write(self.TEST_NET)
301+
self.f.flush()
302+
303+
def tearDown(self):
304+
self.f.close()
305+
306+
def check_net(self, net, outputs):
307+
self.assertEqual(list(net.blobs['data'].shape), [1,1,10,10])
308+
self.assertEqual(net.outputs, outputs)
309+
310+
def test_train(self):
311+
net = caffe.Net(self.f.name, caffe.TRAIN, stages=['train'])
312+
self.check_net(net, ['loss'])
313+
314+
def test_val(self):
315+
net = caffe.Net(self.f.name, caffe.TEST, stages=['val'])
316+
self.check_net(net, ['loss'])
317+
318+
def test_deploy(self):
319+
net = caffe.Net(self.f.name, caffe.TEST, stages=['deploy'])
320+
self.check_net(net, ['pred'])
321+

0 commit comments

Comments
 (0)