diff options
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r-- | numpy/core/_methods.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 75fd32ec8..e475b94df 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -4,14 +4,15 @@ and the Python code for the NumPy-namespace function """ import warnings +from contextlib import nullcontext from numpy.core import multiarray as mu from numpy.core import umath as um -from numpy.core._asarray import asanyarray +from numpy.core.multiarray import asanyarray from numpy.core import numerictypes as nt from numpy.core import _exceptions from numpy._globals import _NoValue -from numpy.compat import pickle, os_fspath, contextlib_nullcontext +from numpy.compat import pickle, os_fspath # save those O(100) nanoseconds! umr_maximum = um.maximum.reduce @@ -51,9 +52,15 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False, return umr_prod(a, axis, dtype, out, keepdims, initial, where) def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + # Parsing keyword arguments is currently fairly slow, so avoid it for now + if where is True: + return umr_any(a, axis, dtype, out, keepdims) return umr_any(a, axis, dtype, out, keepdims, where=where) def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + # Parsing keyword arguments is currently fairly slow, so avoid it for now + if where is True: + return umr_all(a, axis, dtype, out, keepdims) return umr_all(a, axis, dtype, out, keepdims, where=where) def _count_reduce_items(arr, axis, keepdims=False, where=True): @@ -158,7 +165,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): is_float16_result = False rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) - if umr_any(rcount == 0, axis=None): + if rcount == 0 if where is True else umr_any(rcount == 0, axis=None): warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default @@ -191,7 +198,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) # Make this warning show up on top. - if umr_any(ddof >= rcount, axis=None): + if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None): warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) @@ -273,7 +280,7 @@ def _ptp(a, axis=None, out=None, keepdims=False): def _dump(self, file, protocol=2): if hasattr(file, 'write'): - ctx = contextlib_nullcontext(file) + ctx = nullcontext(file) else: ctx = open(os_fspath(file), "wb") with ctx as f: |