summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorDimas Abreu Dutra <dimasadutra@gmail.com>2015-07-16 21:41:23 -0300
committerDimas Abreu Dutra <dimasadutra@gmail.com>2015-07-17 02:54:49 -0300
commit98f186f4ea336138c31c471d18dccc6d9663ced7 (patch)
tree0242110ccdbb76fe3b97351ba570d165ae6d0b8d /numpy/lib
parent85188530bffae563eb274b9c12b77981cfa4e1d2 (diff)
downloadnumpy-98f186f4ea336138c31c471d18dccc6d9663ced7.tar.gz
BUG: Fix tiling of zero-sized arrays numpy/numpy#6089 and add test case.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py20
-rw-r--r--numpy/lib/tests/test_shape_base.py3
2 files changed, 13 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 280765df8..26c2aab04 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -857,16 +857,16 @@ def tile(A, reps):
# numpy array and the repetitions are 1 in all dimensions
return _nx.array(A, copy=True, subok=True, ndmin=d)
else:
+ # Note that no copy of zero-sized arrays is made. However since they
+ # have no data there is no risk of an inadvertent overwrite.
c = _nx.array(A, copy=False, subok=True, ndmin=d)
- shape = list(c.shape)
- n = max(c.size, 1)
if (d < c.ndim):
tup = (1,)*(c.ndim-d) + tup
- for i, nrep in enumerate(tup):
- if nrep != 1:
- c = c.reshape(-1, n).repeat(nrep, 0)
- dim_in = shape[i]
- dim_out = dim_in*nrep
- shape[i] = dim_out
- n //= max(dim_in, 1)
- return c.reshape(shape)
+ shape_out = tuple(s*t for s, t in zip(c.shape, tup))
+ n = c.size
+ if n > 0:
+ for dim_in, nrep in zip(c.shape, tup):
+ if nrep != 1:
+ c = c.reshape(-1, n).repeat(nrep, 0)
+ n //= dim_in
+ return c.reshape(shape_out)
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index fb9d7f364..3f2d8d5b4 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -332,7 +332,10 @@ class TestTile(TestCase):
def test_empty(self):
a = np.array([[[]]])
+ b = np.array([[], []])
+ c = tile(b, 2).shape
d = tile(a, (3, 2, 5)).shape
+ assert_equal(c, (2, 0))
assert_equal(d, (3, 2, 0))
def test_kroncompare(self):