summaryrefslogtreecommitdiff
path: root/numpy/core/_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r--numpy/core/_methods.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 75fd32ec8..e475b94df 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -4,14 +4,15 @@ and the Python code for the NumPy-namespace function
"""
import warnings
+from contextlib import nullcontext
from numpy.core import multiarray as mu
from numpy.core import umath as um
-from numpy.core._asarray import asanyarray
+from numpy.core.multiarray import asanyarray
from numpy.core import numerictypes as nt
from numpy.core import _exceptions
from numpy._globals import _NoValue
-from numpy.compat import pickle, os_fspath, contextlib_nullcontext
+from numpy.compat import pickle, os_fspath
# save those O(100) nanoseconds!
umr_maximum = um.maximum.reduce
@@ -51,9 +52,15 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
return umr_prod(a, axis, dtype, out, keepdims, initial, where)
def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
+ # Parsing keyword arguments is currently fairly slow, so avoid it for now
+ if where is True:
+ return umr_any(a, axis, dtype, out, keepdims)
return umr_any(a, axis, dtype, out, keepdims, where=where)
def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
+ # Parsing keyword arguments is currently fairly slow, so avoid it for now
+ if where is True:
+ return umr_all(a, axis, dtype, out, keepdims)
return umr_all(a, axis, dtype, out, keepdims, where=where)
def _count_reduce_items(arr, axis, keepdims=False, where=True):
@@ -158,7 +165,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
is_float16_result = False
rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
- if umr_any(rcount == 0, axis=None):
+ if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
# Cast bool, unsigned int, and int to float64 by default
@@ -191,7 +198,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
# Make this warning show up on top.
- if umr_any(ddof >= rcount, axis=None):
+ if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
stacklevel=2)
@@ -273,7 +280,7 @@ def _ptp(a, axis=None, out=None, keepdims=False):
def _dump(self, file, protocol=2):
if hasattr(file, 'write'):
- ctx = contextlib_nullcontext(file)
+ ctx = nullcontext(file)
else:
ctx = open(os_fspath(file), "wb")
with ctx as f: