diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-06 19:35:56 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-07 13:03:12 +0200 |
commit | 22df0769eeb180326a657d850faa98e27b70eea5 (patch) | |
tree | 96bd605e8e430d5f824c907aea795eef240b8e61 /numpy/core/numeric.py | |
parent | 4c854c2633894387988b43306ff72333cb00613a (diff) | |
download | numpy-22df0769eeb180326a657d850faa98e27b70eea5.tar.gz |
MAINT: improve readablility of cross and improve test coverage
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 75 |
1 files changed, 46 insertions, 29 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 38c28da6d..a85e8514c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1522,7 +1522,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): b = rollaxis(b, axisb, b.ndim) msg = ("incompatible dimensions for cross product\n" "(dimension must be 2 or 3)") - if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]: + if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): raise ValueError(msg) # Create the output array @@ -1532,45 +1532,62 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): dtype = promote_types(a.dtype, b.dtype) cp = empty(shape, dtype) + # create local aliases for readability + a0 = a[..., 0] + a1 = a[..., 1] + if a.shape[-1] == 3: + a2 = a[..., 2] + b0 = b[..., 0] + b1 = b[..., 1] + if b.shape[-1] == 3: + b2 = b[..., 2] + if cp.ndim != 0 and cp.shape[-1] == 3: + cp0 = cp[..., 0] + cp1 = cp[..., 1] + cp2 = cp[..., 2] + if a.shape[-1] == 2: if b.shape[-1] == 2: - # cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] - multiply(a[..., 0], b[..., 1], out=cp) - cp -= a[..., 1]*b[..., 0] + # a0 * b1 - a1 * b0 + multiply(a0, b1, out=cp) + cp -= a1 * b0 if cp.ndim == 0: return cp else: # This works because we are moving the last axis return rollaxis(cp, -1, axisc) else: - # cp[..., 0] = a[..., 1]*b[..., 2] - multiply(a[..., 1], b[..., 2], out=cp[..., 0]) - # cp[..., 1] = -a[..., 0]*b[..., 2] - multiply(a[..., 0], b[..., 2], out=cp[..., 1]) - cp[..., 1] *= - 1 - # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] - multiply(a[..., 0], b[..., 1], out=cp[..., 2]) - cp[..., 2] -= a[..., 1]*b[..., 0] + # cp0 = a1 * b2 - 0 (a2 = 0) + # cp1 = 0 - a0 * b2 (a2 = 0) + # cp2 = a0 * b1 - a1 * b0 + multiply(a1, b2, out=cp0) + multiply(a0, b2, out=cp1) + negative(cp1, out=cp1) + multiply(a0, b1, out=cp2) + cp2 -= a1 * b0 elif a.shape[-1] == 3: if b.shape[-1] == 3: - # cp[..., 0] = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1] - multiply(a[..., 1], b[..., 2], out=cp[..., 0]) - cp[..., 0] -= a[..., 2]*b[..., 1] - # cp[..., 1] = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2] - multiply(a[..., 2], b[..., 0], out=cp[..., 1]) - cp[..., 1] -= a[..., 0]*b[..., 2] - # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] - multiply(a[..., 0], b[..., 1], out=cp[..., 2]) - cp[..., 2] -= a[..., 1]*b[..., 0] + # cp0 = a1 * b2 - a2 * b1 + # cp1 = a2 * b0 - a0 * b2 + # cp2 = a0 * b1 - a1 * b0 + multiply(a1, b2, out=cp0) + tmp = array(a2 * b1) + cp0 -= tmp + multiply(a2, b0, out=cp1) + multiply(a0, b2, out=tmp) + cp1 -= tmp + multiply(a0, b1, out=cp2) + multiply(a1, b0, out=tmp) + cp2 -= tmp else: - # cp[..., 0] = -a[..., 2]*b[..., 1] - multiply(a[..., 2], b[..., 1], out=cp[..., 0]) - cp[..., 0] *= - 1 - # cp[..., 1] = a[..., 2]*b[..., 0] - multiply(a[..., 2], b[..., 0], out=cp[..., 1]) - # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0] - multiply(a[..., 0], b[..., 1], out=cp[..., 2]) - cp[..., 2] -= a[..., 1]*b[..., 0] + # cp0 = 0 - a2 * b1 (b2 = 0) + # cp1 = a2 * b0 - 0 (b2 = 0) + # cp2 = a0 * b1 - a1 * b0 + multiply(a2, b1, out=cp0) + negative(cp0, out=cp0) + multiply(a2, b0, out=cp1) + multiply(a0, b1, out=cp2) + cp2 -= a1 * b0 if cp.ndim == 1: return cp |