summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index b600b70f6..5812e102e 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1150,17 +1150,16 @@ def kron(a, b):
a = reshape(a, as_)
if not b.flags.contiguous:
b = reshape(b, bs)
- nd = ndb
- if (ndb != nda):
- if (ndb > nda):
- as_ = (1,)*(ndb-nda) + as_
- else:
- bs = (1,)*(nda-ndb) + bs
- nd = nda
- result = outer(a, b).reshape(as_+bs)
- axis = nd-1
- for _ in range(nd):
- result = concatenate(result, axis=axis)
+ as_ = (1,)*max(0, ndb-nda) + as_
+ bs = (1,)*max(0, nda-ndb) + bs
+ nd = max(ndb, nda)
+ if 0 in as_ or 0 in bs:
+ result = zeros(_nx.multiply(as_, bs))
+ else:
+ result = outer(a, b).reshape(as_+bs)
+ axis = nd-1
+ for _ in range(nd):
+ result = concatenate(result, axis=axis)
wrapper = get_array_prepare(a, b)
if wrapper is not None:
result = wrapper(result)