diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-10-07 02:59:33 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-10-07 02:59:33 +0000 |
commit | e5181859b022407090ce43bfca833881c8488b09 (patch) | |
tree | 458a2da6a369be1fe3ba256a1e8036eb6324d935 /numpy/lib/shape_base.py | |
parent | ba6b099b3d4b97fe933e0ebf5f18ca29f8111855 (diff) | |
download | numpy-e5181859b022407090ce43bfca833881c8488b09.tar.gz |
Fix kron to be N-dimensional.
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 3d9b02bc3..5f21f9b34 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,7 +1,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack', 'column_stack','row_stack', 'dstack','array_split','split','hsplit', 'vsplit','dsplit','apply_over_axes','expand_dims', - 'apply_along_axis', 'tile', 'kron'] + 'apply_along_axis', 'kron', 'tile'] import numpy.core.numeric as _nx from numpy.core.numeric import asarray, zeros, newaxis, outer, \ @@ -542,27 +542,30 @@ def _getwrapper(*args): def kron(a,b): """kronecker product of a and b - Kronecker product of two matrices is block matrix + Kronecker product of two arrays is block array [[ a[ 0 ,0]*b, a[ 0 ,1]*b, ... , a[ 0 ,n-1]*b ], [ ... ... ], [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]] """ wrapper = _getwrapper(a, b) - a = asanyarray(a) - b = asanyarray(b) - if not (len(a.shape) == len(b.shape) == 2): - raise ValueError("a and b must both be two dimensional") + b = asanyarray(b) + a = array(a,copy=False,subok=True,ndmin=b.ndim) + as = a.shape + bs = b.shape if not a.flags.contiguous: - a = reshape(a, a.shape) + a = reshape(a, as) if not b.flags.contiguous: - b = reshape(b, b.shape) + b = reshape(b, bs) o = outer(a,b) - o=o.reshape(a.shape + b.shape) - result = concatenate(concatenate(o, axis=1), axis=1) + result = o.reshape(as + bs) + axis = a.ndim-1 + for k in xrange(b.ndim): + result = concatenate(result, axis=axis) if wrapper is not None: result = wrapper(result) return result + def tile(A, reps): """Repeat an array the number of times given in the integer tuple, reps. |