diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/shape_base.py | 23 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 17 |
2 files changed, 30 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 5f21f9b34..ed84e9f5d 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,7 +1,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack', 'column_stack','row_stack', 'dstack','array_split','split','hsplit', 'vsplit','dsplit','apply_over_axes','expand_dims', - 'apply_along_axis', 'kron', 'tile'] + 'apply_along_axis', 'kron', 'tile', 'get_array_wrap'] import numpy.core.numeric as _nx from numpy.core.numeric import asarray, zeros, newaxis, outer, \ @@ -526,7 +526,7 @@ def dsplit(ary,indices_or_sections): raise ValueError, 'vsplit only works on arrays of 3 or more dimensions' return split(ary,indices_or_sections,2) -def _getwrapper(*args): +def get_array_wrap(*args): """Find the wrapper for the array with the highest priority. In case of ties, leftmost wins. If no wrapper is found, return None @@ -547,19 +547,28 @@ def kron(a,b): [ ... ... ], [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]] """ - wrapper = _getwrapper(a, b) + wrapper = get_array_wrap(a, b) b = asanyarray(b) a = array(a,copy=False,subok=True,ndmin=b.ndim) + ndb, nda = b.ndim, a.ndim + if (nda == 0 or ndb == 0): + return a * b as = a.shape bs = b.shape if not a.flags.contiguous: a = reshape(a, as) if not b.flags.contiguous: b = reshape(b, bs) - o = outer(a,b) - result = o.reshape(as + bs) - axis = a.ndim-1 - for k in xrange(b.ndim): + nd = ndb + if (ndb != nda): + if (ndb > nda): + as = (1,)*(ndb-nda) + as + else: + bs = (1,)*(nda-ndb) + bs + nd = nda + result = outer(a,b).reshape(as+bs) + axis = nd-1 + for k in xrange(nd): result = concatenate(result, axis=axis) if wrapper is not None: result = wrapper(result) diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 2d09e86c3..b43b08664 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -11,8 +11,6 @@ class test_apply_along_axis(NumpyTestCase): a = ones((20,10),'d') assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1])) def check_simple101(self,level=11): - # This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4) - # when enabled and shape(a)[1]>100. See Issue 202. a = ones((10,101),'d') assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1])) @@ -370,6 +368,7 @@ class test_kron(NumpyTestCase): assert_equal(type(kron(a,ma)), ndarray) assert_equal(type(kron(ma,a)), myarray) + class test_tile(NumpyTestCase): def check_basic(self): a = array([0,1,2]) @@ -380,7 +379,19 @@ class test_tile(NumpyTestCase): assert_equal(tile(b, 2), [[1,2,1,2],[3,4,3,4]]) assert_equal(tile(b,(2,1)),[[1,2],[3,4],[1,2],[3,4]]) assert_equal(tile(b,(2,2)),[[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]]) - + + def check_kroncompare(self): + import numpy.random as nr + reps=[(2,),(1,2),(2,1),(2,2),(2,3,2),(3,2)] + shape=[(3,),(2,3),(3,4,3),(3,2,3),(4,3,2,4),(2,2)] + for s in shape: + b = nr.randint(0,10,size=s) + for r in reps: + a = ones(r, b.dtype) + large = tile(b, r) + klarge = kron(a, b) + assert_equal(large, klarge) + # Utility def compare_results(res,desired): |