diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 10 | ||||
-rw-r--r-- | numpy/lib/utils.py | 14 |
2 files changed, 16 insertions, 8 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 829691b1c..5f27ea655 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -3432,6 +3432,16 @@ class TestMedian: a = MySubClass([1, 2, 3]) assert_equal(np.median(a), -7) + @pytest.mark.parametrize('arr', + ([1., 2., 3.], [1., np.nan, 3.], np.nan, 0.)) + def test_subclass2(self, arr): + """Check that we return subclasses, even if a NaN scalar.""" + class MySubclass(np.ndarray): + pass + + m = np.median(np.array(arr).view(MySubclass)) + assert isinstance(m, MySubclass) + def test_out(self): o = np.zeros((4,)) d = np.ones((3, 4)) diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 1f2cb66fa..931669fc1 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -1029,14 +1029,12 @@ def _median_nancheck(data, result, axis, out): # masked NaN values are ok if np.ma.isMaskedArray(n): n = n.filled(False) - if result.ndim == 0: - if n == True: - if out is not None: - out[...] = data.dtype.type(np.nan) - result = out - else: - result = data.dtype.type(np.nan) - elif np.count_nonzero(n.ravel()) > 0: + if np.count_nonzero(n.ravel()) > 0: + # Without given output, it is possible that the current result is a + # numpy scalar, which is not writeable. If so, just return nan. + if isinstance(result, np.generic): + return data.dtype.type(np.nan) + result[n] = np.nan return result |