diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-05-23 06:41:18 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-09-05 22:07:06 -0700 |
commit | b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch) | |
tree | e10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core | |
parent | e3f4c536a0014789dbd0321926b4f62c39d73719 (diff) | |
download | numpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz |
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d.
Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/numeric.py | 13 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 15 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 24 |
3 files changed, 43 insertions, 9 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 8ada87b9f..c395b1348 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -26,6 +26,7 @@ if sys.version_info[0] < 3: from . import overrides from . import umath +from . import shape_base from .overrides import set_module from .umath import (multiply, invert, sin, PINF, NAN) from . import numerictypes @@ -545,8 +546,10 @@ def argwhere(a): Returns ------- - index_array : ndarray + index_array : (N, a.ndim) ndarray Indices of elements that are non-zero. Indices are grouped by element. + This array will have shape ``(N, a.ndim)`` where ``N`` is the number of + non-zero items. See Also -------- @@ -554,7 +557,8 @@ def argwhere(a): Notes ----- - ``np.argwhere(a)`` is the same as ``np.transpose(np.nonzero(a))``. + ``np.argwhere(a)`` is almost the same as ``np.transpose(np.nonzero(a))``, + but produces a result of the correct shape for a 0D array. The output of ``argwhere`` is not suitable for indexing arrays. For this purpose use ``nonzero(a)`` instead. @@ -572,6 +576,11 @@ def argwhere(a): [1, 2]]) """ + # nonzero does not behave well on 0d, so promote to 1d + if np.ndim(a) == 0: + a = shape_base.atleast_1d(a) + # then remove the added dimension + return argwhere(a)[:,:0] return transpose(nonzero(a)) diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 710f64827..d7e769e62 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -9,8 +9,9 @@ import warnings from . import numeric as _nx from . import overrides -from .numeric import array, asanyarray, newaxis +from ._asarray import array, asanyarray from .multiarray import normalize_axis_index +from . import fromnumeric as _from_nx array_function_dispatch = functools.partial( @@ -123,7 +124,7 @@ def atleast_2d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1) elif ary.ndim == 1: - result = ary[newaxis, :] + result = ary[_nx.newaxis, :] else: result = ary res.append(result) @@ -193,9 +194,9 @@ def atleast_3d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1, 1) elif ary.ndim == 1: - result = ary[newaxis, :, newaxis] + result = ary[_nx.newaxis, :, _nx.newaxis] elif ary.ndim == 2: - result = ary[:, :, newaxis] + result = ary[:, :, _nx.newaxis] else: result = ary res.append(result) @@ -435,9 +436,9 @@ def stack(arrays, axis=0, out=None): # Internal functions to eliminate the overhead of repeated dispatch in one of # the two possible paths inside np.block. # Use getattr to protect against __array_function__ being disabled. -_size = getattr(_nx.size, '__wrapped__', _nx.size) -_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim) -_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate) +_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size) +_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim) +_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate) def _block_format_index(index): diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index c479a0f6d..1358b45e9 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2583,6 +2583,30 @@ class TestConvolve(object): class TestArgwhere(object): + + @pytest.mark.parametrize('nd', [0, 1, 2]) + def test_nd(self, nd): + # get an nd array with multiple elements in every dimension + x = np.empty((2,)*nd, bool) + + # none + x[...] = False + assert_equal(np.argwhere(x).shape, (0, nd)) + + # only one + x[...] = False + x.flat[0] = True + assert_equal(np.argwhere(x).shape, (1, nd)) + + # all but one + x[...] = True + x.flat[0] = False + assert_equal(np.argwhere(x).shape, (x.size - 1, nd)) + + # all + x[...] = True + assert_equal(np.argwhere(x).shape, (x.size, nd)) + def test_2D(self): x = np.arange(6).reshape((2, 3)) assert_array_equal(np.argwhere(x > 1), |