summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py26
-rw-r--r--numpy/lib/tests/test_shape_base.py32
2 files changed, 55 insertions, 3 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 2d2e6f337..c4f519d30 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -5,7 +5,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outerproduct, \
- concatenate, isscalar, array
+ concatenate, isscalar, array, asanyarray
from numpy.core.oldnumeric import product, reshape
def apply_along_axis(func1d,axis,arr,*args):
@@ -544,7 +544,19 @@ def repmat(a, m, n):
return c.reshape(rows, cols)
-# TODO: figure out how to keep arrays the same
+def _getwrapper(*args):
+ """Find the wrapper for the array with the highest priority.
+
+ In case of ties, leftmost wins. If no wrapper is found, return None
+ """
+ wrappers = [(getattr(x, '__array_priority__', 0), -i,
+ x.__array_wrap__) for i, x in enumerate(args)
+ if hasattr(x, '__array_wrap__')]
+ wrappers.sort()
+ if wrappers:
+ return wrappers[-1][-1]
+ return None
+
def kron(a,b):
"""kronecker product of a and b
@@ -553,10 +565,18 @@ def kron(a,b):
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
+ wrapper = _getwrapper(a, b)
+ a = asanyarray(a)
+ b = asanyarray(b)
+ if not (len(a.shape) == len(b.shape) == 2):
+ raise ValueError("a and b must both be two dimensional")
if not a.flags.contiguous:
a = reshape(a, a.shape)
if not b.flags.contiguous:
b = reshape(b, b.shape)
o = outerproduct(a,b)
o=o.reshape(a.shape + b.shape)
- return concatenate(concatenate(o, axis=1), axis=1)
+ result = concatenate(concatenate(o, axis=1), axis=1)
+ if wrapper is not None:
+ result = wrapper(result)
+ return result
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 2989df9fd..984c63484 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -354,6 +354,38 @@ class test_squeeze(ScipyTestCase):
assert_array_equal(squeeze(b),reshape(b,(20,10,20)))
assert_array_equal(squeeze(c),reshape(c,(20,10)))
+class test_kron(ScipyTestCase):
+ def check_return_type(self):
+ a = ones([2,2])
+ m = asmatrix(a)
+ assert_equal(type(kron(a,a)), ndarray)
+ assert_equal(type(kron(m,m)), matrix)
+ assert_equal(type(kron(a,m)), matrix)
+ assert_equal(type(kron(m,a)), matrix)
+ class myarray(ndarray):
+ __array_priority__ = 0.0
+ ma = myarray(a.shape, a.dtype, a.data)
+ assert_equal(type(kron(a,a)), ndarray)
+ assert_equal(type(kron(ma,ma)), myarray)
+ assert_equal(type(kron(a,ma)), ndarray)
+ assert_equal(type(kron(ma,a)), myarray)
+ def check_rank_checking(self):
+ one = ones([2])
+ two = ones([2,2])
+ three = ones([2,2,2])
+ for a in [one, two, three]:
+ for b in [one, two, three]:
+ if a is b is two:
+ continue
+ try:
+ kron(a, b)
+ except ValueError:
+ continue
+ except:
+ pass
+ assert False, "ValueError expected"
+
+
# Utility
def compare_results(res,desired):