From 01a7e74c9a1de16ac652390f3ad08e70fe8cd69b Mon Sep 17 00:00:00 2001 From: Johan Mabille Date: Mon, 6 Nov 2017 16:32:18 +0100 Subject: [PATCH] overload fix --- include/xtensor-python/pyarray.hpp | 4 ++- include/xtensor-python/pycontainer.hpp | 39 +++++++++++++++++++- include/xtensor-python/pytensor.hpp | 2 +- test_python/main.cpp | 50 ++++++++++++++++++++++++++ test_python/test_pyarray.py | 5 +++ 5 files changed, 97 insertions(+), 3 deletions(-) diff --git a/include/xtensor-python/pyarray.hpp b/include/xtensor-python/pyarray.hpp index 1e33dfe..21a4d7b 100644 --- a/include/xtensor-python/pyarray.hpp +++ b/include/xtensor-python/pyarray.hpp @@ -21,6 +21,8 @@ #include "pystrides_adaptor.hpp" #include "xtensor_type_caster_base.hpp" +#include + namespace xt { template @@ -54,7 +56,7 @@ namespace pybind11 return false; } int type_num = xt::detail::numpy_traits::type_num; - if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + if(xt::detail::pyarray_type(reinterpret_cast(src.ptr())) != type_num) { return false; } diff --git a/include/xtensor-python/pycontainer.hpp b/include/xtensor-python/pycontainer.hpp index 962cc04..4d151c6 100644 --- a/include/xtensor-python/pycontainer.hpp +++ b/include/xtensor-python/pycontainer.hpp @@ -125,10 +125,16 @@ namespace xt { private: + // On Windows 64 bits, NPY_INT != NPY_INT32 and NPY_UINT != NPY_UINT32 + // We use the NPY_INT32 and NPY_UINT32 which are consistent with the values + // of NPY_LONG and NPY_ULONG + // On Linux x64, NPY_INT64 != NPY_LONGLONG and NPY_UINT64 != NPY_ULONGLONG, + // we use the values of NPY_INT64 and NPY_UINT64 which are consistent with the + // values of NPY_LONG and NPY_ULONG. constexpr static const int value_list[15] = { NPY_BOOL, NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT, - NPY_INT, NPY_UINT, NPY_LONGLONG, NPY_ULONGLONG, + NPY_INT32, NPY_UINT32, NPY_INT64, NPY_UINT64, NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE}; @@ -138,6 +144,37 @@ namespace xt static constexpr int type_num = value_list[pybind11::detail::is_fmt_numeric::index]; }; + + // On Linux x64, NPY_INT64 != NPY_LONGLONG and NPY_UINT64 != NPY_ULONGLONG + // NPY_LONGLONG and NPY_ULONGLONG must be adjusted so the right type is + // selected + template + struct numpy_enum_adjuster + { + static inline int pyarray_type(PyArrayObject* obj) + { + return PyArray_TYPE(obj); + } + }; + + template <> + struct numpy_enum_adjuster + { + static inline int pyarray_type(PyArrayObject* obj) + { + int res = PyArray_TYPE(obj); + if(res == NPY_LONGLONG || res == NPY_ULONGLONG) + { + res -= 2; + } + return res; + } + }; + + inline int pyarray_type(PyArrayObject* obj) + { + return numpy_enum_adjuster::pyarray_type(obj); + } } /****************************** diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index a0107de..bf3a442 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -55,7 +55,7 @@ namespace pybind11 return false; } int type_num = xt::detail::numpy_traits::type_num; - if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + if(xt::detail::pyarray_type(reinterpret_cast(src.ptr())) != type_num) { return false; } diff --git a/test_python/main.cpp b/test_python/main.cpp index fbfc170..58b7d69 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -68,6 +68,45 @@ int add(int i, int j) return i + j; } +template std::string typestring() { return "Unknown"; } +template <> std::string typestring() { return "uint8"; } +template <> std::string typestring() { return "int8"; } +template <> std::string typestring() { return "uint16"; } +template <> std::string typestring() { return "int16"; } +template <> std::string typestring() { return "uint32"; } +template <> std::string typestring() { return "int32"; } +template <> std::string typestring() { return "uint64"; } +template <> std::string typestring() { return "int64"; } + +template +inline std::string int_overload(xt::pyarray& m) +{ + return typestring(); +} + +void dump_numpy_constant() +{ + std::cout << "NPY_BOOL = " << NPY_BOOL << std::endl; + std::cout << "NPY_BYTE = " << NPY_BYTE << std::endl; + std::cout << "NPY_UBYTE = " << NPY_UBYTE << std::endl; + std::cout << "NPY_INT8 = " << NPY_INT8 << std::endl; + std::cout << "NPY_UINT8 = " << NPY_UINT8 << std::endl; + std::cout << "NPY_SHORT = " << NPY_SHORT << std::endl; + std::cout << "NPY_USHORT = " << NPY_USHORT << std::endl; + std::cout << "NPY_INT16 = " << NPY_INT16 << std::endl; + std::cout << "NPY_UINT16 = " << NPY_UINT16 << std::endl; + std::cout << "NPY_INT = " << NPY_INT << std::endl; + std::cout << "NPY_UINT = " << NPY_UINT << std::endl; + std::cout << "NPY_INT32 = " << NPY_INT32 << std::endl; + std::cout << "NPY_UINT32 = " << NPY_UINT32 << std::endl; + std::cout << "NPY_LONG = " << NPY_LONG << std::endl; + std::cout << "NPY_ULONG = " << NPY_ULONG << std::endl; + std::cout << "NPY_LONGLONG = " << NPY_LONGLONG << std::endl; + std::cout << "NPY_ULONGLONG = " << NPY_ULONGLONG << std::endl; + std::cout << "NPY_INT64 = " << NPY_INT64 << std::endl; + std::cout << "NPY_UINT64 = " << NPY_UINT64 << std::endl; +} + PYBIND11_PLUGIN(xtensor_python_test) { xt::import_numpy(); @@ -93,5 +132,16 @@ PYBIND11_PLUGIN(xtensor_python_test) return a.shape() == b.shape(); }); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + m.def("int_overload", int_overload); + + m.def("dump_numpy_constant", dump_numpy_constant); + return m.ptr(); } diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index 431f56e..7aa67a8 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -82,3 +82,8 @@ def test_shape_comparison(self): self.assertFalse(xt.compare_shapes(x, y)) self.assertTrue(xt.compare_shapes(x, z)) + def test_int_overload(self): + for dtype in [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64]: + b = xt.int_overload(np.ones((10), dtype)) + self.assertEqual(str(dtype.__name__), b) +