diff options
author | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 06:37:16 -0700 |
---|---|---|
committer | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 06:37:16 -0700 |
commit | 581927a4fcbc7d7c54b7b37e6edc121669863ea5 (patch) | |
tree | ca3f0ddaf9bb359cc64f6633dab55b397e698fb3 /numpy/core/numeric.py | |
parent | 0a02b82ed72f0268875bfc6c70d1e8a8dad6c644 (diff) | |
download | numpy-581927a4fcbc7d7c54b7b37e6edc121669863ea5.tar.gz |
BUG: Handling of axisc in np.cross
Fixes #5885 by ignoring `axisc` when both input vectors are 2D.
Also adds explicit checks for `axis?` parameters in bounds, to
provide more informative errors.
Also slightly simplified the calculation logic and documented the
assumptions in each branch with `assert`s.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 7a0fa4b62..ea2d4d0a2 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1478,8 +1478,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): axisb : int, optional Axis of `b` that defines the vector(s). By default, the last axis. axisc : int, optional - Axis of `c` containing the cross product vector(s). By default, the - last axis. + Axis of `c` containing the cross product vector(s). Ignored if + both input vectors have dimension 2, as the return is scalar. + By default, the last axis. axis : int, optional If defined, the axis of `a`, `b` and `c` that defines the vector(s) and cross product(s). Overrides `axisa`, `axisb` and `axisc`. @@ -1570,6 +1571,12 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): axisa, axisb, axisc = (axis,) * 3 a = asarray(a) b = asarray(b) + # Check axisa and axisb are within bounds + axis_msg = "'axis{0}' out of bounds" + if axisa < -a.ndim or axisa >= a.ndim: + raise ValueError(axis_msg.format('a')) + if axisb < -b.ndim or axisb >= b.ndim: + raise ValueError(axis_msg.format('b')) # Move working axis to the end of the shape a = rollaxis(a, axisa, a.ndim) b = rollaxis(b, axisb, b.ndim) @@ -1578,10 +1585,13 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): raise ValueError(msg) - # Create the output array + # Create the output array shape = broadcast(a[..., 0], b[..., 0]).shape if a.shape[-1] == 3 or b.shape[-1] == 3: shape += (3,) + # Check axisc is within bounds + if axisc < -len(shape) or axisc >= len(shape): + raise ValueError(axis_msg.format('c')) dtype = promote_types(a.dtype, b.dtype) cp = empty(shape, dtype) @@ -1604,12 +1614,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # a0 * b1 - a1 * b0 multiply(a0, b1, out=cp) cp -= a1 * b0 - if cp.ndim == 0: - return cp - else: - # This works because we are moving the last axis - return rollaxis(cp, -1, axisc) + return cp else: + assert b.shape[-1] == 3 # cp0 = a1 * b2 - 0 (a2 = 0) # cp1 = 0 - a0 * b2 (a2 = 0) # cp2 = a0 * b1 - a1 * b0 @@ -1618,7 +1625,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): negative(cp1, out=cp1) multiply(a0, b1, out=cp2) cp2 -= a1 * b0 - elif a.shape[-1] == 3: + else: + assert a.shape[-1] == 3 if b.shape[-1] == 3: # cp0 = a1 * b2 - a2 * b1 # cp1 = a2 * b0 - a0 * b2 @@ -1633,6 +1641,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): multiply(a1, b0, out=tmp) cp2 -= tmp else: + assert b.shape[-1] == 2 # cp0 = 0 - a2 * b1 (b2 = 0) # cp1 = a2 * b0 - 0 (b2 = 0) # cp2 = a0 * b1 - a1 * b0 @@ -1642,11 +1651,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): multiply(a0, b1, out=cp2) cp2 -= a1 * b0 - if cp.ndim == 1: - return cp - else: - # This works because we are moving the last axis - return rollaxis(cp, -1, axisc) + # This works because we are moving the last axis + return rollaxis(cp, -1, axisc) #Use numarray's printing function from .arrayprint import array2string, get_printoptions, set_printoptions |