Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/xtensor-python/pyarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "pystrides_adaptor.hpp"
#include "xtensor_type_caster_base.hpp"

#include <iostream>

namespace xt
{
template <class T>
Expand Down Expand Up @@ -54,7 +56,7 @@ namespace pybind11
return false;
}
int type_num = xt::detail::numpy_traits<T>::type_num;
if (PyArray_TYPE(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
if(xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
{
return false;
}
Expand Down
39 changes: 38 additions & 1 deletion include/xtensor-python/pycontainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,16 @@ namespace xt
{
private:

// On Windows 64 bits, NPY_INT != NPY_INT32 and NPY_UINT != NPY_UINT32
// We use the NPY_INT32 and NPY_UINT32 which are consistent with the values
// of NPY_LONG and NPY_ULONG
// On Linux x64, NPY_INT64 != NPY_LONGLONG and NPY_UINT64 != NPY_ULONGLONG,
// we use the values of NPY_INT64 and NPY_UINT64 which are consistent with the
// values of NPY_LONG and NPY_ULONG.
constexpr static const int value_list[15] = {
NPY_BOOL,
NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT,
NPY_INT, NPY_UINT, NPY_LONGLONG, NPY_ULONGLONG,
NPY_INT32, NPY_UINT32, NPY_INT64, NPY_UINT64,
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE};

Expand All @@ -138,6 +144,37 @@ namespace xt

static constexpr int type_num = value_list[pybind11::detail::is_fmt_numeric<value_type>::index];
};

// On Linux x64, NPY_INT64 != NPY_LONGLONG and NPY_UINT64 != NPY_ULONGLONG
// NPY_LONGLONG and NPY_ULONGLONG must be adjusted so the right type is
// selected
template <bool>
struct numpy_enum_adjuster
{
static inline int pyarray_type(PyArrayObject* obj)
{
return PyArray_TYPE(obj);
}
};

template <>
struct numpy_enum_adjuster<true>
{
static inline int pyarray_type(PyArrayObject* obj)
{
int res = PyArray_TYPE(obj);
if(res == NPY_LONGLONG || res == NPY_ULONGLONG)
{
res -= 2;
}
return res;
}
};

inline int pyarray_type(PyArrayObject* obj)
{
return numpy_enum_adjuster<NPY_LONGLONG != NPY_INT64>::pyarray_type(obj);
}
}

/******************************
Expand Down
2 changes: 1 addition & 1 deletion include/xtensor-python/pytensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace pybind11
return false;
}
int type_num = xt::detail::numpy_traits<T>::type_num;
if (PyArray_TYPE(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
if(xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
{
return false;
}
Expand Down
50 changes: 50 additions & 0 deletions test_python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ int add(int i, int j)
return i + j;
}

template <class T> std::string typestring() { return "Unknown"; }
template <> std::string typestring<uint8_t>() { return "uint8"; }
template <> std::string typestring<int8_t>() { return "int8"; }
template <> std::string typestring<uint16_t>() { return "uint16"; }
template <> std::string typestring<int16_t>() { return "int16"; }
template <> std::string typestring<uint32_t>() { return "uint32"; }
template <> std::string typestring<int32_t>() { return "int32"; }
template <> std::string typestring<uint64_t>() { return "uint64"; }
template <> std::string typestring<int64_t>() { return "int64"; }

template <class T>
inline std::string int_overload(xt::pyarray<T>& m)
{
return typestring<T>();
}

void dump_numpy_constant()
{
std::cout << "NPY_BOOL = " << NPY_BOOL << std::endl;
std::cout << "NPY_BYTE = " << NPY_BYTE << std::endl;
std::cout << "NPY_UBYTE = " << NPY_UBYTE << std::endl;
std::cout << "NPY_INT8 = " << NPY_INT8 << std::endl;
std::cout << "NPY_UINT8 = " << NPY_UINT8 << std::endl;
std::cout << "NPY_SHORT = " << NPY_SHORT << std::endl;
std::cout << "NPY_USHORT = " << NPY_USHORT << std::endl;
std::cout << "NPY_INT16 = " << NPY_INT16 << std::endl;
std::cout << "NPY_UINT16 = " << NPY_UINT16 << std::endl;
std::cout << "NPY_INT = " << NPY_INT << std::endl;
std::cout << "NPY_UINT = " << NPY_UINT << std::endl;
std::cout << "NPY_INT32 = " << NPY_INT32 << std::endl;
std::cout << "NPY_UINT32 = " << NPY_UINT32 << std::endl;
std::cout << "NPY_LONG = " << NPY_LONG << std::endl;
std::cout << "NPY_ULONG = " << NPY_ULONG << std::endl;
std::cout << "NPY_LONGLONG = " << NPY_LONGLONG << std::endl;
std::cout << "NPY_ULONGLONG = " << NPY_ULONGLONG << std::endl;
std::cout << "NPY_INT64 = " << NPY_INT64 << std::endl;
std::cout << "NPY_UINT64 = " << NPY_UINT64 << std::endl;
}

PYBIND11_PLUGIN(xtensor_python_test)
{
xt::import_numpy();
Expand All @@ -93,5 +132,16 @@ PYBIND11_PLUGIN(xtensor_python_test)
return a.shape() == b.shape();
});

m.def("int_overload", int_overload<uint8_t>);
m.def("int_overload", int_overload<int8_t>);
m.def("int_overload", int_overload<uint16_t>);
m.def("int_overload", int_overload<int16_t>);
m.def("int_overload", int_overload<uint32_t>);
m.def("int_overload", int_overload<int32_t>);
m.def("int_overload", int_overload<uint64_t>);
m.def("int_overload", int_overload<int64_t>);

m.def("dump_numpy_constant", dump_numpy_constant);

return m.ptr();
}
5 changes: 5 additions & 0 deletions test_python/test_pyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ def test_shape_comparison(self):
self.assertFalse(xt.compare_shapes(x, y))
self.assertTrue(xt.compare_shapes(x, z))

def test_int_overload(self):
for dtype in [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64]:
b = xt.int_overload(np.ones((10), dtype))
self.assertEqual(str(dtype.__name__), b)