diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-01-24 22:54:47 +0100 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2013-01-25 13:16:21 +0100 |
commit | 49bf2045e1cc7d1fb7e2ce771fbb1636c9d28e5a (patch) | |
tree | 2fbfd16f0f0d72c053ecf272058c6594db5db980 /numpy | |
parent | 4600b2fe1d7ebafef98858572b603f2a8d19db4d (diff) | |
download | numpy-49bf2045e1cc7d1fb7e2ce771fbb1636c9d28e5a.tar.gz |
BUG: Fix strides of trailing 1s when reshaping
When adding ones to the shape for non contiguous arrays reshapes
(i.e. either the first array is not contiguous or the the reshape
order does not match it). The strides of trailing ones were not
set. For example reshape of (6,) to (6,1,1). Previously this occured
rarely because of removed special handleing when only ones were
added or removed.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/shape.c | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 9 |
2 files changed, 37 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c index 97ddb2e61..4223e49f6 100644 --- a/numpy/core/src/multiarray/shape.c +++ b/numpy/core/src/multiarray/shape.c @@ -355,10 +355,14 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims, int oldnd; npy_intp olddims[NPY_MAXDIMS]; npy_intp oldstrides[NPY_MAXDIMS]; - npy_intp np, op; + npy_intp np, op, last_stride; int oi, oj, ok, ni, nj, nk; oldnd = 0; + /* + * Remove axes with dimension 1 from the old array. They have no effect + * but would need special cases since their strides do not matter. + */ for (oi = 0; oi < PyArray_NDIM(self); oi++) { if (PyArray_DIMS(self)[oi]!= 1) { olddims[oldnd] = PyArray_DIMS(self)[oi]; @@ -390,27 +394,31 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims, /* different total sizes; no hope */ return 0; } - /* the current code does not handle 0-sized arrays, so give up */ + if (np == 0) { + /* the current code does not handle 0-sized arrays, so give up */ return 0; } + /* oi to oj and ni to nj give the axis ranges currently worked with */ oi = 0; oj = 1; ni = 0; nj = 1; - while(ni < newnd && oi < oldnd) { + while (ni < newnd && oi < oldnd) { np = newdims[ni]; op = olddims[oi]; while (np != op) { if (np < op) { + /* Misses trailing 1s, these are handled later */ np *= newdims[nj++]; } else { op *= olddims[oj++]; } } + /* Check whether the original axes can be combined */ for (ok = oi; ok < oj - 1; ok++) { if (is_f_order) { if (oldstrides[ok+1] != olddims[ok]*oldstrides[ok]) { @@ -427,6 +435,7 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims, } } + /* Calculate new strides for all axes currently worked with */ if (is_f_order) { newstrides[ni] = oldstrides[oi]; for (nk = ni + 1; nk < nj; nk++) { @@ -445,6 +454,22 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims, } /* + * Set strides corresponding to trailing 1s of the new shape. + */ + if (ni >= 1) { + last_stride = newstrides[ni - 1]; + } + else { + last_stride = PyArray_ITEMSIZE(self); + } + if (is_f_order) { + last_stride *= newdims[ni - 1]; + } + for (nk = ni; nk < newnd; nk++) { + newstrides[nk] = last_stride; + } + + /* fprintf(stderr, "success: _attempt_nocopy_reshape ("); for (oi=0; oi<oldnd; oi++) fprintf(stderr, "(%d,%d), ", olddims[oi], oldstrides[oi]); diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 1cf2e6e85..1f71f480e 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -545,6 +545,15 @@ class TestRegression(TestCase): a = np.ones((0,2)) a.shape = (-1,2) + def test_reshape_trailing_ones_strides(self): + # Github issue gh-2949, bad strides for trailing ones of new shape + a = np.zeros(12, dtype=np.int32)[::2] # not contiguous + strides_c = (16, 8, 8, 8) + strides_f = (8, 24, 48, 48) + assert_equal(a.reshape(3, 2, 1, 1).strides, strides_c) + assert_equal(a.reshape(3, 2, 1, 1, order='F').strides, strides_f) + assert_equal(np.array(0, dtype=np.int32).reshape(1,1).strides, (4,4)) + def test_repeat_discont(self, level=rlevel): """Ticket #352""" a = np.arange(12).reshape(4,3)[:,2] |