summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/tests/test_function_base.py10
-rw-r--r--numpy/lib/utils.py14
2 files changed, 16 insertions, 8 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 829691b1c..5f27ea655 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -3432,6 +3432,16 @@ class TestMedian:
a = MySubClass([1, 2, 3])
assert_equal(np.median(a), -7)
+ @pytest.mark.parametrize('arr',
+ ([1., 2., 3.], [1., np.nan, 3.], np.nan, 0.))
+ def test_subclass2(self, arr):
+ """Check that we return subclasses, even if a NaN scalar."""
+ class MySubclass(np.ndarray):
+ pass
+
+ m = np.median(np.array(arr).view(MySubclass))
+ assert isinstance(m, MySubclass)
+
def test_out(self):
o = np.zeros((4,))
d = np.ones((3, 4))
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 1f2cb66fa..931669fc1 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -1029,14 +1029,12 @@ def _median_nancheck(data, result, axis, out):
# masked NaN values are ok
if np.ma.isMaskedArray(n):
n = n.filled(False)
- if result.ndim == 0:
- if n == True:
- if out is not None:
- out[...] = data.dtype.type(np.nan)
- result = out
- else:
- result = data.dtype.type(np.nan)
- elif np.count_nonzero(n.ravel()) > 0:
+ if np.count_nonzero(n.ravel()) > 0:
+ # Without given output, it is possible that the current result is a
+ # numpy scalar, which is not writeable. If so, just return nan.
+ if isinstance(result, np.generic):
+ return data.dtype.type(np.nan)
+
result[n] = np.nan
return result