diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index deb3c969b..6a589a299 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -531,16 +531,23 @@ class ndindex(object): (2, 1, 0) """ + # This is a hack to handle 0-d arrays correctly. + # Fixing nditer would be more work but should be done eventually. + def __new__(cls, *shape): + if len(shape) == 0: + def zerodim_gen(): + yield () + return zerodim_gen() + else: + return object.__new__(cls, *shape) + def __init__(self, *shape): x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) self._it = _nx.nditer(x, flags=['multi_index'], order='C') - # This is a patch to handle 0-d arrays correctly on the Python side. - # We might want to revisit nditer in the future to handle this - self._zerod = (len(shape)==0) def __iter__(self): return self - + def ndincr(self): """ Increment the multi-dimensional index by one. @@ -560,13 +567,6 @@ class ndindex(object): """ self._it.next() - # This is a hack with an un-necessary check in every next call - # But, it's much simpler than writing another iterator for 0-d arrays - # because the Python iterator protocol does not respect monkey-patching - # the next method on an instance. - # Given that we should fix nditer eventually, we do this for now. - if self._zerod: - return () return self._it.multi_index |