summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py23
-rw-r--r--numpy/lib/tests/test_shape_base.py17
2 files changed, 30 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 5f21f9b34..ed84e9f5d 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,7 +1,7 @@
__all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
'column_stack','row_stack', 'dstack','array_split','split','hsplit',
'vsplit','dsplit','apply_over_axes','expand_dims',
- 'apply_along_axis', 'kron', 'tile']
+ 'apply_along_axis', 'kron', 'tile', 'get_array_wrap']
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outer, \
@@ -526,7 +526,7 @@ def dsplit(ary,indices_or_sections):
raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
return split(ary,indices_or_sections,2)
-def _getwrapper(*args):
+def get_array_wrap(*args):
"""Find the wrapper for the array with the highest priority.
In case of ties, leftmost wins. If no wrapper is found, return None
@@ -547,19 +547,28 @@ def kron(a,b):
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
- wrapper = _getwrapper(a, b)
+ wrapper = get_array_wrap(a, b)
b = asanyarray(b)
a = array(a,copy=False,subok=True,ndmin=b.ndim)
+ ndb, nda = b.ndim, a.ndim
+ if (nda == 0 or ndb == 0):
+ return a * b
as = a.shape
bs = b.shape
if not a.flags.contiguous:
a = reshape(a, as)
if not b.flags.contiguous:
b = reshape(b, bs)
- o = outer(a,b)
- result = o.reshape(as + bs)
- axis = a.ndim-1
- for k in xrange(b.ndim):
+ nd = ndb
+ if (ndb != nda):
+ if (ndb > nda):
+ as = (1,)*(ndb-nda) + as
+ else:
+ bs = (1,)*(nda-ndb) + bs
+ nd = nda
+ result = outer(a,b).reshape(as+bs)
+ axis = nd-1
+ for k in xrange(nd):
result = concatenate(result, axis=axis)
if wrapper is not None:
result = wrapper(result)
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 2d09e86c3..b43b08664 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -11,8 +11,6 @@ class test_apply_along_axis(NumpyTestCase):
a = ones((20,10),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
def check_simple101(self,level=11):
- # This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
- # when enabled and shape(a)[1]>100. See Issue 202.
a = ones((10,101),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
@@ -370,6 +368,7 @@ class test_kron(NumpyTestCase):
assert_equal(type(kron(a,ma)), ndarray)
assert_equal(type(kron(ma,a)), myarray)
+
class test_tile(NumpyTestCase):
def check_basic(self):
a = array([0,1,2])
@@ -380,7 +379,19 @@ class test_tile(NumpyTestCase):
assert_equal(tile(b, 2), [[1,2,1,2],[3,4,3,4]])
assert_equal(tile(b,(2,1)),[[1,2],[3,4],[1,2],[3,4]])
assert_equal(tile(b,(2,2)),[[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]])
-
+
+ def check_kroncompare(self):
+ import numpy.random as nr
+ reps=[(2,),(1,2),(2,1),(2,2),(2,3,2),(3,2)]
+ shape=[(3,),(2,3),(3,4,3),(3,2,3),(4,3,2,4),(2,2)]
+ for s in shape:
+ b = nr.randint(0,10,size=s)
+ for r in reps:
+ a = ones(r, b.dtype)
+ large = tile(b, r)
+ klarge = kron(a, b)
+ assert_equal(large, klarge)
+
# Utility
def compare_results(res,desired):