Skip to content

Commit 65fb3ab

Browse files
Merge pull request #1632 from KrisThielemans/numpy_as_array
Python: addition of as_array() and fill(numpy.ndarray)
2 parents 4210550 + 6319c16 commit 65fb3ab

File tree

6 files changed

+291
-73
lines changed

6 files changed

+291
-73
lines changed

documentation/release_6.3.htm

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,14 @@ <h4>Python</h4>
7777
<code>a = b + c</code> and <code>a -= 3</code>. Note that <code>a = 1 + b</code> is not
7878
yet available.<br>
7979
<a href=https://github.com/UCL/STIR/pull/1630>PR #1630</a>
80-
</li>
80+
</li>
81+
<li>
82+
The above "container" classes now have an extra member `as_array()` which returns a numpy <code>ndarray</code>. This
83+
is equivalent to `stirextra.to_numpy()` which will become deprecated later. In addition, the
84+
<code>fill()</code> method now directly accepts an <code>ndarray</code>, avoiding the need to go via an iterator.
85+
These additions also make it easier to prt SIRF python code to STIR.<br>
86+
<a href=https://github.com/UCL/STIR/pull/1632>PR #1632</a>
87+
</li>
8188
<li>
8289
Added a Python script to convert e7tools generated Siemens Biograph Vision 600 sinograms to STIR compatible format.<br>
8390
<a href=https://github.com/UCL/STIR/pull/1593>PR #1593</a>

src/swig/stir.i

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
#include "stir/find_STIR_config.h"
4444
#include "stir/Succeeded.h"
45+
#include "stir/NumericType.h"
46+
#include "stir/NumericInfo.h"
4547
#include "stir/DetectionPosition.h"
4648
#include "stir/Scanner.h"
4749
#include "stir/Bin.h"
@@ -180,6 +182,11 @@
180182
// helper code below. It is used to convert a Python object to a float.
181183
SWIGINTERN int
182184
SWIG_AsVal_double (PyObject * obj, double *val);
185+
// TODO THIS NEEDS TO BE THE SAME as numpy.i
186+
// We need it here because we need to include arrayobject in the preamble for swigstir functions
187+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
188+
#include <numpy/arrayobject.h>
189+
#include <numpy/ndarraytypes.h>
183190
#endif
184191

185192
// local helper functions for conversions etc. These are not "exposed" to the target language
@@ -233,10 +240,27 @@
233240
return p;
234241
}
235242

236-
// fill an array from a Python sequence
237-
// (could be trivially modified to just write to a C++ iterator)
238-
template <int num_dimensions, typename elemT>
239-
void fill_Array_from_Python_iterator(stir::Array<num_dimensions, elemT> * array_ptr, PyObject* const arg)
243+
template <typename elemT>
244+
int get_np_typenum()
245+
{
246+
const stir::NumericType type_id = stir::NumericInfo<elemT>().type_id();
247+
switch (type_id.id)
248+
{
249+
case stir::NumericType::SCHAR: return NPY_BYTE;
250+
case stir::NumericType::UCHAR: return NPY_UBYTE;
251+
case stir::NumericType::SHORT: return NPY_SHORT;
252+
case stir::NumericType::USHORT: return NPY_USHORT;
253+
case stir::NumericType::LONG: return NPY_LONG;
254+
case stir::NumericType::ULONG: return NPY_ULONG;
255+
case stir::NumericType::FLOAT: return NPY_FLOAT;
256+
case stir::NumericType::DOUBLE: return NPY_DOUBLE;
257+
default: throw std::runtime_error("Unknown dtype of STIR array");
258+
}
259+
}
260+
261+
// fill an iterator from a Python sequence
262+
template <typename elemT, typename IterT>
263+
void fill_iterator_from_Python_iterator(IterT cpp_iter, IterT cpp_iter_end, PyObject* const arg)
240264
{
241265
if (!PyIter_Check(arg))
242266
throw std::runtime_error("STIR-Python internal error: fill_Array_from_Python_iterators called but input is not an iterator");
@@ -245,15 +269,14 @@
245269
PyObject *iterator = PyObject_GetIter(arg);
246270

247271
PyObject *item;
248-
typename stir::Array<num_dimensions, elemT>::full_iterator array_iter = array_ptr->begin_all();
249-
while ((item = PyIter_Next(iterator)) && array_iter != array_ptr->end_all())
272+
while ((item = PyIter_Next(iterator)) && (cpp_iter != cpp_iter_end))
250273
{
251274
double val;
252275
// TODO currently hard-wired as double which might imply extra conversions
253276
int ecode = SWIG_AsVal_double(item, &val);
254277
if (SWIG_IsOK(ecode))
255278
{
256-
*array_iter++ = static_cast<elemT>(val);
279+
*cpp_iter++ = static_cast<elemT>(val);
257280
}
258281
else
259282
{
@@ -266,8 +289,7 @@
266289
}
267290
Py_DECREF(item);
268291
}
269-
270-
if (PyIter_Next(iterator) != NULL || array_iter != array_ptr->end_all())
292+
if (PyIter_Next(iterator) != NULL || cpp_iter != cpp_iter_end)
271293
{
272294
throw std::runtime_error("fill() called with incorrect range of iterators, array needs to have the same number of elements");
273295
}
@@ -281,21 +303,54 @@
281303

282304
}
283305

