summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorJaime Fernandez <jaime.frio@gmail.com>2015-05-17 06:37:16 -0700
committerJaime Fernandez <jaime.frio@gmail.com>2015-05-17 06:37:16 -0700
commit581927a4fcbc7d7c54b7b37e6edc121669863ea5 (patch)
treeca3f0ddaf9bb359cc64f6633dab55b397e698fb3 /numpy/core/numeric.py
parent0a02b82ed72f0268875bfc6c70d1e8a8dad6c644 (diff)
downloadnumpy-581927a4fcbc7d7c54b7b37e6edc121669863ea5.tar.gz
BUG: Handling of axisc in np.cross
Fixes #5885 by ignoring `axisc` when both input vectors are 2D. Also adds explicit checks for `axis?` parameters in bounds, to provide more informative errors. Also slightly simplified the calculation logic and documented the assumptions in each branch with `assert`s.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 7a0fa4b62..ea2d4d0a2 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1478,8 +1478,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
axisb : int, optional
Axis of `b` that defines the vector(s). By default, the last axis.
axisc : int, optional
- Axis of `c` containing the cross product vector(s). By default, the
- last axis.
+ Axis of `c` containing the cross product vector(s). Ignored if
+ both input vectors have dimension 2, as the return is scalar.
+ By default, the last axis.
axis : int, optional
If defined, the axis of `a`, `b` and `c` that defines the vector(s)
and cross product(s). Overrides `axisa`, `axisb` and `axisc`.
@@ -1570,6 +1571,12 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
axisa, axisb, axisc = (axis,) * 3
a = asarray(a)
b = asarray(b)
+ # Check axisa and axisb are within bounds
+ axis_msg = "'axis{0}' out of bounds"
+ if axisa < -a.ndim or axisa >= a.ndim:
+ raise ValueError(axis_msg.format('a'))
+ if axisb < -b.ndim or axisb >= b.ndim:
+ raise ValueError(axis_msg.format('b'))
# Move working axis to the end of the shape
a = rollaxis(a, axisa, a.ndim)
b = rollaxis(b, axisb, b.ndim)
@@ -1578,10 +1585,13 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError(msg)
- # Create the output array
+ # Create the output array
shape = broadcast(a[..., 0], b[..., 0]).shape
if a.shape[-1] == 3 or b.shape[-1] == 3:
shape += (3,)
+ # Check axisc is within bounds
+ if axisc < -len(shape) or axisc >= len(shape):
+ raise ValueError(axis_msg.format('c'))
dtype = promote_types(a.dtype, b.dtype)
cp = empty(shape, dtype)
@@ -1604,12 +1614,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
# a0 * b1 - a1 * b0
multiply(a0, b1, out=cp)
cp -= a1 * b0
- if cp.ndim == 0:
- return cp
- else:
- # This works because we are moving the last axis
- return rollaxis(cp, -1, axisc)
+ return cp
else:
+ assert b.shape[-1] == 3
# cp0 = a1 * b2 - 0 (a2 = 0)
# cp1 = 0 - a0 * b2 (a2 = 0)
# cp2 = a0 * b1 - a1 * b0
@@ -1618,7 +1625,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
negative(cp1, out=cp1)
multiply(a0, b1, out=cp2)
cp2 -= a1 * b0
- elif a.shape[-1] == 3:
+ else:
+ assert a.shape[-1] == 3
if b.shape[-1] == 3:
# cp0 = a1 * b2 - a2 * b1
# cp1 = a2 * b0 - a0 * b2
@@ -1633,6 +1641,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
multiply(a1, b0, out=tmp)
cp2 -= tmp
else:
+ assert b.shape[-1] == 2
# cp0 = 0 - a2 * b1 (b2 = 0)
# cp1 = a2 * b0 - 0 (b2 = 0)
# cp2 = a0 * b1 - a1 * b0
@@ -1642,11 +1651,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
multiply(a0, b1, out=cp2)
cp2 -= a1 * b0
- if cp.ndim == 1:
- return cp
- else:
- # This works because we are moving the last axis
- return rollaxis(cp, -1, axisc)
+ # This works because we are moving the last axis
+ return rollaxis(cp, -1, axisc)
#Use numarray's printing function
from .arrayprint import array2string, get_printoptions, set_printoptions