diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-03-10 16:03:48 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-10 16:03:48 -0700 |
commit | 2e9169601aff252a661b845399ec61c3e575407f (patch) | |
tree | 9fdd9d6cd1678d2f7acb3d47a9e1831df06739fa /numpy/core/_methods.py | |
parent | c9bfd4eb68e61c67aa27ed0cb2788f60d11cf354 (diff) | |
parent | ffe1f46121cd11b2b876d20ba1758a09cb4e5be7 (diff) | |
download | numpy-2e9169601aff252a661b845399ec61c3e575407f.tar.gz |
Merge pull request #15696 from rossbar/enh/var_complex_fastpath
MAINT: Add a fast path to var for complex input
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r-- | numpy/core/_methods.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 694523b20..8a90731e9 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce umr_any = um.logical_or.reduce umr_all = um.logical_and.reduce +# Complex types to -> (2,)float view for fast-path computation in _var() +_complex_to_float = { + nt.dtype(nt.csingle) : nt.dtype(nt.single), + nt.dtype(nt.cdouble) : nt.dtype(nt.double), +} +# Special case for windows: ensure double takes precedence +if nt.dtype(nt.longdouble) != nt.dtype(nt.double): + _complex_to_float.update({ + nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble), + }) +# Add reverse-endian types +_complex_to_float.update({ + k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items() +}) + # avoid keyword arguments to speed up parsing, saves about 15%-20% for very # small reductions def _amax(a, axis=None, out=None, keepdims=False, @@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # Note that x may not be inexact and that we need it to be an array, # not a scalar. x = asanyarray(arr - arrmean) + if issubclass(arr.dtype.type, (nt.floating, nt.integer)): x = um.multiply(x, x, out=x) + # Fast-paths for built-in complex types + elif x.dtype in _complex_to_float: + xv = x.view(dtype=(_complex_to_float[x.dtype], (2,))) + um.multiply(xv, xv, out=xv) + x = um.add(xv[..., 0], xv[..., 1], out=x.real).real + # Most general case; includes handling object arrays containing imaginary + # numbers and complex types with non-native byteorder else: x = um.multiply(x, um.conjugate(x), out=x).real |