7 #ifndef MDS_UTILS_PYTHON_NUMPY_ARRAY_ITERATOR_HPP_INCLUDED 8 #define MDS_UTILS_PYTHON_NUMPY_ARRAY_ITERATOR_HPP_INCLUDED 10 #include <boost/iterator/iterator_facade.hpp> 11 #include <numpy/arrayobject.h> 14 #include <type_traits> 63 template<
class storage>
64 struct ndarray_iterator_traits;
67 struct ndarray_iterator_traits<c_storage> {
69 static const npy_uint32 index_flag = NPY_ITER_C_INDEX;
71 static const NPY_ORDER order_flag = NPY_CORDER;
76 struct ndarray_iterator_traits<fortran_storage> {
78 static const npy_uint32 index_flag = NPY_ITER_F_INDEX;
80 static const NPY_ORDER order_flag = NPY_FORTRANORDER;
84 template<
class T,
class Storage,npy_u
int32 rw_flag>
85 class ArrayElementProxy;
115 template<
class T,
class Storage,npy_u
int32 rw_flag = NPY_ITER_READWRITE>
117 public boost::iterator_facade<
118 NDArrayIterator<T,Storage,rw_flag>,
120 boost::random_access_traversal_tag,
121 detail::ArrayElementProxy<T,Storage,rw_flag>
124 friend class boost::iterator_core_access;
125 friend class detail::ArrayElementProxy<T,Storage,rw_flag>;
127 PyArrayObject *m_parray;
128 PyArray_Descr *dtype;
132 char **m_data_ptr_arr;
140 bool is_past_the_end;
147 typedef detail::ArrayElementProxy<T,Storage,rw_flag>
reference;
154 reference dereference()
const {
155 return reference(*
this);
159 return cur_ind == rhs.cur_ind && *m_data_ptr_arr == *rhs.m_data_ptr_arr;
168 void reposition_iterator(npy_intp ind) {
169 int retval(NpyIter_GotoIndex(m_it,ind));
170 if (retval == NPY_FAIL) {
171 throw std::runtime_error(
"Could not reposition the iterator");
177 reposition_iterator(cur_ind-1);
180 void advance(difference_type n) {
181 reposition_iterator(cur_ind+n);
185 return rhs.cur_ind - cur_ind;
193 iternext(NULL),nel(0) {}
214 detail::ndarray_iterator_traits<Storage>::index_flag |
215 NPY_ITER_UPDATEIFCOPY | rw_flag,
216 detail::ndarray_iterator_traits<Storage>::order_flag,
219 m_data_ptr_arr(NULL),
221 nel(0),cur_ind(0),is_past_the_end(false) {
224 throw std::runtime_error(
"Could not allocate NDArrayIterator object");
227 if (PyArray_INCREF(m_parray)) {
228 throw std::runtime_error(
"Could not incref the NumPy array object.");
231 if ((m_data_ptr_arr = NpyIter_GetDataPtrArray(m_it)) == NULL) {
232 throw std::runtime_error(
"Could not get the data array from the NDArrayIterator");
235 if ((iternext = NpyIter_GetIterNext(m_it, NULL)) == NULL) {
236 throw std::runtime_error(
"Could not get the iteration function");
238 nel = NpyIter_GetIterSize(m_it);
243 if (NpyIter_Deallocate(m_it) == NPY_FAIL) {
244 throw std::runtime_error(
"Could not deallocate Numpy iterator.");
247 if (PyArray_XDECREF(m_parray)) {
248 throw std::runtime_error(
"Could not properly decref the NumPy array object.");
265 m_parray(rhs.m_parray),
267 m_it(NpyIter_Copy(rhs.m_it)),m_data_ptr_arr(NULL),
268 iternext(NULL),nel(0),cur_ind(rhs.cur_ind),
269 is_past_the_end(end) {
272 throw std::runtime_error(
"Could not duplicate NumPy array iterator");
274 if (PyArray_INCREF(m_parray)) {
275 throw std::runtime_error(
"Could not incref the NumPy array object.");
279 if ((m_data_ptr_arr = NpyIter_GetDataPtrArray(m_it)) == NULL) {
280 throw std::runtime_error(
"Could not get the data array from the NDArrayIterator");
283 if ((iternext = NpyIter_GetIterNext(m_it, NULL)) == NULL) {
284 throw std::runtime_error(
"Could not get the iteration function");
286 nel = NpyIter_GetIterSize(m_it);
288 if (is_past_the_end) {
289 reposition_iterator(nel-1);
304 dtype = rhs.dtype; Py_XINCREF(dtype);
305 cur_ind = rhs.cur_ind;
306 is_past_the_end = rhs.is_past_the_end;
308 if ((m_it = NpyIter_Copy(rhs.m_it)) == NULL) {
309 throw std::runtime_error(
"Could not duplicate NumPy array iterator");
311 if ((m_data_ptr_arr = NpyIter_GetDataPtrArray(m_it)) == NULL) {
312 throw std::runtime_error(
"Could not get the data array from the NDArrayIterator");
315 if ((iternext = NpyIter_GetIterNext(m_it, NULL)) == NULL) {
316 throw std::runtime_error(
"Could not get the iteration function");
318 nel = NpyIter_GetIterSize(m_it);
320 if (is_past_the_end) {
321 reposition_iterator(nel-1);
343 template<
class T,
class Storage,npy_u
int32 rw_flag>
344 class ArrayElementProxy {
362 ArrayElementProxy(
const ArrIt& it) :
363 m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
366 ArrayElementProxy& operator =(
const T& val) {
367 *
reinterpret_cast<T*
>(*(m_arr_it.m_data_ptr_arr)) = val;
372 operator T ()
const {
373 return *
reinterpret_cast<T*
>(*(m_arr_it.m_data_ptr_arr));
378 template<
class T,
class Storage>
379 class ArrayElementProxy<T,Storage,NPY_ITER_READONLY> {
397 ArrayElementProxy(
const ArrIt& it) :
398 m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
401 operator T ()
const {
402 return *
reinterpret_cast<T*
>(*(m_arr_it.m_data_ptr_arr));
407 template<
class T,
class Storage>
408 class ArrayElementProxy<T,Storage,NPY_ITER_WRITEONLY> {
426 ArrayElementProxy(
const ArrIt& it) :
427 m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
430 ArrayElementProxy& operator =(
const T& val) {
431 *
reinterpret_cast<T*
>(*(m_arr_it.m_data_ptr_arr)) = val;
438 template<
class T,
class Storage>
439 class ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_READWRITE> {
441 static_assert(std::is_floating_point<T>::value,
"Non-floating-point complex numbers are not allowed.");
454 ArrayElementProxy(
const ArrIt& it) :
455 m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
457 ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_READWRITE>&
458 operator =(
const std::complex<T>& val) {
460 npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
462 pel->real = std::real(val);
463 pel->imag = std::imag(val);
469 operator std::complex<T> ()
const {
471 npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
473 return std::complex<T>(pel->real,pel->imag);
478 template<
class T,
class Storage>
479 class ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_READONLY> {
481 static_assert(std::is_floating_point<T>::value,
"Non-floating-point complex numbers are not allowed.");
494 ArrayElementProxy(
const ArrIt& it) :
495 m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
497 operator std::complex<T> ()
const {
499 npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
501 return std::complex<T>(pel->real,pel->imag);
506 template<
class T,
class Storage>
507 class ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_WRITEONLY> {
509 static_assert(std::is_floating_point<T>::value,
"Non-floating-point complex numbers are not allowed.");
522 ArrayElementProxy(
const ArrIt& it) :
523 m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
525 ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_WRITEONLY>&
526 operator =(
const std::complex<T>& val) {
528 npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
530 pel->real = std::real(val);
531 pel->imag = std::imag(val);
Tag for the C storage ordering.
detail::ArrayElementProxy< T, Storage, rw_flag > reference
The type of a reference to an element.
Iterator on a NumPy ndarray.
NDArrayIterator(PyArrayObject *parray)
Builds the start iterator for a particular ndarray object.
Tag for the FORTRAN storage ordering.
Main namespace of all Michele De Stefano's C++ utilities.
NDArrayIterator(const NDArrayIterator &rhs, bool end=false)
Copy-constructor or constructor for the Past-the-End iterator.
ptrdiff_t difference_type
The type of the difference between two memory locations.
NDArrayIterator()
Default constructor.
~NDArrayIterator()
Destructor.
Contains type utilities for building NumPy-based extensions.
Provides traits for a specific C/C++ datatype.