diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/shape_base.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index b600b70f6..5812e102e 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1150,17 +1150,16 @@ def kron(a, b): a = reshape(a, as_) if not b.flags.contiguous: b = reshape(b, bs) - nd = ndb - if (ndb != nda): - if (ndb > nda): - as_ = (1,)*(ndb-nda) + as_ - else: - bs = (1,)*(nda-ndb) + bs - nd = nda - result = outer(a, b).reshape(as_+bs) - axis = nd-1 - for _ in range(nd): - result = concatenate(result, axis=axis) + as_ = (1,)*max(0, ndb-nda) + as_ + bs = (1,)*max(0, nda-ndb) + bs + nd = max(ndb, nda) + if 0 in as_ or 0 in bs: + result = zeros(_nx.multiply(as_, bs)) + else: + result = outer(a, b).reshape(as_+bs) + axis = nd-1 + for _ in range(nd): + result = concatenate(result, axis=axis) wrapper = get_array_prepare(a, b) if wrapper is not None: result = wrapper(result) |