diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-02-24 18:00:13 -0800 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-02-24 18:00:13 -0800 |
commit | 6de7a4be70c894e3d63ac952bd20a74c171e6413 (patch) | |
tree | 9dd54e08e31d80437eebdc27836673735c8fcf43 | |
parent | 17774a6d58b889b6b7a25d6af5b66f2148d47f41 (diff) | |
parent | 230ee3aa201552a8a9fa13c4b319f68cbd504d85 (diff) | |
download | numpy-6de7a4be70c894e3d63ac952bd20a74c171e6413.tar.gz |
Merge pull request #3002 from seberg/issue-3001
BUG: Incref items in np.take on error as they are decrefed later
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 65 | ||||
-rw-r--r-- | numpy/core/tests/test_indexerrors.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_item_selection.py | 92 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 11 |
4 files changed, 126 insertions, 46 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 7c8041cf1..4adecf193 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -31,10 +31,11 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, PyArray_Descr *dtype; PyArray_FastTakeFunc *func; PyArrayObject *obj = NULL, *self, *indices; - npy_intp nd, i, j, n, m, max_item, tmp, chunk, nelem; + npy_intp nd, i, j, n, m, k, max_item, tmp, chunk, itemsize, nelem; npy_intp shape[NPY_MAXDIMS]; - char *src, *dest; + char *src, *dest, *tmp_src; int err; + npy_bool needs_refcounting; indices = NULL; self = (PyArrayObject *)PyArray_CheckAxis(self0, &axis, @@ -110,9 +111,18 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, max_item = PyArray_DIMS(self)[axis]; nelem = chunk; - chunk = chunk * PyArray_DESCR(obj)->elsize; + itemsize = PyArray_ITEMSIZE(obj); + chunk = chunk * itemsize; src = PyArray_DATA(self); dest = PyArray_DATA(obj); + needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(self)); + + if ((max_item == 0) && (PyArray_SIZE(obj) != 0)) { + /* Index error, since that is the usual error for raise mode */ + PyErr_SetString(PyExc_IndexError, + "cannot do a non-empty take from an empty axes."); + goto fail; + } func = PyArray_DESCR(self)->f->fasttake; if (func == NULL) { @@ -124,8 +134,20 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, if (check_and_adjust_index(&tmp, max_item, axis) < 0) { goto fail; } - memmove(dest, src + tmp*chunk, chunk); - dest += chunk; + tmp_src = src + tmp * chunk; + if (needs_refcounting) { + for (k=0; k < nelem; k++) { + PyArray_Item_INCREF(tmp_src, PyArray_DESCR(self)); + PyArray_Item_XDECREF(dest, PyArray_DESCR(self)); + memmove(dest, tmp_src, itemsize); + dest += itemsize; + tmp_src += itemsize; + } + } + else { + memmove(dest, tmp_src, chunk); + dest += chunk; + } } src += chunk*max_item; } @@ -144,8 +166,20 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, tmp -= max_item; } } - memmove(dest, src + tmp*chunk, chunk); - dest += chunk; + tmp_src = src + tmp * chunk; + if (needs_refcounting) { + for (k=0; k < nelem; k++) { + PyArray_Item_INCREF(tmp_src, PyArray_DESCR(self)); + PyArray_Item_XDECREF(dest, PyArray_DESCR(self)); + memmove(dest, tmp_src, itemsize); + dest += itemsize; + tmp_src += itemsize; + } + } + else { + memmove(dest, tmp_src, chunk); + dest += chunk; + } } src += chunk*max_item; } @@ -160,8 +194,20 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, else if (tmp >= max_item) { tmp = max_item - 1; } - memmove(dest, src+tmp*chunk, chunk); - dest += chunk; + tmp_src = src + tmp * chunk; + if (needs_refcounting) { + for (k=0; k < nelem; k++) { + PyArray_Item_INCREF(tmp_src, PyArray_DESCR(self)); + PyArray_Item_XDECREF(dest, PyArray_DESCR(self)); + memmove(dest, tmp_src, itemsize); + dest += itemsize; + tmp_src += itemsize; + } + } + else { + memmove(dest, tmp_src, chunk); + dest += chunk; + } } src += chunk*max_item; } @@ -176,7 +222,6 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, } } - PyArray_INCREF(obj); Py_XDECREF(indices); Py_XDECREF(self); if (out != NULL && out != obj) { diff --git a/numpy/core/tests/test_indexerrors.py b/numpy/core/tests/test_indexerrors.py index af10df9b3..2f9c5d4d2 100644 --- a/numpy/core/tests/test_indexerrors.py +++ b/numpy/core/tests/test_indexerrors.py @@ -10,6 +10,8 @@ class TestIndexErrors(TestCase): x = np.empty((2, 3, 0, 4)) assert_raises(IndexError, x.take, [0], axis=2) assert_raises(IndexError, x.take, [1], axis=2) + assert_raises(IndexError, x.take, [0], axis=2, mode='wrap') + assert_raises(IndexError, x.take, [0], axis=2, mode='clip') def test_take_from_object(self): # Check exception taking from object array @@ -21,6 +23,8 @@ class TestIndexErrors(TestCase): assert_raises(IndexError, d.take, [1], axis=1) assert_raises(IndexError, d.take, [0], axis=1) assert_raises(IndexError, d.take, [0]) + assert_raises(IndexError, d.take, [0], mode='wrap') + assert_raises(IndexError, d.take, [0], mode='clip') def test_multiindex_exceptions(self): a = np.empty(5, dtype=object) diff --git a/numpy/core/tests/test_item_selection.py b/numpy/core/tests/test_item_selection.py index f35e04c4f..47de43012 100644 --- a/numpy/core/tests/test_item_selection.py +++ b/numpy/core/tests/test_item_selection.py @@ -2,42 +2,62 @@ import numpy as np from numpy.testing import * import sys, warnings -def test_take(): - a = [[1, 2], [3, 4]] - a_str = [['1','2'],['3','4']] - modes = ['raise', 'wrap', 'clip'] - indices = [-1, 4] - index_arrays = [np.empty(0, dtype=np.intp), - np.empty(tuple(), dtype=np.intp), - np.empty((1,1), dtype=np.intp)] - real_indices = {} - real_indices['raise'] = {-1:1, 4:IndexError} - real_indices['wrap'] = {-1:1, 4:0} - real_indices['clip'] = {-1:0, 4:1} - # Currently all types but object, use the same function generation. - # So it should not be necessary to test all, but the code does support it. - types = np.int, np.object - for t in types: - ta = np.array(a if issubclass(t, np.number) else a_str, dtype=t) - tresult = list(ta.T.copy()) - for index_array in index_arrays: - if index_array.size != 0: - tresult[0].shape = (2,) + index_array.shape - tresult[1].shape = (2,) + index_array.shape - for mode in modes: - for index in indices: - real_index = real_indices[mode][index] - if real_index is IndexError and index_array.size != 0: - index_array.put(0, index) - assert_raises(IndexError, ta.take, index_array, - mode=mode, axis=1) - elif index_array.size != 0: - index_array.put(0, index) - res = ta.take(index_array, mode=mode, axis=1) - assert_array_equal(res, tresult[real_index]) - else: - res = ta.take(index_array, mode=mode, axis=1) - assert_(res.shape == (2,) + index_array.shape) + +class TestTake(TestCase): + def test_simple(self): + a = [[1, 2], [3, 4]] + a_str = [[b'1', b'2'],[b'3', b'4']] + modes = ['raise', 'wrap', 'clip'] + indices = [-1, 4] + index_arrays = [np.empty(0, dtype=np.intp), + np.empty(tuple(), dtype=np.intp), + np.empty((1,1), dtype=np.intp)] + real_indices = {} + real_indices['raise'] = {-1:1, 4:IndexError} + real_indices['wrap'] = {-1:1, 4:0} + real_indices['clip'] = {-1:0, 4:1} + # Currently all types but object, use the same function generation. + # So it should not be necessary to test all. However test also a non + # refcounted struct on top of object. + types = np.int, np.object, np.dtype([('', 'i', 2)]) + for t in types: + # ta works, even if the array may be odd if buffer interface is used + ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t) + tresult = list(ta.T.copy()) + for index_array in index_arrays: + if index_array.size != 0: + tresult[0].shape = (2,) + index_array.shape + tresult[1].shape = (2,) + index_array.shape + for mode in modes: + for index in indices: + real_index = real_indices[mode][index] + if real_index is IndexError and index_array.size != 0: + index_array.put(0, index) + assert_raises(IndexError, ta.take, index_array, + mode=mode, axis=1) + elif index_array.size != 0: + index_array.put(0, index) + res = ta.take(index_array, mode=mode, axis=1) + assert_array_equal(res, tresult[real_index]) + else: + res = ta.take(index_array, mode=mode, axis=1) + assert_(res.shape == (2,) + index_array.shape) + + + def test_refcounting(self): + objects = [object() for i in xrange(10)] + for mode in ('raise', 'clip', 'wrap'): + a = np.array(objects) + b = np.array([2, 2, 4, 5, 3, 5]) + a.take(b, out=a[:6]) + del a + assert_(all(sys.getrefcount(o) == 3 for o in objects)) + # not contiguous, example: + a = np.array(objects * 2)[::2] + a.take(b, out=a[:6]) + del a + assert_(all(sys.getrefcount(o) == 3 for o in objects)) + if __name__ == "__main__": run_module_suite() diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 9a193b3c1..58d3f2819 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -626,6 +626,17 @@ class TestRegression(TestCase): np.take(x,[0,2],axis=1,out=b) assert_array_equal(a,b) + def test_take_object_fail(self): + # Issue gh-3001 + d = 123. + a = np.array([d, 1], dtype=object) + ref_d = sys.getrefcount(d) + try: + a.take([0, 100]) + except IndexError: + pass + assert_(ref_d == sys.getrefcount(d)) + def test_array_str_64bit(self, level=rlevel): """Ticket #501""" s = np.array([1, np.nan],dtype=np.float64) |