diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-10-02 12:12:41 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-10-02 12:12:41 -0600 |
commit | df0afda4c69e9e1fd47afcb7d79236bc101c502f (patch) | |
tree | 2ee08b902072418502119cbd6a0f5c9f21bddc0f /numpy | |
parent | 21367df4ba6dbf4a01e5e4634b2e20ddb1f4c401 (diff) | |
parent | 0862e89fb51b2e6fc2dfe74e6166a218b67ff06d (diff) | |
download | numpy-df0afda4c69e9e1fd47afcb7d79236bc101c502f.tar.gz |
Merge pull request #6312 from behzadnouri/object-lexsort
adds lexsort for arrays with object dtype
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 16 |
2 files changed, 42 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index d3b9a036d..ec0717bd6 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1427,9 +1427,10 @@ PyArray_LexSort(PyObject *sort_keys, int axis) goto fail; } } - if (!PyArray_DESCR(mps[i])->f->argsort[NPY_MERGESORT]) { + if (!PyArray_DESCR(mps[i])->f->argsort[NPY_MERGESORT] + && !PyArray_DESCR(mps[i])->f->compare) { PyErr_Format(PyExc_TypeError, - "merge sort not available for item %zd", i); + "item %zd type does not have compare function", i); goto fail; } if (!object @@ -1520,15 +1521,25 @@ PyArray_LexSort(PyObject *sort_keys, int axis) *iptr++ = i; } for (j = 0; j < n; j++) { + int rcode; elsize = PyArray_DESCR(mps[j])->elsize; astride = PyArray_STRIDES(mps[j])[axis]; argsort = PyArray_DESCR(mps[j])->f->argsort[NPY_MERGESORT]; + if(argsort == NULL) { + argsort = npy_amergesort; + } _unaligned_strided_byte_copy(valbuffer, (npy_intp) elsize, its[j]->dataptr, astride, N, elsize); if (swaps[j]) { _strided_byte_swap(valbuffer, (npy_intp) elsize, N, elsize); } - if (argsort(valbuffer, (npy_intp *)indbuffer, N, mps[j]) < 0) { + rcode = argsort(valbuffer, (npy_intp *)indbuffer, N, mps[j]); +#if defined(NPY_PY3K) + if (rcode < 0 || (PyDataType_REFCHK(PyArray_DESCR(mps[j])) + && PyErr_Occurred())) { +#else + if (rcode < 0) { +#endif PyDataMem_FREE(valbuffer); PyDataMem_FREE(indbuffer); free(swaps); @@ -1551,9 +1562,19 @@ PyArray_LexSort(PyObject *sort_keys, int axis) *iptr++ = i; } for (j = 0; j < n; j++) { + int rcode; argsort = PyArray_DESCR(mps[j])->f->argsort[NPY_MERGESORT]; - if (argsort(its[j]->dataptr, (npy_intp *)rit->dataptr, - N, mps[j]) < 0) { + if(argsort == NULL) { + argsort = npy_amergesort; + } + rcode = argsort(its[j]->dataptr, + (npy_intp *)rit->dataptr, N, mps[j]); +#if defined(NPY_PY3K) + if (rcode < 0 || (PyDataType_REFCHK(PyArray_DESCR(mps[j])) + && PyErr_Occurred())) { +#else + if (rcode < 0) { +#endif goto fail; } PyArray_ITER_NEXT(its[j]); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 872f9bde4..9fd08e023 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3207,6 +3207,22 @@ class TestLexsort(TestCase): expected_idx = np.array([2, 1, 0]) assert_array_equal(idx, expected_idx) + def test_object(self): # gh-6312 + a = np.random.choice(10, 1000) + b = np.random.choice(['abc', 'xy', 'wz', 'efghi', 'qwst', 'x'], 1000) + + for u in a, b: + left = np.lexsort((u.astype('O'),)) + right = np.argsort(u, kind='mergesort') + assert_array_equal(left, right) + + for u, v in (a, b), (b, a): + idx = np.lexsort((u, v)) + assert_array_equal(idx, np.lexsort((u.astype('O'), v))) + assert_array_equal(idx, np.lexsort((u, v.astype('O')))) + u, v = np.array(u, dtype='object'), np.array(v, dtype='object') + assert_array_equal(idx, np.lexsort((u, v))) + class TestIO(object): """Test tofile, fromfile, tobytes, and fromstring""" |