File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -61,6 +61,11 @@ class Solver {
6161 // RestoreSolverStateFrom___ protected methods. You should implement these
6262 // methods to restore the state from the appropriate snapshot type.
6363 void Restore (const char * resume_file);
64+ // The Solver::Snapshot function implements the basic snapshotting utility
65+ // that stores the learned net. You should implement the SnapshotSolverState()
66+ // function that produces a SolverState protocol buffer that needs to be
67+ // written to disk together with the learned net.
68+ void Snapshot ();
6469 virtual ~Solver () {}
6570 inline const SolverParameter& param () const { return param_; }
6671 inline shared_ptr<Net<Dtype> > net () { return net_; }
@@ -92,11 +97,6 @@ class Solver {
9297 protected:
9398 // Make and apply the update value for the current iteration.
9499 virtual void ApplyUpdate () = 0;
95- // The Solver::Snapshot function implements the basic snapshotting utility
96- // that stores the learned net. You should implement the SnapshotSolverState()
97- // function that produces a SolverState protocol buffer that needs to be
98- // written to disk together with the learned net.
99- void Snapshot ();
100100 string SnapshotFilename (const string extension);
101101 string SnapshotToBinaryProto ();
102102 string SnapshotToHDF5 ();
Original file line number Diff line number Diff line change @@ -287,7 +287,8 @@ BOOST_PYTHON_MODULE(_caffe) {
287287 .def (" solve" , static_cast <void (Solver<Dtype>::*)(const char *)>(
288288 &Solver<Dtype>::Solve), SolveOverloads ())
289289 .def (" step" , &Solver<Dtype>::Step)
290- .def (" restore" , &Solver<Dtype>::Restore);
290+ .def (" restore" , &Solver<Dtype>::Restore)
291+ .def (" snapshot" , &Solver<Dtype>::Snapshot);
291292
292293 bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
293294 shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
Original file line number Diff line number Diff line change @@ -16,7 +16,8 @@ def setUp(self):
1616 f .write ("""net: '""" + net_f + """'
1717 test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
1818 weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
19- display: 100 max_iter: 100 snapshot_after_train: false""" )
19+ display: 100 max_iter: 100 snapshot_after_train: false
20+ snapshot_prefix: "model" """ )
2021 f .close ()
2122 self .solver = caffe .SGDSolver (f .name )
2223 # also make sure get_solver runs
@@ -51,3 +52,11 @@ def test_net_memory(self):
5152 total += p .data .sum () + p .diff .sum ()
5253 for bl in six .itervalues (net .blobs ):
5354 total += bl .data .sum () + bl .diff .sum ()
55+
56+ def test_snapshot (self ):
57+ self .solver .snapshot ()
58+ # Check that these files exist and then remove them
59+ files = ['model_iter_0.caffemodel' , 'model_iter_0.solverstate' ]
60+ for fn in files :
61+ assert os .path .isfile (fn )
62+ os .remove (fn )
You can’t perform that action at this time.
0 commit comments