summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-10-07 02:59:33 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-10-07 02:59:33 +0000
commite5181859b022407090ce43bfca833881c8488b09 (patch)
tree458a2da6a369be1fe3ba256a1e8036eb6324d935 /numpy/lib/shape_base.py
parentba6b099b3d4b97fe933e0ebf5f18ca29f8111855 (diff)
downloadnumpy-e5181859b022407090ce43bfca833881c8488b09.tar.gz
Fix kron to be N-dimensional.
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 3d9b02bc3..5f21f9b34 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,7 +1,7 @@
__all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
'column_stack','row_stack', 'dstack','array_split','split','hsplit',
'vsplit','dsplit','apply_over_axes','expand_dims',
- 'apply_along_axis', 'tile', 'kron']
+ 'apply_along_axis', 'kron', 'tile']
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outer, \
@@ -542,27 +542,30 @@ def _getwrapper(*args):
def kron(a,b):
"""kronecker product of a and b
- Kronecker product of two matrices is block matrix
+ Kronecker product of two arrays is block array
[[ a[ 0 ,0]*b, a[ 0 ,1]*b, ... , a[ 0 ,n-1]*b ],
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
wrapper = _getwrapper(a, b)
- a = asanyarray(a)
- b = asanyarray(b)
- if not (len(a.shape) == len(b.shape) == 2):
- raise ValueError("a and b must both be two dimensional")
+ b = asanyarray(b)
+ a = array(a,copy=False,subok=True,ndmin=b.ndim)
+ as = a.shape
+ bs = b.shape
if not a.flags.contiguous:
- a = reshape(a, a.shape)
+ a = reshape(a, as)
if not b.flags.contiguous:
- b = reshape(b, b.shape)
+ b = reshape(b, bs)
o = outer(a,b)
- o=o.reshape(a.shape + b.shape)
- result = concatenate(concatenate(o, axis=1), axis=1)
+ result = o.reshape(as + bs)
+ axis = a.ndim-1
+ for k in xrange(b.ndim):
+ result = concatenate(result, axis=axis)
if wrapper is not None:
result = wrapper(result)
return result
+
def tile(A, reps):
"""Repeat an array the number of times given in the integer tuple, reps.