Michele De Stefano's C++ Utilities
array_iterator.hpp
Go to the documentation of this file.
1 // mds_utils/python/numpy/array_iterator.hpp
2 //
3 // Copyright (c) 2014 - Michele De Stefano (micdestefano@users.sourceforge.net)
4 //
5 // Distributed under the MIT License (See accompanying file LICENSE)
6 
7 #ifndef MDS_UTILS_PYTHON_NUMPY_ARRAY_ITERATOR_HPP_INCLUDED
8 #define MDS_UTILS_PYTHON_NUMPY_ARRAY_ITERATOR_HPP_INCLUDED
9 
10 #include <boost/iterator/iterator_facade.hpp>
11 #include <numpy/arrayobject.h>
13 #include <complex>
14 #include <type_traits>
15 
34 namespace mds_utils {
35  namespace python {
36 
43  namespace numpy {
44 
50 struct c_storage {};
51 
57 struct fortran_storage {};
58 
59 
61  namespace detail {
62 
63 template<class storage>
64 struct ndarray_iterator_traits;
65 
66 template<>
67 struct ndarray_iterator_traits<c_storage> {
68 
69  static const npy_uint32 index_flag = NPY_ITER_C_INDEX;
70 
71  static const NPY_ORDER order_flag = NPY_CORDER;
72 
73 };
74 
75 template<>
76 struct ndarray_iterator_traits<fortran_storage> {
77 
78  static const npy_uint32 index_flag = NPY_ITER_F_INDEX;
79 
80  static const NPY_ORDER order_flag = NPY_FORTRANORDER;
81 
82 };
83 
84 template<class T,class Storage,npy_uint32 rw_flag>
85 class ArrayElementProxy;
86 
87  }
89 
115 template<class T,class Storage,npy_uint32 rw_flag = NPY_ITER_READWRITE>
117  public boost::iterator_facade<
118  NDArrayIterator<T,Storage,rw_flag>,
119  T,
120  boost::random_access_traversal_tag,
121  detail::ArrayElementProxy<T,Storage,rw_flag>
122  > {
123 
124  friend class boost::iterator_core_access;
125  friend class detail::ArrayElementProxy<T,Storage,rw_flag>;
126 
127  PyArrayObject *m_parray;
128  PyArray_Descr *dtype;
129 
130  NpyIter *m_it;
131 
132  char **m_data_ptr_arr;
133 
134  NpyIter_IterNextFunc
135  *iternext;
136 
137  npy_intp nel, // Total number of elements
138  cur_ind; // Current 1D index
139 
140  bool is_past_the_end;
141 
142 
143 
144 public:
145 
147  typedef detail::ArrayElementProxy<T,Storage,rw_flag> reference;
148 
150  typedef ptrdiff_t difference_type;
151 
152 private:
153 
154  reference dereference() const {
155  return reference(*this);
156  }
157 
158  bool equal(const NDArrayIterator& rhs) const {
159  return cur_ind == rhs.cur_ind && *m_data_ptr_arr == *rhs.m_data_ptr_arr;
160  }
161 
162  void increment() {
163  iternext(m_it);
164  ++cur_ind;
165  }
166 
167 
168  void reposition_iterator(npy_intp ind) {
169  int retval(NpyIter_GotoIndex(m_it,ind));
170  if (retval == NPY_FAIL) {
171  throw std::runtime_error("Could not reposition the iterator");
172  }
173  cur_ind = ind;
174  }
175 
176  void decrement() {
177  reposition_iterator(cur_ind-1);
178  }
179 
180  void advance(difference_type n) {
181  reposition_iterator(cur_ind+n);
182  }
183 
184  difference_type distance_to(const NDArrayIterator& rhs) const {
185  return rhs.cur_ind - cur_ind;
186  }
187 
188 
189 public:
190 
192  NDArrayIterator() : m_parray(NULL),dtype(NULL),m_it(NULL),m_data_ptr_arr(NULL),
193  iternext(NULL),nel(0) {}
194 
209  NDArrayIterator(PyArrayObject *parray) :
210  m_parray(parray),
211  dtype(PyArray_DescrFromType(numpy_dtype_traits<T>::typenum)),
212  m_it(NpyIter_New(
213  parray,
214  detail::ndarray_iterator_traits<Storage>::index_flag |
215  NPY_ITER_UPDATEIFCOPY | rw_flag,
216  detail::ndarray_iterator_traits<Storage>::order_flag,
217  NPY_UNSAFE_CASTING,
218  dtype)),
219  m_data_ptr_arr(NULL),
220  iternext(NULL),
221  nel(0),cur_ind(0),is_past_the_end(false) {
222 
223  if (m_it == NULL) {
224  throw std::runtime_error("Could not allocate NDArrayIterator object");
225  }
226 
227  if (PyArray_INCREF(m_parray)) {
228  throw std::runtime_error("Could not incref the NumPy array object.");
229  }
230 
231  if ((m_data_ptr_arr = NpyIter_GetDataPtrArray(m_it)) == NULL) {
232  throw std::runtime_error("Could not get the data array from the NDArrayIterator");
233  }
234 
235  if ((iternext = NpyIter_GetIterNext(m_it, NULL)) == NULL) {
236  throw std::runtime_error("Could not get the iteration function");
237  }
238  nel = NpyIter_GetIterSize(m_it);
239  }
240 
243  if (NpyIter_Deallocate(m_it) == NPY_FAIL) {
244  throw std::runtime_error("Could not deallocate Numpy iterator.");
245  }
246  m_it = NULL;
247  if (PyArray_XDECREF(m_parray)) {
248  throw std::runtime_error("Could not properly decref the NumPy array object.");
249  }
250  Py_XDECREF(dtype);
251  }
252 
264  NDArrayIterator(const NDArrayIterator& rhs,bool end = false) :
265  m_parray(rhs.m_parray),
266  dtype(rhs.dtype),
267  m_it(NpyIter_Copy(rhs.m_it)),m_data_ptr_arr(NULL),
268  iternext(NULL),nel(0),cur_ind(rhs.cur_ind),
269  is_past_the_end(end) {
270 
271  if (m_it == NULL) {
272  throw std::runtime_error("Could not duplicate NumPy array iterator");
273  }
274  if (PyArray_INCREF(m_parray)) {
275  throw std::runtime_error("Could not incref the NumPy array object.");
276  }
277  Py_XINCREF(dtype);
278 
279  if ((m_data_ptr_arr = NpyIter_GetDataPtrArray(m_it)) == NULL) {
280  throw std::runtime_error("Could not get the data array from the NDArrayIterator");
281  }
282 
283  if ((iternext = NpyIter_GetIterNext(m_it, NULL)) == NULL) {
284  throw std::runtime_error("Could not get the iteration function");
285  }
286  nel = NpyIter_GetIterSize(m_it);
287 
288  if (is_past_the_end) {
289  reposition_iterator(nel-1);
290  increment();
291  }
292  }
293 
294 
303  NDArrayIterator& operator =(const NDArrayIterator& rhs) {
304  dtype = rhs.dtype; Py_XINCREF(dtype);
305  cur_ind = rhs.cur_ind;
306  is_past_the_end = rhs.is_past_the_end;
307 
308  if ((m_it = NpyIter_Copy(rhs.m_it)) == NULL) {
309  throw std::runtime_error("Could not duplicate NumPy array iterator");
310  }
311  if ((m_data_ptr_arr = NpyIter_GetDataPtrArray(m_it)) == NULL) {
312  throw std::runtime_error("Could not get the data array from the NDArrayIterator");
313  }
314 
315  if ((iternext = NpyIter_GetIterNext(m_it, NULL)) == NULL) {
316  throw std::runtime_error("Could not get the iteration function");
317  }
318  nel = NpyIter_GetIterSize(m_it);
319 
320  if (is_past_the_end) {
321  reposition_iterator(nel-1);
322  increment();
323  }
324  }
325 
326 
327 };
328 
329 
331  namespace detail {
332 
343 template<class T,class Storage,npy_uint32 rw_flag>
344 class ArrayElementProxy {
345 
347 
348  ArrIt& m_arr_it;
349 
350  npy_intp m_ind; // Index of the element to be retrieved or set
351 
352 public:
353 
362  ArrayElementProxy(const ArrIt& it) :
363  m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
364 
366  ArrayElementProxy& operator =(const T& val) {
367  *reinterpret_cast<T*>(*(m_arr_it.m_data_ptr_arr)) = val;
368  return *this;
369  }
370 
372  operator T () const {
373  return *reinterpret_cast<T*>(*(m_arr_it.m_data_ptr_arr));
374  }
375 };
376 
377 
378 template<class T,class Storage>
379 class ArrayElementProxy<T,Storage,NPY_ITER_READONLY> {
380 
382 
383  ArrIt& m_arr_it;
384 
385  npy_intp m_ind; // Index of the element to be retrieved or set
386 
387 public:
388 
397  ArrayElementProxy(const ArrIt& it) :
398  m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
399 
401  operator T () const {
402  return *reinterpret_cast<T*>(*(m_arr_it.m_data_ptr_arr));
403  }
404 };
405 
406 
407 template<class T,class Storage>
408 class ArrayElementProxy<T,Storage,NPY_ITER_WRITEONLY> {
409 
411 
412  ArrIt& m_arr_it;
413 
414  npy_intp m_ind; // Index of the element to be retrieved or set
415 
416 public:
417 
426  ArrayElementProxy(const ArrIt& it) :
427  m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
428 
430  ArrayElementProxy& operator =(const T& val) {
431  *reinterpret_cast<T*>(*(m_arr_it.m_data_ptr_arr)) = val;
432  return *this;
433  }
434 };
435 
436 // Partial specializations for complex numbers
437 
438 template<class T,class Storage>
439 class ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_READWRITE> {
440 
441  static_assert(std::is_floating_point<T>::value,"Non-floating-point complex numbers are not allowed.");
442 
443  typedef NDArrayIterator<std::complex<T>,Storage,NPY_ITER_READWRITE> ArrIt;
444 
445  typedef typename numpy_dtype_traits< std::complex<T> >::type
446  npy_complex_T;
447 
448  ArrIt& m_arr_it;
449 
450  npy_intp m_ind; // Index of the element to be retrieved or set
451 
452 public:
453 
454  ArrayElementProxy(const ArrIt& it) :
455  m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
456 
457  ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_READWRITE>&
458  operator =(const std::complex<T>& val) {
459 
460  npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
461 
462  pel->real = std::real(val);
463  pel->imag = std::imag(val);
464 
465  return *this;
466  }
467 
468 
469  operator std::complex<T> () const {
470 
471  npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
472 
473  return std::complex<T>(pel->real,pel->imag);
474  }
475 };
476 
477 
478 template<class T,class Storage>
479 class ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_READONLY> {
480 
481  static_assert(std::is_floating_point<T>::value,"Non-floating-point complex numbers are not allowed.");
482 
483  typedef NDArrayIterator<std::complex<T>,Storage,NPY_ITER_READONLY> ArrIt;
484 
485  typedef typename numpy_dtype_traits< std::complex<T> >::type
486  npy_complex_T;
487 
488  ArrIt& m_arr_it;
489 
490  npy_intp m_ind; // Index of the element to be retrieved or set
491 
492 public:
493 
494  ArrayElementProxy(const ArrIt& it) :
495  m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
496 
497  operator std::complex<T> () const {
498 
499  npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
500 
501  return std::complex<T>(pel->real,pel->imag);
502  }
503 };
504 
505 
506 template<class T,class Storage>
507 class ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_WRITEONLY> {
508 
509  static_assert(std::is_floating_point<T>::value,"Non-floating-point complex numbers are not allowed.");
510 
511  typedef NDArrayIterator<std::complex<T>,Storage,NPY_ITER_WRITEONLY> ArrIt;
512 
513  typedef typename numpy_dtype_traits< std::complex<T> >::type
514  npy_complex_T;
515 
516  ArrIt& m_arr_it;
517 
518  npy_intp m_ind; // Index of the element to be retrieved or set
519 
520 public:
521 
522  ArrayElementProxy(const ArrIt& it) :
523  m_arr_it(*const_cast<ArrIt*>(&it)),m_ind(it.cur_ind) {}
524 
525  ArrayElementProxy<std::complex<T>,Storage,NPY_ITER_WRITEONLY>&
526  operator =(const std::complex<T>& val) {
527 
528  npy_complex_T *pel(reinterpret_cast<npy_complex_T*>(*(m_arr_it.m_data_ptr_arr)));
529 
530  pel->real = std::real(val);
531  pel->imag = std::imag(val);
532 
533  return *this;
534  }
535 };
536 
537 
538  } // namespace detail
540 
541 
542  }
543  }
544 }
545 
546 
547 
548 #endif /* MDS_UTILS_PYTHON_NUMPY_ARRAY_ITERATOR_HPP_INCLUDED */
Tag for the C storage ordering.
detail::ArrayElementProxy< T, Storage, rw_flag > reference
The type of a reference to an element.
NDArrayIterator(PyArrayObject *parray)
Builds the start iterator for a particular ndarray object.
Tag for the FORTRAN storage ordering.
Main namespace of all Michele De Stefano&#39;s C++ utilities.
Definition: endian.hpp:30
NDArrayIterator(const NDArrayIterator &rhs, bool end=false)
Copy-constructor or constructor for the Past-the-End iterator.
ptrdiff_t difference_type
The type of the difference between two memory locations.
Contains type utilities for building NumPy-based extensions.
Provides traits for a specific C/C++ datatype.
Definition: traits.hpp:48