@@ -139,9 +139,9 @@ public:
139139 QOCOSettings *get_settings();
140140 PyQOCOSolution &get_solution();
141141
142- QOCOInt update_settings(const QOCOSettings &);
143- // QOCOInt update_vector_data(py::object, py::object, py::object);
144- // QOCOInt update_matrix_data(py::object, py::object, py::object);
142+ QOCOInt update_settings(const QOCOSettings &new_settings );
143+ void update_vector_data(py::object cnew , py::object bnew , py::object hnew );
144+ void update_matrix_data(py::object Pxnew , py::object Axnew , py::object Gxnew );
145145
146146 QOCOInt solve();
147147
@@ -241,6 +241,78 @@ QOCOInt PyQOCOSolver::update_settings(const QOCOSettings &new_settings)
241241 return qoco_update_settings(this->_solver, &new_settings);
242242}
243243
244+ void PyQOCOSolver::update_vector_data(py::object cnew, py::object bnew, py::object hnew)
245+ {
246+ QOCOFloat *cnew_ptr = nullptr;
247+ QOCOFloat *bnew_ptr = nullptr;
248+ QOCOFloat *hnew_ptr = nullptr;
249+
250+ if (cnew != py::none())
251+ {
252+ auto cnew_arr = cnew.cast<py::array_t<QOCOFloat>>();
253+ auto buf = cnew_arr.request();
254+ if (buf.shape[0] != this->n)
255+ throw std::runtime_error("cnew size must be n = " + std::to_string(this->n));
256+ cnew_ptr = (QOCOFloat *)buf.ptr;
257+ }
258+
259+ if (bnew != py::none())
260+ {
261+ auto bnew_arr = bnew.cast<py::array_t<QOCOFloat>>();
262+ auto buf = bnew_arr.request();
263+ if (buf.shape[0] != this->p)
264+ throw std::runtime_error("bnew size must be p = " + std::to_string(this->p));
265+ bnew_ptr = (QOCOFloat *)buf.ptr;
266+ }
267+
268+ if (hnew != py::none())
269+ {
270+ auto hnew_arr = hnew.cast<py::array_t<QOCOFloat>>();
271+ auto buf = hnew_arr.request();
272+ if (buf.shape[0] != this->m)
273+ throw std::runtime_error("hnew size must be m = " + std::to_string(this->m));
274+ hnew_ptr = (QOCOFloat *)buf.ptr;
275+ }
276+
277+ qoco_update_vector_data(this->_solver, cnew_ptr, bnew_ptr, hnew_ptr);
278+ }
279+
280+ void PyQOCOSolver::update_matrix_data(py::object Pxnew, py::object Axnew, py::object Gxnew)
281+ {
282+ QOCOFloat *Pxnew_ptr = nullptr;
283+ QOCOFloat *Axnew_ptr = nullptr;
284+ QOCOFloat *Gxnew_ptr = nullptr;
285+
286+ if (Pxnew != py::none())
287+ {
288+ auto Pxnew_arr = Pxnew.cast<py::array_t<QOCOFloat>>();
289+ auto buf = Pxnew_arr.request();
290+ if (buf.ndim != 1)
291+ throw std::runtime_error("Pxnew must be 1-D array");
292+ Pxnew_ptr = (QOCOFloat *)buf.ptr;
293+ }
294+
295+ if (Axnew != py::none())
296+ {
297+ auto Axnew_arr = Axnew.cast<py::array_t<QOCOFloat>>();
298+ auto buf = Axnew_arr.request();
299+ if (buf.ndim != 1)
300+ throw std::runtime_error("Axnew must be 1-D array");
301+ Axnew_ptr = (QOCOFloat *)buf.ptr;
302+ }
303+
304+ if (Gxnew != py::none())
305+ {
306+ auto Gxnew_arr = Gxnew.cast<py::array_t<QOCOFloat>>();
307+ auto buf = Gxnew_arr.request();
308+ if (buf.ndim != 1)
309+ throw std::runtime_error("Gxnew must be 1-D array");
310+ Gxnew_ptr = (QOCOFloat *)buf.ptr;
311+ }
312+
313+ qoco_update_matrix_data(this->_solver, Pxnew_ptr, Axnew_ptr, Gxnew_ptr);
314+ }
315+
244316PYBIND11_MODULE(@QOCO_EXT_MODULE_NAME@, m)
245317{
246318 // Enums.
@@ -308,6 +380,8 @@ PYBIND11_MODULE(@QOCO_EXT_MODULE_NAME@, m)
308380 .def(py::init<QOCOInt, QOCOInt, QOCOInt, const CSC &, const py::array_t<QOCOFloat>, const CSC &, const py::array_t<QOCOFloat>, const CSC &, const py::array_t<QOCOFloat>, QOCOInt, QOCOInt, const py::array_t<QOCOInt>, QOCOSettings *>(), "n"_a, "m"_a, "p"_a, "P"_a, "c"_a.noconvert(), "A"_a, "b"_a.noconvert(), "G"_a, "h"_a.noconvert(), "l"_a, "nsoc"_a, "q"_a.noconvert(), "settings"_a)
309381 .def_property_readonly("solution", &PyQOCOSolver::get_solution, py::return_value_policy::reference)
310382 .def("update_settings", &PyQOCOSolver::update_settings)
383+ .def("update_vector_data", &PyQOCOSolver::update_vector_data, "cnew"_a=py::none(), "bnew"_a=py::none(), "hnew"_a=py::none())
384+ .def("update_matrix_data", &PyQOCOSolver::update_matrix_data, "Pxnew"_a=py::none(), "Axnew"_a=py::none(), "Gxnew"_a=py::none())
311385 .def("solve", &PyQOCOSolver::solve)
312386 .def("get_settings", &PyQOCOSolver::get_settings, py::return_value_policy::reference);
313387}
0 commit comments