summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index b07fde27d..deb3c969b 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -502,7 +502,6 @@ class ndenumerate(object):
def __iter__(self):
return self
-
class ndindex(object):
"""
An N-dimensional iterator object to index arrays.
@@ -535,6 +534,9 @@ class ndindex(object):
def __init__(self, *shape):
x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
self._it = _nx.nditer(x, flags=['multi_index'], order='C')
+ # This is a patch to handle 0-d arrays correctly on the Python side.
+ # We might want to revisit nditer in the future to handle this
+ self._zerod = (len(shape)==0)
def __iter__(self):
return self
@@ -558,6 +560,13 @@ class ndindex(object):
"""
self._it.next()
+ # This is a hack with an un-necessary check in every next call
+ # But, it's much simpler than writing another iterator for 0-d arrays
+ # because the Python iterator protocol does not respect monkey-patching
+ # the next method on an instance.
+ # Given that we should fix nditer eventually, we do this for now.
+ if self._zerod:
+ return ()
return self._it.multi_index