diff options
author | Dimas Abreu Dutra <dimasadutra@gmail.com> | 2015-07-16 21:41:23 -0300 |
---|---|---|
committer | Dimas Abreu Dutra <dimasadutra@gmail.com> | 2015-07-17 02:54:49 -0300 |
commit | 98f186f4ea336138c31c471d18dccc6d9663ced7 (patch) | |
tree | 0242110ccdbb76fe3b97351ba570d165ae6d0b8d /numpy/lib | |
parent | 85188530bffae563eb274b9c12b77981cfa4e1d2 (diff) | |
download | numpy-98f186f4ea336138c31c471d18dccc6d9663ced7.tar.gz |
BUG: Fix tiling of zero-sized arrays numpy/numpy#6089 and add test case.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/shape_base.py | 20 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 3 |
2 files changed, 13 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 280765df8..26c2aab04 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -857,16 +857,16 @@ def tile(A, reps): # numpy array and the repetitions are 1 in all dimensions return _nx.array(A, copy=True, subok=True, ndmin=d) else: + # Note that no copy of zero-sized arrays is made. However since they + # have no data there is no risk of an inadvertent overwrite. c = _nx.array(A, copy=False, subok=True, ndmin=d) - shape = list(c.shape) - n = max(c.size, 1) if (d < c.ndim): tup = (1,)*(c.ndim-d) + tup - for i, nrep in enumerate(tup): - if nrep != 1: - c = c.reshape(-1, n).repeat(nrep, 0) - dim_in = shape[i] - dim_out = dim_in*nrep - shape[i] = dim_out - n //= max(dim_in, 1) - return c.reshape(shape) + shape_out = tuple(s*t for s, t in zip(c.shape, tup)) + n = c.size + if n > 0: + for dim_in, nrep in zip(c.shape, tup): + if nrep != 1: + c = c.reshape(-1, n).repeat(nrep, 0) + n //= dim_in + return c.reshape(shape_out) diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index fb9d7f364..3f2d8d5b4 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -332,7 +332,10 @@ class TestTile(TestCase): def test_empty(self): a = np.array([[[]]]) + b = np.array([[], []]) + c = tile(b, 2).shape d = tile(a, (3, 2, 5)).shape + assert_equal(c, (2, 0)) assert_equal(d, (3, 2, 0)) def test_kroncompare(self): |