summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py17
-rw-r--r--numpy/ma/tests/test_core.py7
2 files changed, 11 insertions, 13 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 93eb74be3..405937046 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -2356,20 +2356,11 @@ def masked_invalid(a, copy=True):
fill_value=1e+20)
"""
- a = np.array(a, copy=copy, subok=True)
- mask = getattr(a, '_mask', None)
- if mask is not None:
- condition = ~(np.isfinite(getdata(a)))
- if mask is not nomask:
- condition |= mask
- cls = type(a)
- else:
- condition = ~(np.isfinite(a))
- cls = MaskedArray
- result = a.view(cls)
- result._mask = condition
- return result
+ try:
+ return masked_where(~(np.isfinite(getdata(a))), a, copy=copy)
+ except TypeError:
+ raise
###############################################################################
# Printing options #
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index b056d5169..5a6d642b4 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4496,6 +4496,13 @@ class TestMaskedArrayFunctions:
assert_equal(ma, expected)
assert_equal(ma.mask, expected.mask)
+ def test_masked_invalid_error(self):
+ a = np.arange(5, dtype=object)
+ a[3] = np.PINF
+ a[2] = np.NaN
+ with pytest.raises(TypeError, match="not supported for the input types"):
+ np.ma.masked_invalid(a)
+
def test_choose(self):
# Test choose
choices = [[0, 1, 2, 3], [10, 11, 12, 13],