summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/upcoming_changes/16911.deprecation.rst2
-rw-r--r--numpy/core/tests/test_deprecations.py12
-rw-r--r--numpy/lib/function_base.py8
-rw-r--r--numpy/lib/tests/test_function_base.py3
4 files changed, 16 insertions, 9 deletions
diff --git a/doc/release/upcoming_changes/16911.deprecation.rst b/doc/release/upcoming_changes/16911.deprecation.rst
index d4dcb629c..8f38ed989 100644
--- a/doc/release/upcoming_changes/16911.deprecation.rst
+++ b/doc/release/upcoming_changes/16911.deprecation.rst
@@ -4,4 +4,4 @@ The ``trim_zeros`` function will, in the future, require an array with the
following two properties:
* It must be 1D.
-* It must be convertable into a boolean array.
+* It must support elementwise comparisons with zero.
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 9004bef30..f0eac24ee 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -710,15 +710,19 @@ class TestRaggedArray(_DeprecationTestCase):
class TestTrimZeros(_DeprecationTestCase):
# Numpy 1.20.0, 2020-07-31
- @pytest.mark.parametrize("arr", [np.random.rand(10, 10).tolist(),
- np.random.rand(10).astype(str)])
- def test_deprecated(self, arr):
+ @pytest.mark.parametrize(
+ "arr,exc_type",
+ [(np.random.rand(10, 10).tolist(), ValueError),
+ (np.random.rand(10).astype(str), FutureWarning)]
+ )
+ def test_deprecated(self, arr, exc_type):
with warnings.catch_warnings():
warnings.simplefilter('error', DeprecationWarning)
try:
np.trim_zeros(arr)
except DeprecationWarning as ex:
- assert_(isinstance(ex.__cause__, ValueError))
+ ex_cause = ex.__cause__
+ assert_(isinstance(ex_cause, exc_type))
else:
raise AssertionError("No error raised during function call")
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index cd8862c94..96e1c6de9 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1631,7 +1631,7 @@ def trim_zeros(filt, trim='fb'):
# Numpy 1.20.0, 2020-07-31
warning = DeprecationWarning(
"in the future trim_zeros will require a 1-D array as input "
- "that is compatible with ndarray.astype(bool)"
+ "that supports elementwise comparisons with zero"
)
warning.__cause__ = ex
warnings.warn(warning, stacklevel=3)
@@ -1643,7 +1643,11 @@ def trim_zeros(filt, trim='fb'):
def _trim_zeros_new(filt, trim='fb'):
"""Newer optimized implementation of ``trim_zeros()``."""
- arr = np.asanyarray(filt).astype(bool, copy=False)
+ arr_any = np.asanyarray(filt)
+ with warnings.catch_warnings():
+ # not all dtypes support elementwise comparisons with `0` (e.g. str)
+ warnings.simplefilter('error', FutureWarning)
+ arr = arr_any != 0 if arr_any.dtype != bool else arr_any
if arr.ndim != 1:
raise ValueError('trim_zeros requires an array of exactly one dimension')
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 89c1a2d9b..744034e01 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1169,10 +1169,9 @@ class TestTrimZeros:
a = np.array([0, 0, 1, 0, 2, 3, 4, 0])
b = a.astype(float)
c = a.astype(complex)
- d = np.array([None, [], 1, False, 'b', 3.0, range(4), b''], dtype=object)
def values(self):
- attr_names = ('a', 'b', 'c', 'd')
+ attr_names = ('a', 'b', 'c')
return (getattr(self, name) for name in attr_names)
def test_basic(self):