summaryrefslogtreecommitdiff
path: root/numpy/core/_methods.py
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-03-10 16:03:48 -0700
committerGitHub <noreply@github.com>2020-03-10 16:03:48 -0700
commit2e9169601aff252a661b845399ec61c3e575407f (patch)
tree9fdd9d6cd1678d2f7acb3d47a9e1831df06739fa /numpy/core/_methods.py
parentc9bfd4eb68e61c67aa27ed0cb2788f60d11cf354 (diff)
parentffe1f46121cd11b2b876d20ba1758a09cb4e5be7 (diff)
downloadnumpy-2e9169601aff252a661b845399ec61c3e575407f.tar.gz
Merge pull request #15696 from rossbar/enh/var_complex_fastpath
MAINT: Add a fast path to var for complex input
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r--numpy/core/_methods.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 694523b20..8a90731e9 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce
umr_any = um.logical_or.reduce
umr_all = um.logical_and.reduce
+# Complex types to -> (2,)float view for fast-path computation in _var()
+_complex_to_float = {
+ nt.dtype(nt.csingle) : nt.dtype(nt.single),
+ nt.dtype(nt.cdouble) : nt.dtype(nt.double),
+}
+# Special case for windows: ensure double takes precedence
+if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
+ _complex_to_float.update({
+ nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
+ })
+# Add reverse-endian types
+_complex_to_float.update({
+ k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items()
+})
+
# avoid keyword arguments to speed up parsing, saves about 15%-20% for very
# small reductions
def _amax(a, axis=None, out=None, keepdims=False,
@@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
# Note that x may not be inexact and that we need it to be an array,
# not a scalar.
x = asanyarray(arr - arrmean)
+
if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
x = um.multiply(x, x, out=x)
+ # Fast-paths for built-in complex types
+ elif x.dtype in _complex_to_float:
+ xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
+ um.multiply(xv, xv, out=xv)
+ x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
+ # Most general case; includes handling object arrays containing imaginary
+ # numbers and complex types with non-native byteorder
else:
x = um.multiply(x, um.conjugate(x), out=x).real