diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index b07fde27d..deb3c969b 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -502,7 +502,6 @@ class ndenumerate(object): def __iter__(self): return self - class ndindex(object): """ An N-dimensional iterator object to index arrays. @@ -535,6 +534,9 @@ class ndindex(object): def __init__(self, *shape): x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) self._it = _nx.nditer(x, flags=['multi_index'], order='C') + # This is a patch to handle 0-d arrays correctly on the Python side. + # We might want to revisit nditer in the future to handle this + self._zerod = (len(shape)==0) def __iter__(self): return self @@ -558,6 +560,13 @@ class ndindex(object): """ self._it.next() + # This is a hack with an un-necessary check in every next call + # But, it's much simpler than writing another iterator for 0-d arrays + # because the Python iterator protocol does not respect monkey-patching + # the next method on an instance. + # Given that we should fix nditer eventually, we do this for now. + if self._zerod: + return () return self._it.multi_index |