Skip to content

Commit f5fd18b

Browse files
committed
Merge pull request #3082 from gustavla/pycaffe-snapshot
Expose `Solver::Snapshot` to pycaffe
2 parents ca4e342 + 19d9927 commit f5fd18b

3 files changed

Lines changed: 17 additions & 7 deletions

File tree

include/caffe/solver.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class Solver {
6161
// RestoreSolverStateFrom___ protected methods. You should implement these
6262
// methods to restore the state from the appropriate snapshot type.
6363
void Restore(const char* resume_file);
64+
// The Solver::Snapshot function implements the basic snapshotting utility
65+
// that stores the learned net. You should implement the SnapshotSolverState()
66+
// function that produces a SolverState protocol buffer that needs to be
67+
// written to disk together with the learned net.
68+
void Snapshot();
6469
virtual ~Solver() {}
6570
inline const SolverParameter& param() const { return param_; }
6671
inline shared_ptr<Net<Dtype> > net() { return net_; }
@@ -92,11 +97,6 @@ class Solver {
9297
protected:
9398
// Make and apply the update value for the current iteration.
9499
virtual void ApplyUpdate() = 0;
95-
// The Solver::Snapshot function implements the basic snapshotting utility
96-
// that stores the learned net. You should implement the SnapshotSolverState()
97-
// function that produces a SolverState protocol buffer that needs to be
98-
// written to disk together with the learned net.
99-
void Snapshot();
100100
string SnapshotFilename(const string extension);
101101
string SnapshotToBinaryProto();
102102
string SnapshotToHDF5();

python/caffe/_caffe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ BOOST_PYTHON_MODULE(_caffe) {
287287
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
288288
&Solver<Dtype>::Solve), SolveOverloads())
289289
.def("step", &Solver<Dtype>::Step)
290-
.def("restore", &Solver<Dtype>::Restore);
290+
.def("restore", &Solver<Dtype>::Restore)
291+
.def("snapshot", &Solver<Dtype>::Snapshot);
291292

292293
bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
293294
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(

python/caffe/test/test_solver.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def setUp(self):
1616
f.write("""net: '""" + net_f + """'
1717
test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
1818
weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
19-
display: 100 max_iter: 100 snapshot_after_train: false""")
19+
display: 100 max_iter: 100 snapshot_after_train: false
20+
snapshot_prefix: "model" """)
2021
f.close()
2122
self.solver = caffe.SGDSolver(f.name)
2223
# also make sure get_solver runs
@@ -51,3 +52,11 @@ def test_net_memory(self):
5152
total += p.data.sum() + p.diff.sum()
5253
for bl in six.itervalues(net.blobs):
5354
total += bl.data.sum() + bl.diff.sum()
55+
56+
def test_snapshot(self):
57+
self.solver.snapshot()
58+
# Check that these files exist and then remove them
59+
files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
60+
for fn in files:
61+
assert os.path.isfile(fn)
62+
os.remove(fn)

0 commit comments

Comments
 (0)