summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2012-04-04 17:10:05 -0600
committerCharles Harris <charlesr.harris@gmail.com>2012-04-05 09:18:54 -0600
commit11859c0a2f113e66fdd8384eea93b688a8e08b22 (patch)
treeb1d89e9a887440e3e541a5381a41299772eb3081 /numpy/core/src
parentb35eaccfae6ec00310e7203a310e3de7b6b8018d (diff)
downloadnumpy-11859c0a2f113e66fdd8384eea93b688a8e08b22.tar.gz
BUG: ticket #2097, fix bounds checking in searchsorted when sorter invoked.
The bounds are checked on the fly. This won't always raise an error if there is an out of bounds index, but only when that index is used in the binary search. Since we don't check that the sorter actually sorts the array, this seems reasonable. Only safety is ensured, not correctness.
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/multiarray/item_selection.c47
1 files changed, 27 insertions, 20 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index c7eb5f87b..f9b5ff66f 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1590,9 +1590,9 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
* @param arr contiguous sorted array to be searched.
* @param key contiguous array of keys.
* @param ret contiguous array of intp for returned indices.
- * @return void
+ * @return int
*/
-static void
+static int
local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
PyArrayObject *sorter, PyArrayObject *ret)
{
@@ -1611,7 +1611,12 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
- if (compare(parr + elsize*psorter[imid], pkey, key) < 0) {
+ npy_intp indx = psorter[imid];
+
+ if (indx < 0 || indx >= nelts) {
+ return -1;
+ }
+ if (compare(parr + elsize*indx, pkey, key) < 0) {
imin = imid + 1;
}
else {
@@ -1622,6 +1627,7 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
pret += 1;
pkey += elsize;
}
+ return 0;
}
@@ -1635,9 +1641,9 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
* @param arr contiguous sorted array to be searched.
* @param key contiguous array of keys.
* @param ret contiguous array of intp for returned indices.
- * @return void
+ * @return int
*/
-static void
+static int
local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
PyArrayObject *sorter, PyArrayObject *ret)
{
@@ -1656,7 +1662,12 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
- if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) {
+ npy_intp indx = psorter[imid];
+
+ if (indx < 0 || indx >= nelts) {
+ return -1;
+ }
+ if (compare(parr + elsize*indx, pkey, key) <= 0) {
imin = imid + 1;
}
else {
@@ -1667,6 +1678,7 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
pret += 1;
pkey += elsize;
}
+ return 0;
}
/*NUMPY_API
@@ -1774,18 +1786,6 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
"sorter.size must equal a.size");
goto fail;
}
- max = PyArray_Max(sorter, 0, NULL);
- if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) {
- PyErr_SetString(PyExc_ValueError,
- "sorter.max() must be less than a.size");
- goto fail;
- }
- min = PyArray_Min(sorter, 0, NULL);
- if (PyLong_AsLong(min) < 0) {
- PyErr_SetString(PyExc_ValueError,
- "sorter elements must be non-negative");
- goto fail;
- }
}
/* ret is a contiguous array of intp type to hold returned indices */
@@ -1809,16 +1809,23 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
}
}
else {
+ int err;
+
if (side == NPY_SEARCHLEFT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
- local_argsearch_left(ap1, ap2, sorter, ret);
+ err = local_argsearch_left(ap1, ap2, sorter, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
else if (side == NPY_SEARCHRIGHT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
- local_argsearch_right(ap1, ap2, sorter, ret);
+ err = local_argsearch_right(ap1, ap2, sorter, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
+ if (err < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "Sorter index out of range.");
+ goto fail;
+ }
Py_DECREF(ap3);
Py_DECREF(sorter);
}