|
| 1 | +#include <pybind11/pybind11.h> |
| 2 | +#include <pybind11/stl.h> |
| 3 | +#include <pybind11/numpy.h> |
| 4 | +#include "../include/feather.h" |
| 5 | + |
| 6 | +namespace py = pybind11; |
| 7 | + |
| 8 | +PYBIND11_MODULE(feather_py, m) { |
| 9 | + m.doc() = "Feather: SQLite for Vectors"; |
| 10 | + |
| 11 | + py::class_<feather::DB, std::unique_ptr<feather::DB, py::nodelete>>(m, "DB") |
| 12 | + .def_static("open", &feather::DB::open, py::arg("path"), py::arg("dim") = 768) |
| 13 | + .def("add", [](feather::DB& db, uint64_t id, py::array_t<float> vec) { |
| 14 | + auto buf = vec.request(); |
| 15 | + if (buf.size != db.dim()) throw std::runtime_error("Dimension mismatch"); |
| 16 | + const float* ptr = static_cast<const float*>(buf.ptr); |
| 17 | + std::vector<float> vec_copy(ptr, ptr + buf.size); |
| 18 | + db.add(id, vec_copy); |
| 19 | + }) |
| 20 | + .def("search", [](const feather::DB& db, py::array_t<float> q, size_t k = 5) { |
| 21 | + auto buf = q.request(); |
| 22 | + if (buf.size != db.dim()) throw std::runtime_error("Query dimension mismatch"); |
| 23 | + const float* ptr = static_cast<const float*>(buf.ptr); |
| 24 | + std::vector<float> query(ptr, ptr + buf.size); |
| 25 | + auto results = db.search(query, k); |
| 26 | + |
| 27 | + py::array_t<uint64_t> ids(results.size()); |
| 28 | + py::array_t<float> distances(results.size()); |
| 29 | + auto ids_ptr = ids.mutable_data(); |
| 30 | + auto dist_ptr = distances.mutable_data(); |
| 31 | + |
| 32 | + for (size_t i = 0; i < results.size(); ++i) { |
| 33 | + auto [id, dist] = results[i]; |
| 34 | + ids_ptr[i] = id; |
| 35 | + dist_ptr[i] = dist; |
| 36 | + } |
| 37 | + return py::make_tuple(ids, distances); |
| 38 | + }, py::arg("q"), py::arg("k") = 5) |
| 39 | + .def("save", &feather::DB::save) |
| 40 | + .def("dim", &feather::DB::dim); |
| 41 | +} |
0 commit comments