diff options
-rw-r--r-- | doc/release/upcoming_changes/16911.deprecation.rst | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_deprecations.py | 12 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 8 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 3 |
4 files changed, 16 insertions, 9 deletions
diff --git a/doc/release/upcoming_changes/16911.deprecation.rst b/doc/release/upcoming_changes/16911.deprecation.rst index d4dcb629c..8f38ed989 100644 --- a/doc/release/upcoming_changes/16911.deprecation.rst +++ b/doc/release/upcoming_changes/16911.deprecation.rst @@ -4,4 +4,4 @@ The ``trim_zeros`` function will, in the future, require an array with the following two properties: * It must be 1D. -* It must be convertable into a boolean array. +* It must support elementwise comparisons with zero. diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index 9004bef30..f0eac24ee 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -710,15 +710,19 @@ class TestRaggedArray(_DeprecationTestCase): class TestTrimZeros(_DeprecationTestCase): # Numpy 1.20.0, 2020-07-31 - @pytest.mark.parametrize("arr", [np.random.rand(10, 10).tolist(), - np.random.rand(10).astype(str)]) - def test_deprecated(self, arr): + @pytest.mark.parametrize( + "arr,exc_type", + [(np.random.rand(10, 10).tolist(), ValueError), + (np.random.rand(10).astype(str), FutureWarning)] + ) + def test_deprecated(self, arr, exc_type): with warnings.catch_warnings(): warnings.simplefilter('error', DeprecationWarning) try: np.trim_zeros(arr) except DeprecationWarning as ex: - assert_(isinstance(ex.__cause__, ValueError)) + ex_cause = ex.__cause__ + assert_(isinstance(ex_cause, exc_type)) else: raise AssertionError("No error raised during function call") diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index cd8862c94..96e1c6de9 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1631,7 +1631,7 @@ def trim_zeros(filt, trim='fb'): # Numpy 1.20.0, 2020-07-31 warning = DeprecationWarning( "in the future trim_zeros will require a 1-D array as input " - "that is compatible with ndarray.astype(bool)" + "that supports elementwise comparisons with zero" ) warning.__cause__ = ex warnings.warn(warning, stacklevel=3) @@ -1643,7 +1643,11 @@ def trim_zeros(filt, trim='fb'): def _trim_zeros_new(filt, trim='fb'): """Newer optimized implementation of ``trim_zeros()``.""" - arr = np.asanyarray(filt).astype(bool, copy=False) + arr_any = np.asanyarray(filt) + with warnings.catch_warnings(): + # not all dtypes support elementwise comparisons with `0` (e.g. str) + warnings.simplefilter('error', FutureWarning) + arr = arr_any != 0 if arr_any.dtype != bool else arr_any if arr.ndim != 1: raise ValueError('trim_zeros requires an array of exactly one dimension') diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 89c1a2d9b..744034e01 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1169,10 +1169,9 @@ class TestTrimZeros: a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) b = a.astype(float) c = a.astype(complex) - d = np.array([None, [], 1, False, 'b', 3.0, range(4), b''], dtype=object) def values(self): - attr_names = ('a', 'b', 'c', 'd') + attr_names = ('a', 'b', 'c') return (getattr(self, name) for name in attr_names) def test_basic(self): |