Skip to content

Commit 4d0103b

Browse files
committed
Merge pull request #2037 from shelhamer/expose-solver-restore
Expose Solver::Restore() as public for restoring without solving
2 parents 65d84a5 + c0219cc commit 4d0103b

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

include/caffe/solver.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class Solver {
2727
virtual void Solve(const char* resume_file = NULL);
2828
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
2929
void Step(int iters);
30+
// The Restore function implements how one should restore the solver to a
31+
// previously snapshotted state. You should implement the RestoreSolverState()
32+
// function that restores the state from a SolverState protocol buffer.
33+
void Restore(const char* resume_file);
3034
virtual ~Solver() {}
3135
inline shared_ptr<Net<Dtype> > net() { return net_; }
3236
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
@@ -46,10 +50,6 @@ class Solver {
4650
void TestAll();
4751
void Test(const int test_net_id = 0);
4852
virtual void SnapshotSolverState(SolverState* state) = 0;
49-
// The Restore function implements how one should restore the solver to a
50-
// previously snapshotted state. You should implement the RestoreSolverState()
51-
// function that restores the state from a SolverState protocol buffer.
52-
void Restore(const char* resume_file);
5353
virtual void RestoreSolverState(const SolverState& state) = 0;
5454
void DisplayOutputBlobs(const int net_id);
5555

python/caffe/_caffe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ BOOST_PYTHON_MODULE(_caffe) {
261261
.add_property("iter", &Solver<Dtype>::iter)
262262
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
263263
&Solver<Dtype>::Solve), SolveOverloads())
264-
.def("step", &Solver<Dtype>::Step);
264+
.def("step", &Solver<Dtype>::Step)
265+
.def("restore", &Solver<Dtype>::Restore);
265266

266267
bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
267268
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(

0 commit comments

Comments
 (0)