diff options
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 28 |
1 files changed, 13 insertions, 15 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 1f2cb66fa..1df2ab09b 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -1002,7 +1002,7 @@ def safe_eval(source): return ast.literal_eval(source) -def _median_nancheck(data, result, axis, out): +def _median_nancheck(data, result, axis): """ Utility function to check median result from data for NaN values at the end and return NaN in that case. Input result can also be a MaskedArray. @@ -1010,18 +1010,18 @@ def _median_nancheck(data, result, axis, out): Parameters ---------- data : array - Input data to median function + Sorted input data to median function result : Array or MaskedArray - Result of median function + Result of median function. axis : int Axis along which the median was computed. - out : ndarray, optional - Output array in which to place the result. Returns ------- - median : scalar or ndarray - Median or NaN in axes which contained NaN in the input. + result : scalar or ndarray + Median or NaN in axes which contained NaN in the input. If the input + was an array, NaN will be inserted in-place. If a scalar, either the + input itself or a scalar NaN. """ if data.size == 0: return result @@ -1029,14 +1029,12 @@ def _median_nancheck(data, result, axis, out): # masked NaN values are ok if np.ma.isMaskedArray(n): n = n.filled(False) - if result.ndim == 0: - if n == True: - if out is not None: - out[...] = data.dtype.type(np.nan) - result = out - else: - result = data.dtype.type(np.nan) - elif np.count_nonzero(n.ravel()) > 0: + if np.count_nonzero(n.ravel()) > 0: + # Without given output, it is possible that the current result is a + # numpy scalar, which is not writeable. If so, just return nan. + if isinstance(result, np.generic): + return data.dtype.type(np.nan) + result[n] = np.nan return result |