summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2013-01-24 22:54:47 +0100
committerSebastian Berg <sebastian@sipsolutions.net>2013-01-25 13:16:21 +0100
commit49bf2045e1cc7d1fb7e2ce771fbb1636c9d28e5a (patch)
tree2fbfd16f0f0d72c053ecf272058c6594db5db980 /numpy
parent4600b2fe1d7ebafef98858572b603f2a8d19db4d (diff)
downloadnumpy-49bf2045e1cc7d1fb7e2ce771fbb1636c9d28e5a.tar.gz
BUG: Fix strides of trailing 1s when reshaping
When adding ones to the shape for non contiguous arrays reshapes (i.e. either the first array is not contiguous or the the reshape order does not match it). The strides of trailing ones were not set. For example reshape of (6,) to (6,1,1). Previously this occured rarely because of removed special handleing when only ones were added or removed.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/shape.c31
-rw-r--r--numpy/core/tests/test_regression.py9
2 files changed, 37 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c
index 97ddb2e61..4223e49f6 100644
--- a/numpy/core/src/multiarray/shape.c
+++ b/numpy/core/src/multiarray/shape.c
@@ -355,10 +355,14 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims,
int oldnd;
npy_intp olddims[NPY_MAXDIMS];
npy_intp oldstrides[NPY_MAXDIMS];
- npy_intp np, op;
+ npy_intp np, op, last_stride;
int oi, oj, ok, ni, nj, nk;
oldnd = 0;
+ /*
+ * Remove axes with dimension 1 from the old array. They have no effect
+ * but would need special cases since their strides do not matter.
+ */
for (oi = 0; oi < PyArray_NDIM(self); oi++) {
if (PyArray_DIMS(self)[oi]!= 1) {
olddims[oldnd] = PyArray_DIMS(self)[oi];
@@ -390,27 +394,31 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims,
/* different total sizes; no hope */
return 0;
}
- /* the current code does not handle 0-sized arrays, so give up */
+
if (np == 0) {
+ /* the current code does not handle 0-sized arrays, so give up */
return 0;
}
+ /* oi to oj and ni to nj give the axis ranges currently worked with */
oi = 0;
oj = 1;
ni = 0;
nj = 1;
- while(ni < newnd && oi < oldnd) {
+ while (ni < newnd && oi < oldnd) {
np = newdims[ni];
op = olddims[oi];
while (np != op) {
if (np < op) {
+ /* Misses trailing 1s, these are handled later */
np *= newdims[nj++];
} else {
op *= olddims[oj++];
}
}
+ /* Check whether the original axes can be combined */
for (ok = oi; ok < oj - 1; ok++) {
if (is_f_order) {
if (oldstrides[ok+1] != olddims[ok]*oldstrides[ok]) {
@@ -427,6 +435,7 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims,
}
}
+ /* Calculate new strides for all axes currently worked with */
if (is_f_order) {
newstrides[ni] = oldstrides[oi];
for (nk = ni + 1; nk < nj; nk++) {
@@ -445,6 +454,22 @@ _attempt_nocopy_reshape(PyArrayObject *self, int newnd, npy_intp* newdims,
}
/*
+ * Set strides corresponding to trailing 1s of the new shape.
+ */
+ if (ni >= 1) {
+ last_stride = newstrides[ni - 1];
+ }
+ else {
+ last_stride = PyArray_ITEMSIZE(self);
+ }
+ if (is_f_order) {
+ last_stride *= newdims[ni - 1];
+ }
+ for (nk = ni; nk < newnd; nk++) {
+ newstrides[nk] = last_stride;
+ }
+
+ /*
fprintf(stderr, "success: _attempt_nocopy_reshape (");
for (oi=0; oi<oldnd; oi++)
fprintf(stderr, "(%d,%d), ", olddims[oi], oldstrides[oi]);
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 1cf2e6e85..1f71f480e 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -545,6 +545,15 @@ class TestRegression(TestCase):
a = np.ones((0,2))
a.shape = (-1,2)
+ def test_reshape_trailing_ones_strides(self):
+ # Github issue gh-2949, bad strides for trailing ones of new shape
+ a = np.zeros(12, dtype=np.int32)[::2] # not contiguous
+ strides_c = (16, 8, 8, 8)
+ strides_f = (8, 24, 48, 48)
+ assert_equal(a.reshape(3, 2, 1, 1).strides, strides_c)
+ assert_equal(a.reshape(3, 2, 1, 1, order='F').strides, strides_f)
+ assert_equal(np.array(0, dtype=np.int32).reshape(1,1).strides, (4,4))
+
def test_repeat_discont(self, level=rlevel):
"""Ticket #352"""
a = np.arange(12).reshape(4,3)[:,2]