summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2017-07-17 20:37:32 +0300
committermattip <matti.picus@gmail.com>2017-07-17 20:37:32 +0300
commitf831bec2ed0fbc4ac7c5a15e3abf28fcc879b5ff (patch)
treef2b784bd9a374bee3383c19afad8c4e6ee23f99b /numpy
parentc0609e33884d4dbbbd539966a716e90514a6842b (diff)
downloadnumpy-f831bec2ed0fbc4ac7c5a15e3abf28fcc879b5ff.tar.gz
BUG: Check for exception in sort functions, add tests
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src16
-rw-r--r--numpy/core/src/multiarray/item_selection.c6
-rw-r--r--numpy/core/tests/test_multiarray.py17
3 files changed, 31 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index 336ec70c7..86cf435d9 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -2748,6 +2748,15 @@ OBJECT_compare(PyObject **ip1, PyObject **ip2, PyArrayObject *NPY_UNUSED(ap))
* the alignment of pointers, so it doesn't need to be handled
* here.
*/
+
+ int ret;
+ /*
+ * work around gh-3879, we cannot abort an in-progress quicksort
+ * so at least do not raise again
+ */
+ if (PyErr_Occurred()) {
+ return 0;
+ }
if ((*ip1 == NULL) || (*ip2 == NULL)) {
if (ip1 == ip2) {
return 1;
@@ -2758,7 +2767,12 @@ OBJECT_compare(PyObject **ip1, PyObject **ip2, PyArrayObject *NPY_UNUSED(ap))
return 1;
}
- if (PyObject_RichCompareBool(*ip1, *ip2, Py_LT) == 1) {
+ ret = PyObject_RichCompareBool(*ip1, *ip2, Py_LT);
+ if (ret < 0) {
+ /* error occurred, avoid the next call to PyObject_RichCompareBool */
+ return 0;
+ }
+ if (ret == 1) {
return -1;
}
else if (PyObject_RichCompareBool(*ip1, *ip2, Py_GT) == 1) {
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index e60145508..9a6ed4d2a 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -869,12 +869,9 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
if (part == NULL) {
ret = sort(bufptr, N, op);
-#if defined(NPY_PY3K)
- /* Object comparisons may raise an exception in Python 3 */
if (hasrefs && PyErr_Occurred()) {
ret = -1;
}
-#endif
if (ret < 0) {
goto fail;
}
@@ -885,12 +882,9 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
npy_intp i;
for (i = 0; i < nkth; ++i) {
ret = part(bufptr, N, kth[i], pivots, &npiv, op);
-#if defined(NPY_PY3K)
- /* Object comparisons may raise an exception in Python 3 */
if (hasrefs && PyErr_Occurred()) {
ret = -1;
}
-#endif
if (ret < 0) {
goto fail;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 897450e7f..c182a4d59 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -36,7 +36,7 @@ from numpy.testing import (
)
# Need to test an object that does not fully implement math interface
-from datetime import timedelta
+from datetime import timedelta, datetime
if sys.version_info[:2] > (3, 2):
@@ -1476,6 +1476,21 @@ class TestMethods(TestCase):
arr = np.empty(1000, dt)
arr[::-1].sort()
+ def test_sort_raises(self):
+ #gh-9404
+ arr = np.array([0, datetime.now(), 1], dtype=object)
+ for kind in ['q', 'm', 'h']:
+ assert_raises(TypeError, arr.sort, kind=kind)
+ #gh-3879
+ class Raiser(object):
+ def raises_anything(*args, **kwargs):
+ raise TypeError("SOMETHING ERRORED")
+ __eq__ = __ne__ = __lt__ = __gt__ = __ge__ = __le__ = raises_anything
+ arr = np.array([[Raiser(), n] for n in range(10)]).reshape(-1)
+ np.random.shuffle(arr)
+ for kind in ['q', 'm', 'h']:
+ assert_raises(TypeError, arr.sort, kind=kind)
+
def test_sort_degraded(self):
# test degraded dataset would take minutes to run with normal qsort
d = np.arange(1000000)