diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/shape_base.py | 26 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 32 |
2 files changed, 55 insertions, 3 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 2d2e6f337..c4f519d30 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -5,7 +5,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack', import numpy.core.numeric as _nx from numpy.core.numeric import asarray, zeros, newaxis, outerproduct, \ - concatenate, isscalar, array + concatenate, isscalar, array, asanyarray from numpy.core.oldnumeric import product, reshape def apply_along_axis(func1d,axis,arr,*args): @@ -544,7 +544,19 @@ def repmat(a, m, n): return c.reshape(rows, cols) -# TODO: figure out how to keep arrays the same +def _getwrapper(*args): + """Find the wrapper for the array with the highest priority. + + In case of ties, leftmost wins. If no wrapper is found, return None + """ + wrappers = [(getattr(x, '__array_priority__', 0), -i, + x.__array_wrap__) for i, x in enumerate(args) + if hasattr(x, '__array_wrap__')] + wrappers.sort() + if wrappers: + return wrappers[-1][-1] + return None + def kron(a,b): """kronecker product of a and b @@ -553,10 +565,18 @@ def kron(a,b): [ ... ... ], [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]] """ + wrapper = _getwrapper(a, b) + a = asanyarray(a) + b = asanyarray(b) + if not (len(a.shape) == len(b.shape) == 2): + raise ValueError("a and b must both be two dimensional") if not a.flags.contiguous: a = reshape(a, a.shape) if not b.flags.contiguous: b = reshape(b, b.shape) o = outerproduct(a,b) o=o.reshape(a.shape + b.shape) - return concatenate(concatenate(o, axis=1), axis=1) + result = concatenate(concatenate(o, axis=1), axis=1) + if wrapper is not None: + result = wrapper(result) + return result diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 2989df9fd..984c63484 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -354,6 +354,38 @@ class test_squeeze(ScipyTestCase): assert_array_equal(squeeze(b),reshape(b,(20,10,20))) assert_array_equal(squeeze(c),reshape(c,(20,10))) +class test_kron(ScipyTestCase): + def check_return_type(self): + a = ones([2,2]) + m = asmatrix(a) + assert_equal(type(kron(a,a)), ndarray) + assert_equal(type(kron(m,m)), matrix) + assert_equal(type(kron(a,m)), matrix) + assert_equal(type(kron(m,a)), matrix) + class myarray(ndarray): + __array_priority__ = 0.0 + ma = myarray(a.shape, a.dtype, a.data) + assert_equal(type(kron(a,a)), ndarray) + assert_equal(type(kron(ma,ma)), myarray) + assert_equal(type(kron(a,ma)), ndarray) + assert_equal(type(kron(ma,a)), myarray) + def check_rank_checking(self): + one = ones([2]) + two = ones([2,2]) + three = ones([2,2,2]) + for a in [one, two, three]: + for b in [one, two, three]: + if a is b is two: + continue + try: + kron(a, b) + except ValueError: + continue + except: + pass + assert False, "ValueError expected" + + # Utility def compare_results(res,desired): |