Skip to content

Commit 66b7f89

Browse files
[Python] add ProjData(InMemory).as_array()
1 parent 6303dbb commit 66b7f89

4 files changed

Lines changed: 84 additions & 31 deletions

File tree

src/swig/stir.i

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,44 @@
303303

304304
}
305305

306-
// fill an array from a Python sequence
306+
// fill a STIR array from a Python sequence
307307
template <int num_dimensions, typename elemT>
308308
void fill_Array_from_Python_iterator(stir::Array<num_dimensions, elemT> * array_ptr, PyObject* const arg)
309309
{
310310
fill_iterator_from_Python_iterator<elemT>(array_ptr->begin_all(), array_ptr->end_all(), arg);
311311
}
312312

313+
template <typename elemT, typename IterT>
314+
void fill_nparray_from_iterator(PyArrayObject * np, IterT cpp_iter)
315+
{
316+
if (!PyArray_EquivTypenums(PyArray_TYPE(np), get_np_typenum<elemT>()))
317+
{
318+
throw std::runtime_error("stir_object.fill needs to be called with numpy array of correct dtype");
319+
}
320+
321+
#if 1
322+
auto iter = NpyIter_New(np, NPY_ITER_READONLY, NPY_KEEPORDER, NPY_NO_CASTING, NULL);
323+
if (iter==NULL) {
324+
return;
325+
}
326+
auto iternext = NpyIter_GetIterNext(iter, NULL);
327+
auto dataptr = (elemT **) NpyIter_GetDataPtrArray(iter);
328+
do {
329+
**dataptr = *cpp_iter++;
330+
}
331+
while (iternext(iter));
332+
#else
333+
// generic alternative, but doesn't compile and might be slower
334+
auto iterator = PyObject_GetIter(np_array);
335+
PyObject *item;
336+
while (item = PyIter_Next(iterator))
337+
{
338+
*item = *cpp_iter++; // this does not compile. not sure how to assign
339+
Py_DECREF(item);
340+
}
341+
#endif
342+
}
343+
313344
template <typename elemT, typename IterT>
314345
void fill_iterator_from_nparray(IterT iterator, PyArrayObject * np)
315346
{
@@ -751,13 +782,32 @@ namespace std {
751782
return new SwigPyForwardIteratorClosed_T<OutIter>(current, begin, end, seq);
752783
}
753784

754-
755785
#endif
756-
static Array<4,float> create_array_for_proj_data(const ProjData& proj_data)
786+
787+
// helper function that allocates a stir::Array of appropriate size
788+
static stir::BasicCoordinate<4, int> array_for_proj_data_size(const ProjData& proj_data)
757789
{
758790
const int num_non_tof_sinos = proj_data.get_num_non_tof_sinograms();
759-
Array<4,float> array(IndexRange4D(proj_data.get_num_tof_poss(),num_non_tof_sinos, proj_data.get_num_views(), proj_data.get_num_tangential_poss()));
760-
return array;
791+
return stir::make_coordinate(proj_data.get_num_tof_poss(),num_non_tof_sinos, proj_data.get_num_views(), proj_data.get_num_tangential_poss());
792+
}
793+
794+
static Array<4,float> create_array_for_proj_data(const ProjData& proj_data)
795+
{
796+
Array<4,float> array(IndexRange4D(array_for_proj_data_size(proj_data)));
797+
return array;
798+
}
799+
800+
// helper function that allocates a numpy.ndarray of appropriate size
801+
static PyArrayObject* create_nparray_for_proj_data(const ProjData& proj_data)
802+
{
803+
const auto stir_sizes = swigstir::array_for_proj_data_size(proj_data);
804+
constexpr int num_dimensions = 4;
805+
npy_intp dims[num_dimensions];
806+
for (int d=0; d<num_dimensions; ++d)
807+
dims[d] = stir_sizes[d + 1];
808+
auto np_array =
809+
(PyArrayObject *)PyArray_SimpleNew(num_dimensions, dims, NPY_FLOAT);
810+
return np_array;
761811
}
762812

763813
// a function for converting ProjData to a 4D array as that's what is easy to use

src/swig/stir_array.i

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ namespace stir
200200
PyObject* as_array() const
201201
{
202202
int np_typenum = swigstir::get_np_typenum<elemT>();
203-
const auto dtype = PyArray_DescrFromType(np_typenum);
204203

205204
stir::BasicCoordinate<num_dimensions,int> minind,maxind;
206205
if (!$self->get_regular_range(minind, maxind))
@@ -211,28 +210,7 @@ namespace stir
211210
dims[d]= stir_sizes[d + 1];
212211
auto np_array =
213212
(PyArrayObject *)PyArray_SimpleNew(num_dimensions, dims, np_typenum);
214-
auto stir_iter = self->begin_all();
215-
#if 1
216-
auto iter = NpyIter_New(np_array, NPY_ITER_READONLY, NPY_KEEPORDER, NPY_NO_CASTING, dtype);
217-
if (iter==NULL) {
218-
return NULL;
219-
}
220-
auto iternext = NpyIter_GetIterNext(iter, NULL);
221-
auto dataptr = (elemT **) NpyIter_GetDataPtrArray(iter);
222-
do {
223-
**dataptr = *stir_iter;
224-
++stir_iter; }
225-
while (iternext(iter));
226-
#else
227-
// generic alternative, but doesn't compile and might be slower
228-
auto iterator = PyObject_GetIter(np_array);
229-
PyObject *item;
230-
while (item = PyIter_Next(iterator))
231-
{
232-
*item = *stir_iter++; // this does not compile. not sure how to assign
233-
Py_DECREF(item);
234-
}
235-
#endif
213+
swigstir::fill_nparray_from_iterator<elemT>(np_array, self->begin_all());
236214
return PyArray_Return(np_array);
237215
}
238216

src/swig/stir_projdata.i

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,19 @@ namespace stir {
159159
return array;
160160
}
161161

162-
%feature("autodoc", "fill from a Python scalar, numpy array or iterator, e.g. array.fill(numpyarray.flat)") fill;
162+
%newobject as_array();
163+
%feature("autodoc", "Create a new numpy array with same dimensions as the return of to_array().") as_array;
164+
PyObject* as_array() const
165+
{
166+
auto np_array = swigstir::create_nparray_for_proj_data(*$self);
167+
// TODO avoid making an extra copy, but this way, there's less code
168+
// and we don't depend on knowing internal details
169+
const Array<4,float> stir_array = swigstir::projdata_to_4D(*$self);
170+
swigstir::fill_nparray_from_iterator<float>(np_array, stir_array.begin_all());
171+
return PyArray_Return(np_array);
172+
}
173+
174+
%feature("autodoc", "fill from a Python scalar, numpy array or iterator, e.g. array.fill(numpyarray.flat)") fill;
163175
void fill(PyObject* const arg)
164176
{
165177
if (PyIter_Check(arg))
@@ -237,6 +249,15 @@ namespace stir {
237249
}
238250
}
239251

252+
%newobject as_array();
253+
%feature("autodoc", "Create a new numpy array with same dimensions as the return of to_array().") as_array;
254+
PyObject* as_array() const
255+
{
256+
auto np_array = swigstir::create_nparray_for_proj_data(*$self);
257+
swigstir::fill_nparray_from_iterator<float>(np_array, $self->begin());
258+
return PyArray_Return(np_array);
259+
}
260+
240261
#elif defined(SWIGMATLAB)
241262
void fill(const mxArray *pm)
242263
{

src/swig/test/python/test_buildblock.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def test_ProjDataInMemory_numerics():
471471
c /= 3
472472
assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) / 3, rel_tol=1e-4)
473473

474-
def test_ProjData_from_to_Array():
474+
def test_ProjDataInMemory_from_to_Array():
475475
# define a projection with some dummy data (filled with segment no.)
476476
s=Scanner.get_scanner_from_name("ECAT 962")
477477
projdatainfo=ProjDataInfo.construct_proj_data_info(s,3,9,8,6)
@@ -491,7 +491,11 @@ def test_ProjData_from_to_Array():
491491
# fill with iterator
492492
new_projdata.fill(stir_array.flat())
493493
# assert every data point is equal
494-
assert all(a==b for a, b in zip(projdata.to_array().flat(),new_projdata.to_array().flat()))
494+
assert all(a==b for a, b in zip(projdata.to_array().flat(), new_projdata.as_array().flat))
495+
# fill with numpy array
496+
new_projdata.fill(stir_array.as_array())
497+
# assert every data point is equal
498+
assert all(a==b for a, b in zip(projdata.to_array().flat(), new_projdata.as_array().flat))
495499

496500
def test_xapyb_and_sapyb():
497501
"""

0 commit comments

Comments
 (0)