summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-02-24 18:00:13 -0800
committerCharles Harris <charlesr.harris@gmail.com>2013-02-24 18:00:13 -0800
commit6de7a4be70c894e3d63ac952bd20a74c171e6413 (patch)
tree9dd54e08e31d80437eebdc27836673735c8fcf43
parent17774a6d58b889b6b7a25d6af5b66f2148d47f41 (diff)
parent230ee3aa201552a8a9fa13c4b319f68cbd504d85 (diff)
downloadnumpy-6de7a4be70c894e3d63ac952bd20a74c171e6413.tar.gz
Merge pull request #3002 from seberg/issue-3001
BUG: Incref items in np.take on error as they are decrefed later
-rw-r--r--numpy/core/src/multiarray/item_selection.c65
-rw-r--r--numpy/core/tests/test_indexerrors.py4
-rw-r--r--numpy/core/tests/test_item_selection.py92
-rw-r--r--numpy/core/tests/test_regression.py11
4 files changed, 126 insertions, 46 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 7c8041cf1..4adecf193 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -31,10 +31,11 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
PyArray_Descr *dtype;
PyArray_FastTakeFunc *func;
PyArrayObject *obj = NULL, *self, *indices;
- npy_intp nd, i, j, n, m, max_item, tmp, chunk, nelem;
+ npy_intp nd, i, j, n, m, k, max_item, tmp, chunk, itemsize, nelem;
npy_intp shape[NPY_MAXDIMS];
- char *src, *dest;
+ char *src, *dest, *tmp_src;
int err;
+ npy_bool needs_refcounting;
indices = NULL;
self = (PyArrayObject *)PyArray_CheckAxis(self0, &axis,
@@ -110,9 +111,18 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
max_item = PyArray_DIMS(self)[axis];
nelem = chunk;
- chunk = chunk * PyArray_DESCR(obj)->elsize;
+ itemsize = PyArray_ITEMSIZE(obj);
+ chunk = chunk * itemsize;
src = PyArray_DATA(self);
dest = PyArray_DATA(obj);
+ needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(self));
+
+ if ((max_item == 0) && (PyArray_SIZE(obj) != 0)) {
+ /* Index error, since that is the usual error for raise mode */
+ PyErr_SetString(PyExc_IndexError,
+ "cannot do a non-empty take from an empty axes.");
+ goto fail;
+ }
func = PyArray_DESCR(self)->f->fasttake;
if (func == NULL) {
@@ -124,8 +134,20 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
if (check_and_adjust_index(&tmp, max_item, axis) < 0) {
goto fail;
}
- memmove(dest, src + tmp*chunk, chunk);
- dest += chunk;
+ tmp_src = src + tmp * chunk;
+ if (needs_refcounting) {
+ for (k=0; k < nelem; k++) {
+ PyArray_Item_INCREF(tmp_src, PyArray_DESCR(self));
+ PyArray_Item_XDECREF(dest, PyArray_DESCR(self));
+ memmove(dest, tmp_src, itemsize);
+ dest += itemsize;
+ tmp_src += itemsize;
+ }
+ }
+ else {
+ memmove(dest, tmp_src, chunk);
+ dest += chunk;
+ }
}
src += chunk*max_item;
}
@@ -144,8 +166,20 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
tmp -= max_item;
}
}
- memmove(dest, src + tmp*chunk, chunk);
- dest += chunk;
+ tmp_src = src + tmp * chunk;
+ if (needs_refcounting) {
+ for (k=0; k < nelem; k++) {
+ PyArray_Item_INCREF(tmp_src, PyArray_DESCR(self));
+ PyArray_Item_XDECREF(dest, PyArray_DESCR(self));
+ memmove(dest, tmp_src, itemsize);
+ dest += itemsize;
+ tmp_src += itemsize;
+ }
+ }
+ else {
+ memmove(dest, tmp_src, chunk);
+ dest += chunk;
+ }
}
src += chunk*max_item;
}
@@ -160,8 +194,20 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
else if (tmp >= max_item) {
tmp = max_item - 1;
}
- memmove(dest, src+tmp*chunk, chunk);
- dest += chunk;
+ tmp_src = src + tmp * chunk;
+ if (needs_refcounting) {
+ for (k=0; k < nelem; k++) {
+ PyArray_Item_INCREF(tmp_src, PyArray_DESCR(self));
+ PyArray_Item_XDECREF(dest, PyArray_DESCR(self));
+ memmove(dest, tmp_src, itemsize);
+ dest += itemsize;
+ tmp_src += itemsize;
+ }
+ }
+ else {
+ memmove(dest, tmp_src, chunk);
+ dest += chunk;
+ }
}
src += chunk*max_item;
}
@@ -176,7 +222,6 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
}
}
- PyArray_INCREF(obj);
Py_XDECREF(indices);
Py_XDECREF(self);
if (out != NULL && out != obj) {
diff --git a/numpy/core/tests/test_indexerrors.py b/numpy/core/tests/test_indexerrors.py
index af10df9b3..2f9c5d4d2 100644
--- a/numpy/core/tests/test_indexerrors.py
+++ b/numpy/core/tests/test_indexerrors.py
@@ -10,6 +10,8 @@ class TestIndexErrors(TestCase):
x = np.empty((2, 3, 0, 4))
assert_raises(IndexError, x.take, [0], axis=2)
assert_raises(IndexError, x.take, [1], axis=2)
+ assert_raises(IndexError, x.take, [0], axis=2, mode='wrap')
+ assert_raises(IndexError, x.take, [0], axis=2, mode='clip')
def test_take_from_object(self):
# Check exception taking from object array
@@ -21,6 +23,8 @@ class TestIndexErrors(TestCase):
assert_raises(IndexError, d.take, [1], axis=1)
assert_raises(IndexError, d.take, [0], axis=1)
assert_raises(IndexError, d.take, [0])
+ assert_raises(IndexError, d.take, [0], mode='wrap')
+ assert_raises(IndexError, d.take, [0], mode='clip')
def test_multiindex_exceptions(self):
a = np.empty(5, dtype=object)
diff --git a/numpy/core/tests/test_item_selection.py b/numpy/core/tests/test_item_selection.py
index f35e04c4f..47de43012 100644
--- a/numpy/core/tests/test_item_selection.py
+++ b/numpy/core/tests/test_item_selection.py
@@ -2,42 +2,62 @@ import numpy as np
from numpy.testing import *
import sys, warnings
-def test_take():
- a = [[1, 2], [3, 4]]
- a_str = [['1','2'],['3','4']]
- modes = ['raise', 'wrap', 'clip']
- indices = [-1, 4]
- index_arrays = [np.empty(0, dtype=np.intp),
- np.empty(tuple(), dtype=np.intp),
- np.empty((1,1), dtype=np.intp)]
- real_indices = {}
- real_indices['raise'] = {-1:1, 4:IndexError}
- real_indices['wrap'] = {-1:1, 4:0}
- real_indices['clip'] = {-1:0, 4:1}
- # Currently all types but object, use the same function generation.
- # So it should not be necessary to test all, but the code does support it.
- types = np.int, np.object
- for t in types:
- ta = np.array(a if issubclass(t, np.number) else a_str, dtype=t)
- tresult = list(ta.T.copy())
- for index_array in index_arrays:
- if index_array.size != 0:
- tresult[0].shape = (2,) + index_array.shape
- tresult[1].shape = (2,) + index_array.shape
- for mode in modes:
- for index in indices:
- real_index = real_indices[mode][index]
- if real_index is IndexError and index_array.size != 0:
- index_array.put(0, index)
- assert_raises(IndexError, ta.take, index_array,
- mode=mode, axis=1)
- elif index_array.size != 0:
- index_array.put(0, index)
- res = ta.take(index_array, mode=mode, axis=1)
- assert_array_equal(res, tresult[real_index])
- else:
- res = ta.take(index_array, mode=mode, axis=1)
- assert_(res.shape == (2,) + index_array.shape)
+
+class TestTake(TestCase):
+ def test_simple(self):
+ a = [[1, 2], [3, 4]]
+ a_str = [[b'1', b'2'],[b'3', b'4']]
+ modes = ['raise', 'wrap', 'clip']
+ indices = [-1, 4]
+ index_arrays = [np.empty(0, dtype=np.intp),
+ np.empty(tuple(), dtype=np.intp),
+ np.empty((1,1), dtype=np.intp)]
+ real_indices = {}
+ real_indices['raise'] = {-1:1, 4:IndexError}
+ real_indices['wrap'] = {-1:1, 4:0}
+ real_indices['clip'] = {-1:0, 4:1}
+ # Currently all types but object, use the same function generation.
+ # So it should not be necessary to test all. However test also a non
+ # refcounted struct on top of object.
+ types = np.int, np.object, np.dtype([('', 'i', 2)])
+ for t in types:
+ # ta works, even if the array may be odd if buffer interface is used
+ ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t)
+ tresult = list(ta.T.copy())
+ for index_array in index_arrays:
+ if index_array.size != 0:
+ tresult[0].shape = (2,) + index_array.shape
+ tresult[1].shape = (2,) + index_array.shape
+ for mode in modes:
+ for index in indices:
+ real_index = real_indices[mode][index]
+ if real_index is IndexError and index_array.size != 0:
+ index_array.put(0, index)
+ assert_raises(IndexError, ta.take, index_array,
+ mode=mode, axis=1)
+ elif index_array.size != 0:
+ index_array.put(0, index)
+ res = ta.take(index_array, mode=mode, axis=1)
+ assert_array_equal(res, tresult[real_index])
+ else:
+ res = ta.take(index_array, mode=mode, axis=1)
+ assert_(res.shape == (2,) + index_array.shape)
+
+
+ def test_refcounting(self):
+ objects = [object() for i in xrange(10)]
+ for mode in ('raise', 'clip', 'wrap'):
+ a = np.array(objects)
+ b = np.array([2, 2, 4, 5, 3, 5])
+ a.take(b, out=a[:6])
+ del a
+ assert_(all(sys.getrefcount(o) == 3 for o in objects))
+ # not contiguous, example:
+ a = np.array(objects * 2)[::2]
+ a.take(b, out=a[:6])
+ del a
+ assert_(all(sys.getrefcount(o) == 3 for o in objects))
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 9a193b3c1..58d3f2819 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -626,6 +626,17 @@ class TestRegression(TestCase):
np.take(x,[0,2],axis=1,out=b)
assert_array_equal(a,b)
+ def test_take_object_fail(self):
+ # Issue gh-3001
+ d = 123.
+ a = np.array([d, 1], dtype=object)
+ ref_d = sys.getrefcount(d)
+ try:
+ a.take([0, 100])
+ except IndexError:
+ pass
+ assert_(ref_d == sys.getrefcount(d))
+
def test_array_str_64bit(self, level=rlevel):
"""Ticket #501"""
s = np.array([1, np.nan],dtype=np.float64)