@@ -228,6 +228,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
228228 return bp::object ();
229229}
230230
231+ template <typename Dtype>
232+ class PythonCallback : public Solver <Dtype>::Callback {
233+ protected:
234+ bp::object on_start_, on_gradients_ready_;
235+
236+ public:
237+ PythonCallback (bp::object on_start, bp::object on_gradients_ready)
238+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
239+ virtual void on_gradients_ready () {
240+ on_gradients_ready_ ();
241+ }
242+ virtual void on_start () {
243+ on_start_ ();
244+ }
245+ };
246+ template <typename Dtype>
247+ void Solver_add_callback (Solver<Dtype> * solver, bp::object on_start,
248+ bp::object on_gradients_ready) {
249+ solver->add_callback (new PythonCallback<Dtype>(on_start, on_gradients_ready));
250+ }
251+
231252BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS (SolveOverloads, Solve, 0 , 1 );
232253
233254BOOST_PYTHON_MODULE (_caffe) {
@@ -317,6 +338,7 @@ BOOST_PYTHON_MODULE(_caffe) {
317338 .add_property (" test_nets" , bp::make_function (&Solver<Dtype>::test_nets,
318339 bp::return_internal_reference<>()))
319340 .add_property (" iter" , &Solver<Dtype>::iter)
341+ .def (" add_callback" , &Solver_add_callback<Dtype>)
320342 .def (" solve" , static_cast <void (Solver<Dtype>::*)(const char *)>(
321343 &Solver<Dtype>::Solve), SolveOverloads ())
322344 .def (" step" , &Solver<Dtype>::Step)
0 commit comments