diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2012-04-04 17:10:05 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2012-04-05 09:18:54 -0600 |
commit | 11859c0a2f113e66fdd8384eea93b688a8e08b22 (patch) | |
tree | b1d89e9a887440e3e541a5381a41299772eb3081 /numpy/core/src | |
parent | b35eaccfae6ec00310e7203a310e3de7b6b8018d (diff) | |
download | numpy-11859c0a2f113e66fdd8384eea93b688a8e08b22.tar.gz |
BUG: ticket #2097, fix bounds checking in searchsorted when sorter invoked.
The bounds are checked on the fly. This won't always raise an error if
there is an out of bounds index, but only when that index is used in the
binary search. Since we don't check that the sorter actually sorts the
array, this seems reasonable. Only safety is ensured, not correctness.
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index c7eb5f87b..f9b5ff66f 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1590,9 +1590,9 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) * @param arr contiguous sorted array to be searched. * @param key contiguous array of keys. * @param ret contiguous array of intp for returned indices. - * @return void + * @return int */ -static void +static int local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *sorter, PyArrayObject *ret) { @@ -1611,7 +1611,12 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, npy_intp imax = nelts; while (imin < imax) { npy_intp imid = imin + ((imax - imin) >> 1); - if (compare(parr + elsize*psorter[imid], pkey, key) < 0) { + npy_intp indx = psorter[imid]; + + if (indx < 0 || indx >= nelts) { + return -1; + } + if (compare(parr + elsize*indx, pkey, key) < 0) { imin = imid + 1; } else { @@ -1622,6 +1627,7 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, pret += 1; pkey += elsize; } + return 0; } @@ -1635,9 +1641,9 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, * @param arr contiguous sorted array to be searched. * @param key contiguous array of keys. * @param ret contiguous array of intp for returned indices. - * @return void + * @return int */ -static void +static int local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *sorter, PyArrayObject *ret) { @@ -1656,7 +1662,12 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, npy_intp imax = nelts; while (imin < imax) { npy_intp imid = imin + ((imax - imin) >> 1); - if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) { + npy_intp indx = psorter[imid]; + + if (indx < 0 || indx >= nelts) { + return -1; + } + if (compare(parr + elsize*indx, pkey, key) <= 0) { imin = imid + 1; } else { @@ -1667,6 +1678,7 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, pret += 1; pkey += elsize; } + return 0; } /*NUMPY_API @@ -1774,18 +1786,6 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, "sorter.size must equal a.size"); goto fail; } - max = PyArray_Max(sorter, 0, NULL); - if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) { - PyErr_SetString(PyExc_ValueError, - "sorter.max() must be less than a.size"); - goto fail; - } - min = PyArray_Min(sorter, 0, NULL); - if (PyLong_AsLong(min) < 0) { - PyErr_SetString(PyExc_ValueError, - "sorter elements must be non-negative"); - goto fail; - } } /* ret is a contiguous array of intp type to hold returned indices */ @@ -1809,16 +1809,23 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, } } else { + int err; + if (side == NPY_SEARCHLEFT) { NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_argsearch_left(ap1, ap2, sorter, ret); + err = local_argsearch_left(ap1, ap2, sorter, ret); NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); } else if (side == NPY_SEARCHRIGHT) { NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_argsearch_right(ap1, ap2, sorter, ret); + err = local_argsearch_right(ap1, ap2, sorter, ret); NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); } + if (err < 0) { + PyErr_SetString(PyExc_ValueError, + "Sorter index out of range."); + goto fail; + } Py_DECREF(ap3); Py_DECREF(sorter); } |