|
42 | 42 |
|
43 | 43 | #include "stir/find_STIR_config.h" |
44 | 44 | #include "stir/Succeeded.h" |
| 45 | + #include "stir/NumericType.h" |
| 46 | + #include "stir/NumericInfo.h" |
45 | 47 | #include "stir/DetectionPosition.h" |
46 | 48 | #include "stir/Scanner.h" |
47 | 49 | #include "stir/Bin.h" |
|
180 | 182 | // helper code below. It is used to convert a Python object to a float. |
181 | 183 | SWIGINTERN int |
182 | 184 | SWIG_AsVal_double (PyObject * obj, double *val); |
| 185 | + // TODO THIS NEEDS TO BE THE SAME as numpy.i |
| 186 | + // We need it here because we need to include arrayobject in the preamble for swigstir functions |
| 187 | + #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION |
| 188 | + #include <numpy/arrayobject.h> |
| 189 | + #include <numpy/ndarraytypes.h> |
183 | 190 | #endif |
184 | 191 |
|
185 | 192 | // local helper functions for conversions etc. These are not "exposed" to the target language |
|
233 | 240 | return p; |
234 | 241 | } |
235 | 242 |
|
236 | | - // fill an array from a Python sequence |
237 | | - // (could be trivially modified to just write to a C++ iterator) |
238 | | - template <int num_dimensions, typename elemT> |
239 | | - void fill_Array_from_Python_iterator(stir::Array<num_dimensions, elemT> * array_ptr, PyObject* const arg) |
| 243 | + template <typename elemT> |
| 244 | + int get_np_typenum() |
| 245 | + { |
| 246 | + const stir::NumericType type_id = stir::NumericInfo<elemT>().type_id(); |
| 247 | + switch (type_id.id) |
| 248 | + { |
| 249 | + case stir::NumericType::SCHAR: return NPY_BYTE; |
| 250 | + case stir::NumericType::UCHAR: return NPY_UBYTE; |
| 251 | + case stir::NumericType::SHORT: return NPY_SHORT; |
| 252 | + case stir::NumericType::USHORT: return NPY_USHORT; |
| 253 | + case stir::NumericType::LONG: return NPY_LONG; |
| 254 | + case stir::NumericType::ULONG: return NPY_ULONG; |
| 255 | + case stir::NumericType::FLOAT: return NPY_FLOAT; |
| 256 | + case stir::NumericType::DOUBLE: return NPY_DOUBLE; |
| 257 | + default: throw std::runtime_error("Unknown dtype of STIR array"); |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + // fill an iterator from a Python sequence |
| 262 | + template <typename elemT, typename IterT> |
| 263 | + void fill_iterator_from_Python_iterator(IterT cpp_iter, IterT cpp_iter_end, PyObject* const arg) |
240 | 264 | { |
241 | 265 | if (!PyIter_Check(arg)) |
242 | 266 | throw std::runtime_error("STIR-Python internal error: fill_Array_from_Python_iterators called but input is not an iterator"); |
|
245 | 269 | PyObject *iterator = PyObject_GetIter(arg); |
246 | 270 |
|
247 | 271 | PyObject *item; |
248 | | - typename stir::Array<num_dimensions, elemT>::full_iterator array_iter = array_ptr->begin_all(); |
249 | | - while ((item = PyIter_Next(iterator)) && array_iter != array_ptr->end_all()) |
| 272 | + while ((item = PyIter_Next(iterator)) && (cpp_iter != cpp_iter_end)) |
250 | 273 | { |
251 | 274 | double val; |
252 | 275 | // TODO currently hard-wired as double which might imply extra conversions |
253 | 276 | int ecode = SWIG_AsVal_double(item, &val); |
254 | 277 | if (SWIG_IsOK(ecode)) |
255 | 278 | { |
256 | | - *array_iter++ = static_cast<elemT>(val); |
| 279 | + *cpp_iter++ = static_cast<elemT>(val); |
257 | 280 | } |
258 | 281 | else |
259 | 282 | { |
|
266 | 289 | } |
267 | 290 | Py_DECREF(item); |
268 | 291 | } |
269 | | - |
270 | | - if (PyIter_Next(iterator) != NULL || array_iter != array_ptr->end_all()) |
| 292 | + if (PyIter_Next(iterator) != NULL || cpp_iter != cpp_iter_end) |
271 | 293 | { |
272 | 294 | throw std::runtime_error("fill() called with incorrect range of iterators, array needs to have the same number of elements"); |
273 | 295 | } |
|
281 | 303 |
|
282 | 304 | } |
283 | 305 |
|
284 | | -#if 0 |
285 | | - |
286 | | - // TODO does not work yet. |
287 | | - // it doesn't compile as includes are in init section, which follows after this in the wrapper |
288 | | - // Even if it did compile, it might not work anyway as I haven't tested it. |
289 | | - template <typename IterT> |
290 | | - void fill_nparray_from_iterator(PyObject * np, IterT iterator) |
| 306 | + // fill a STIR array from a Python sequence |
| 307 | + template <int num_dimensions, typename elemT> |
| 308 | + void fill_Array_from_Python_iterator(stir::Array<num_dimensions, elemT> * array_ptr, PyObject* const arg) |
| 309 | + { |
| 310 | + fill_iterator_from_Python_iterator<elemT>(array_ptr->begin_all(), array_ptr->end_all(), arg); |
| 311 | + } |
| 312 | + |
| 313 | + template <typename elemT, typename IterT> |
| 314 | + void fill_nparray_from_iterator(PyArrayObject * np, IterT cpp_iter) |
| 315 | + { |
| 316 | + if (!PyArray_EquivTypenums(PyArray_TYPE(np), get_np_typenum<elemT>())) |
| 317 | + { |
| 318 | + throw std::runtime_error("stir_object.fill needs to be called with numpy array of correct dtype"); |
| 319 | + } |
| 320 | + |
| 321 | +#if 1 |
| 322 | + auto iter = NpyIter_New(np, NPY_ITER_READONLY, NPY_KEEPORDER, NPY_NO_CASTING, NULL); |
| 323 | + if (iter==NULL) { |
| 324 | + return; |
| 325 | + } |
| 326 | + auto iternext = NpyIter_GetIterNext(iter, NULL); |
| 327 | + auto dataptr = (elemT **) NpyIter_GetDataPtrArray(iter); |
| 328 | + do { |
| 329 | + **dataptr = *cpp_iter++; |
| 330 | + } |
| 331 | + while (iternext(iter)); |
| 332 | +#else |
| 333 | + // generic alternative, but doesn't compile and might be slower |
| 334 | + auto iterator = PyObject_GetIter(np_array); |
| 335 | + PyObject *item; |
| 336 | + while (item = PyIter_Next(iterator)) |
| 337 | + { |
| 338 | + *item = *cpp_iter++; // this does not compile. not sure how to assign |
| 339 | + Py_DECREF(item); |
| 340 | + } |
| 341 | +#endif |
| 342 | + } |
| 343 | + |
| 344 | + template <typename elemT, typename IterT> |
| 345 | + void fill_iterator_from_nparray(IterT iterator, PyArrayObject * np) |
291 | 346 | { |
| 347 | + if (!PyArray_EquivTypenums(PyArray_TYPE(np), get_np_typenum<elemT>())) |
| 348 | + { |
| 349 | + throw std::runtime_error("stir_object.fill needs to be called with numpy array of correct dtype"); |
| 350 | + } |
| 351 | + |
292 | 352 | // This code is more or less a copy of the "simple iterator example" (!) in the Numpy doc |
293 | | - // see e.g. http://students.mimuw.edu.pl/~pbechler/numpy_doc/reference/c-api.iterator.html |
294 | | - typedef float elemT; |
295 | | - NpyIter* iter; |
296 | | - NpyIter_IterNextFunc *iternext; |
297 | | - char** dataptr; |
298 | | - npy_intp* strideptr,* innersizeptr; |
| 353 | + // see e.g. https://numpy.org/doc/stable/reference/c-api/iterator.html |
299 | 354 |
|
300 | 355 | /* Handle zero-sized arrays specially */ |
301 | 356 | if (PyArray_SIZE(np) == 0) { |
|
317 | 372 | * casting NPY_NO_CASTING |
318 | 373 | * - No casting is required for this operation. |
319 | 374 | */ |
320 | | - iter = NpyIter_New(np, NPY_ITER_WRITEONLY| |
| 375 | +#if 0 |
| 376 | + // code for simpler loop, but it is likely slower |
| 377 | + auto iter = NpyIter_New(np, NPY_ITER_READONLY, |
| 378 | + NPY_KEEPORDER, NPY_NO_CASTING, |
| 379 | + NULL); |
| 380 | +#else |
| 381 | + auto iter = NpyIter_New(np, NPY_ITER_READONLY| |
321 | 382 | NPY_ITER_EXTERNAL_LOOP, |
322 | 383 | NPY_KEEPORDER, NPY_NO_CASTING, |
323 | 384 | NULL); |
| 385 | +#endif |
324 | 386 | if (iter == NULL) { |
325 | 387 | throw std::runtime_error("Error creating numpy iterator"); |
326 | 388 | } |
|
329 | 391 | * The iternext function gets stored in a local variable |
330 | 392 | * so it can be called repeatedly in an efficient manner. |
331 | 393 | */ |
332 | | - iternext = NpyIter_GetIterNext(iter, NULL); |
| 394 | + auto iternext = NpyIter_GetIterNext(iter, NULL); |
333 | 395 | if (iternext == NULL) { |
334 | 396 | NpyIter_Deallocate(iter); |
335 | 397 | throw std::runtime_error("Error creating numpy iterator function"); |
336 | 398 | } |
| 399 | +#if 0 |
| 400 | + // code for simpler loop, but it is likely slower |
| 401 | + auto dataptr = (elemT **) NpyIter_GetDataPtrArray(iter); |
| 402 | + do { |
| 403 | + *iterator++ = **dataptr; |
| 404 | + /* Increment the iterator to the next inner loop */ |
| 405 | + } while(iternext(iter)); |
| 406 | +#else |
337 | 407 | /* The location of the data pointer which the iterator may update */ |
338 | | - dataptr = NpyIter_GetDataPtrArray(iter); |
| 408 | + auto dataptr = NpyIter_GetDataPtrArray(iter); |
339 | 409 | /* The location of the stride which the iterator may update */ |
340 | | - strideptr = NpyIter_GetInnerStrideArray(iter); |
| 410 | + auto strideptr = NpyIter_GetInnerStrideArray(iter); |
341 | 411 | /* The location of the inner loop size which the iterator may update */ |
342 | | - innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); |
| 412 | + auto innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); |
343 | 413 |
|
344 | 414 | /* The iteration loop */ |
345 | 415 | do { |
346 | 416 | /* Get the inner loop data/stride/count values */ |
347 | | - char* data = *dataptr; |
| 417 | + auto data = *dataptr; |
348 | 418 | npy_intp stride = *strideptr; |
349 | 419 | npy_intp count = *innersizeptr; |
350 | 420 |
|
351 | 421 | /* This is a typical inner loop for NPY_ITER_EXTERNAL_LOOP */ |
352 | 422 | while (count--) { |
353 | | - *(reinterpret_cast<elemT *>(data)) = static_cast<elemT>(*iterator++); |
354 | | - data += stride; |
| 423 | + *iterator++ = *(reinterpret_cast<elemT *>(data)); |
| 424 | + data += stride; |
355 | 425 | } |
356 | | - |
357 | 426 | /* Increment the iterator to the next inner loop */ |
358 | 427 | } while(iternext(iter)); |
| 428 | +#endif |
359 | 429 |
|
360 | 430 | NpyIter_Deallocate(iter); |
361 | 431 | } |
362 | | -#endif |
363 | 432 |
|
364 | 433 | #elif defined(SWIGMATLAB) |
365 | 434 | // convert stir::Array to matlab (currently always converting to double) |
@@ -713,13 +782,32 @@ namespace std { |
713 | 782 | return new SwigPyForwardIteratorClosed_T<OutIter>(current, begin, end, seq); |
714 | 783 | } |
715 | 784 |
|
716 | | - |
717 | 785 | #endif |
718 | | - static Array<4,float> create_array_for_proj_data(const ProjData& proj_data) |
| 786 | + |
| 787 | + // helper function that allocates a stir::Array of appropriate size |
| 788 | + static stir::BasicCoordinate<4, int> array_for_proj_data_size(const ProjData& proj_data) |
719 | 789 | { |
720 | 790 | const int num_non_tof_sinos = proj_data.get_num_non_tof_sinograms(); |
721 | | - Array<4,float> array(IndexRange4D(proj_data.get_num_tof_poss(),num_non_tof_sinos, proj_data.get_num_views(), proj_data.get_num_tangential_poss())); |
722 | | - return array; |
| 791 | + return stir::make_coordinate(proj_data.get_num_tof_poss(),num_non_tof_sinos, proj_data.get_num_views(), proj_data.get_num_tangential_poss()); |
| 792 | + } |
| 793 | + |
| 794 | + static Array<4,float> create_array_for_proj_data(const ProjData& proj_data) |
| 795 | + { |
| 796 | + Array<4,float> array(IndexRange4D(array_for_proj_data_size(proj_data))); |
| 797 | + return array; |
| 798 | + } |
| 799 | + |
| 800 | + // helper function that allocates a numpy.ndarray of appropriate size |
| 801 | + static PyArrayObject* create_nparray_for_proj_data(const ProjData& proj_data) |
| 802 | + { |
| 803 | + const auto stir_sizes = swigstir::array_for_proj_data_size(proj_data); |
| 804 | + constexpr int num_dimensions = 4; |
| 805 | + npy_intp dims[num_dimensions]; |
| 806 | + for (int d=0; d<num_dimensions; ++d) |
| 807 | + dims[d] = stir_sizes[d + 1]; |
| 808 | + auto np_array = |
| 809 | + (PyArrayObject *)PyArray_SimpleNew(num_dimensions, dims, NPY_FLOAT); |
| 810 | + return np_array; |
723 | 811 | } |
724 | 812 |
|
725 | 813 | // a function for converting ProjData to a 4D array as that's what is easy to use |
|
0 commit comments