284-
#if 0
285-
286-
// TODO does not work yet.
287-
// it doesn't compile as includes are in init section, which follows after this in the wrapper
288-
// Even if it did compile, it might not work anyway as I haven't tested it.
289-
template <typename IterT>
290-
void fill_nparray_from_iterator(PyObject * np, IterT iterator)
306+
// fill a STIR array from a Python sequence
307+
template <int num_dimensions, typename elemT>
308+
void fill_Array_from_Python_iterator(stir::Array<num_dimensions, elemT> * array_ptr, PyObject* const arg)
309+
{
310+
fill_iterator_from_Python_iterator<elemT>(array_ptr->begin_all(), array_ptr->end_all(), arg);
311+
}
312+
313+
template <typename elemT, typename IterT>
314+
void fill_nparray_from_iterator(PyArrayObject * np, IterT cpp_iter)
315+
{
316+
if (!PyArray_EquivTypenums(PyArray_TYPE(np), get_np_typenum<elemT>()))
317+
{
318+
throw std::runtime_error("stir_object.fill needs to be called with numpy array of correct dtype");
319+
}
320+
321+
#if 1
322+
auto iter = NpyIter_New(np, NPY_ITER_READONLY, NPY_KEEPORDER, NPY_NO_CASTING, NULL);
323+
if (iter==NULL) {
324+
return;
325+
}
326+
auto iternext = NpyIter_GetIterNext(iter, NULL);
327+
auto dataptr = (elemT **) NpyIter_GetDataPtrArray(iter);
328+
do {
329+
**dataptr = *cpp_iter++;
330+
}
331+
while (iternext(iter));
332+
#else
333+
// generic alternative, but doesn't compile and might be slower
334+
auto iterator = PyObject_GetIter(np_array);
335+
PyObject *item;
336+
while (item = PyIter_Next(iterator))
337+
{
338+
*item = *cpp_iter++; // this does not compile. not sure how to assign
339+
Py_DECREF(item);
340+
}
341+
#endif
342+
}
343+
344+
template <typename elemT, typename IterT>
345+
void fill_iterator_from_nparray(IterT iterator, PyArrayObject * np)
291346
{
347+
if (!PyArray_EquivTypenums(PyArray_TYPE(np), get_np_typenum<elemT>()))
348+
{
349+
throw std::runtime_error("stir_object.fill needs to be called with numpy array of correct dtype");
350+
}
351+
292352
// This code is more or less a copy of the "simple iterator example" (!) in the Numpy doc
293-
// see e.g. http://students.mimuw.edu.pl/~pbechler/numpy_doc/reference/c-api.iterator.html
294-
typedef float elemT;
295-
NpyIter* iter;
296-
NpyIter_IterNextFunc *iternext;
297-
char** dataptr;
298-
npy_intp* strideptr,* innersizeptr;
353+
// see e.g. https://numpy.org/doc/stable/reference/c-api/iterator.html
299354

300355
/* Handle zero-sized arrays specially */
301356
if (PyArray_SIZE(np) == 0) {
@@ -317,10 +372,17 @@
317372
* casting NPY_NO_CASTING
318373
* - No casting is required for this operation.
319374
*/
320-
iter = NpyIter_New(np, NPY_ITER_WRITEONLY|
375+
#if 0
376+
// code for simpler loop, but it is likely slower
377+
auto iter = NpyIter_New(np, NPY_ITER_READONLY,
378+
NPY_KEEPORDER, NPY_NO_CASTING,
379+
NULL);
380+
#else
381+
auto iter = NpyIter_New(np, NPY_ITER_READONLY|
321382
NPY_ITER_EXTERNAL_LOOP,
322383
NPY_KEEPORDER, NPY_NO_CASTING,
323384
NULL);
385+
#endif
324386
if (iter == NULL) {
325387
throw std::runtime_error("Error creating numpy iterator");
326388
}
@@ -329,37 +391,44 @@
329391
* The iternext function gets stored in a local variable
330392
* so it can be called repeatedly in an efficient manner.
331393
*/
332-
iternext = NpyIter_GetIterNext(iter, NULL);
394+
auto iternext = NpyIter_GetIterNext(iter, NULL);
333395
if (iternext == NULL) {
334396
NpyIter_Deallocate(iter);
335397
throw std::runtime_error("Error creating numpy iterator function");
336398
}
399+
#if 0
400+
// code for simpler loop, but it is likely slower
401+
auto dataptr = (elemT **) NpyIter_GetDataPtrArray(iter);
402+
do {
403+
*iterator++ = **dataptr;
404+
/* Increment the iterator to the next inner loop */
405+
} while(iternext(iter));
406+
#else
337407
/* The location of the data pointer which the iterator may update */
338-
dataptr = NpyIter_GetDataPtrArray(iter);
408+
auto dataptr = NpyIter_GetDataPtrArray(iter);
339409
/* The location of the stride which the iterator may update */
340-
strideptr = NpyIter_GetInnerStrideArray(iter);
410+
auto strideptr = NpyIter_GetInnerStrideArray(iter);
341411
/* The location of the inner loop size which the iterator may update */
342-
innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
412+
auto innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
343413

344414
/* The iteration loop */
345415
do {
346416
/* Get the inner loop data/stride/count values */
347-
char* data = *dataptr;
417+
auto data = *dataptr;
348418
npy_intp stride = *strideptr;
349419
npy_intp count = *innersizeptr;
350420

351421
/* This is a typical inner loop for NPY_ITER_EXTERNAL_LOOP */
352422
while (count--) {
353-
*(reinterpret_cast<elemT *>(data)) = static_cast<elemT>(*iterator++);
354-
data += stride;
423+
*iterator++ = *(reinterpret_cast<elemT *>(data));
424+
data += stride;
355425
}
356-
357426
/* Increment the iterator to the next inner loop */
358427
} while(iternext(iter));
428+
#endif
359429

360430
NpyIter_Deallocate(iter);
361431
}
362-
#endif
363432

364433
#elif defined(SWIGMATLAB)
365434
// convert stir::Array to matlab (currently always converting to double)
@@ -713,13 +782,32 @@ namespace std {
713782
return new SwigPyForwardIteratorClosed_T<OutIter>(current, begin, end, seq);
714783
}
715784

716-
717785
#endif
718-
static Array<4,float> create_array_for_proj_data(const ProjData& proj_data)
786+
787+
// helper function that allocates a stir::Array of appropriate size
788+
static stir::BasicCoordinate<4, int> array_for_proj_data_size(const ProjData& proj_data)
719789
{
720790
const int num_non_tof_sinos = proj_data.get_num_non_tof_sinograms();
721-
Array<4,float> array(IndexRange4D(proj_data.get_num_tof_poss(),num_non_tof_sinos, proj_data.get_num_views(), proj_data.get_num_tangential_poss()));
722-
return array;
791+
return stir::make_coordinate(proj_data.get_num_tof_poss(),num_non_tof_sinos, proj_data.get_num_views(), proj_data.get_num_tangential_poss());
792+
}
793+
794+
static Array<4,float> create_array_for_proj_data(const ProjData& proj_data)
795+
{
796+
Array<4,float> array(IndexRange4D(array_for_proj_data_size(proj_data)));
797+
return array;
798+
}
799+
800+
// helper function that allocates a numpy.ndarray of appropriate size
801+
static PyArrayObject* create_nparray_for_proj_data(const ProjData& proj_data)
802+
{
803+
const auto stir_sizes = swigstir::array_for_proj_data_size(proj_data);
804+
constexpr int num_dimensions = 4;
805+
npy_intp dims[num_dimensions];
806+
for (int d=0; d<num_dimensions; ++d)
807+
dims[d] = stir_sizes[d + 1];
808+
auto np_array =
809+
(PyArrayObject *)PyArray_SimpleNew(num_dimensions, dims, NPY_FLOAT);
810+
return np_array;
723811
}
724812

725813
// a function for converting ProjData to a 4D array as that's what is easy to use

src/swig/stir_array.i

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,21 +170,50 @@ namespace stir
170170
return swigstir::tuple_from_coord(sizes);
171171
}
172172

173-
%feature("autodoc", "fill from a Python iterator, e.g. array.fill(numpyarray.flat)") fill;
173+
%feature("autodoc", "fill from a Python scalar, numpy array or iterator, e.g. array.fill(numpyarray.flat)") fill;
174174
void fill(PyObject* const arg)
175175
{
176176
if (PyIter_Check(arg))
177177
{
178178
swigstir::fill_Array_from_Python_iterator($self, arg);
179179
}
180+
else if (PyArray_Check(arg))
181+
{
182+
auto np_arr = (PyArrayObject*)arg;
183+
if (static_cast<size_t>(PyArray_SIZE(np_arr)) != $self->size_all())
184+
{
185+
throw std::runtime_error("Array.fill needs to be called with numpy array of correct size");
186+
}
187+
swigstir::fill_iterator_from_nparray<elemT>($self->begin_all(), (PyArrayObject*)arg);
188+
}
180189
else
181190
{
182191
char str[1000];
183-
snprintf(str, 1000, "Wrong argument-type used for fill(): should be a scalar or an iterator or so, but is of type %s",
192+
snprintf(str, 1000, "Wrong argument-type used for fill(): should be a scalar, numpy array or an iterator, but is of type %s",
184193
arg->ob_type->tp_name);
185194
throw std::invalid_argument(str);
186-
}
195+
}
187196
}
197+
198+
%newobject as_array();
199+
%feature("autodoc", "Create a new numpy array with same dimensions/data. Raises an error if the array is not rectangular.") as_array;
200+
PyObject* as_array() const
201+
{
202+
int np_typenum = swigstir::get_np_typenum<elemT>();
203+
204+
stir::BasicCoordinate<num_dimensions,int> minind,maxind;
205+
if (!$self->get_regular_range(minind, maxind))
206+
throw std::range_error("as_array() called on irregular array");
207+
stir::BasicCoordinate<num_dimensions, int> stir_sizes=maxind-minind+1;
208+
npy_intp dims[num_dimensions];
209+
for (int d=0; d<num_dimensions; ++d)
210+
dims[d]= stir_sizes[d + 1];
211+
auto np_array =
212+
(PyArrayObject *)PyArray_SimpleNew(num_dimensions, dims, np_typenum);
213+
swigstir::fill_nparray_from_iterator<elemT>(np_array, self->begin_all());
214+
return PyArray_Return(np_array);
215+
}
216+
188217
}
189218
#endif
190219

0 commit comments

Comments
 (0)