summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py53
-rw-r--r--numpy/core/tests/test_numeric.py4
2 files changed, 46 insertions, 11 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index ee14f2af6..e371797b6 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1429,6 +1429,11 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
outer : Outer product.
ix_ : Construct index arrays.
+ Notes
+ -----
+ .. versionadded:: 1.9.0
+ Supports full broadcasting of the inputs.
+
Examples
--------
Vector cross-product.
@@ -1500,28 +1505,54 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"(dimension must be 2 or 3)")
if a.shape[-1] not in [2, 3] or b.shape[-1] not in [2, 3]:
raise ValueError(msg)
+
+ # Create the output array
+ shape = broadcast(a[..., 0], b[..., 0]).shape
+ if a.shape[-1] == 3 or b.shape[-1] == 3:
+ shape += (3,)
+ dtype = promote_types(a.dtype, b.dtype)
+ cp = empty(shape, dtype)
+
if a.shape[-1] == 2:
if b.shape[-1] == 2:
- cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ # cp = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp)
+ cp -= a[..., 1]*b[..., 0]
if cp.ndim == 0:
return cp
else:
# This works because we are moving the last axis
return rollaxis(cp, -1, axisc)
else:
- x = a[..., 1]*b[..., 2]
- y = -a[..., 0]*b[..., 2]
- z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ # cp[..., 0] = a[..., 1]*b[..., 2]
+ multiply(a[..., 1], b[..., 2], out=cp[..., 0])
+ # cp[..., 1] = -a[..., 0]*b[..., 2]
+ multiply(a[..., 0], b[..., 2], out=cp[..., 1])
+ cp[..., 1] *= - 1
+ # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp[..., 2])
+ cp[..., 2] -= a[..., 1]*b[..., 0]
elif a.shape[-1] == 3:
if b.shape[-1] == 3:
- x = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1]
- y = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2]
- z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ # cp[..., 0] = a[..., 1]*b[..., 2] - a[..., 2]*b[..., 1]
+ multiply(a[..., 1], b[..., 2], out=cp[..., 0])
+ cp[..., 0] -= a[..., 2]*b[..., 1]
+ # cp[..., 1] = a[..., 2]*b[..., 0] - a[..., 0]*b[..., 2]
+ multiply(a[..., 2], b[..., 0], out=cp[..., 1])
+ cp[..., 1] -= a[..., 0]*b[..., 2]
+ # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp[..., 2])
+ cp[..., 2] -= a[..., 1]*b[..., 0]
else:
- x = -a[..., 2]*b[..., 1]
- y = a[..., 2]*b[..., 0]
- z = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
- cp = concatenate((x[..., None], y[..., None], z[..., None]), axis=-1)
+ # cp[..., 0] = -a[..., 2]*b[..., 1]
+ multiply(a[..., 2], b[..., 1], out=cp[..., 0])
+ cp[..., 0] *= - 1
+ # cp[..., 1] = a[..., 2]*b[..., 0]
+ multiply(a[..., 2], b[..., 0], out=cp[..., 1])
+ # cp[..., 2] = a[..., 0]*b[..., 1] - a[..., 1]*b[..., 0]
+ multiply(a[..., 0], b[..., 1], out=cp[..., 2])
+ cp[..., 2] -= a[..., 1]*b[..., 0]
+
if cp.ndim == 1:
return cp
else:
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 47c56aa68..677deea8e 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1913,9 +1913,13 @@ class TestCross(TestCase):
u = np.ones((10, 3, 5))
v = np.ones((2, 5))
assert_equal(np.cross(u, v, axisa=1, axisb=0).shape, (10, 5, 3))
+ assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=2)
+ assert_raises(ValueError, np.cross, u, v, axisa=3, axisb=0)
u = np.ones((10, 3, 5, 7))
v = np.ones((5, 7, 2))
assert_equal(np.cross(u, v, axisa=1, axisc=2).shape, (10, 5, 3, 7))
+ assert_raises(ValueError, np.cross, u, v, axisa=-5, axisb=2)
+ assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=-4)
if __name__ == "__main__":