summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-10-02 12:12:41 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-10-02 12:12:41 -0600
commitdf0afda4c69e9e1fd47afcb7d79236bc101c502f (patch)
tree2ee08b902072418502119cbd6a0f5c9f21bddc0f /numpy
parent21367df4ba6dbf4a01e5e4634b2e20ddb1f4c401 (diff)
parent0862e89fb51b2e6fc2dfe74e6166a218b67ff06d (diff)
downloadnumpy-df0afda4c69e9e1fd47afcb7d79236bc101c502f.tar.gz
Merge pull request #6312 from behzadnouri/object-lexsort
adds lexsort for arrays with object dtype
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c31
-rw-r--r--numpy/core/tests/test_multiarray.py16
2 files changed, 42 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index d3b9a036d..ec0717bd6 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1427,9 +1427,10 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
goto fail;
}
}
- if (!PyArray_DESCR(mps[i])->f->argsort[NPY_MERGESORT]) {
+ if (!PyArray_DESCR(mps[i])->f->argsort[NPY_MERGESORT]
+ && !PyArray_DESCR(mps[i])->f->compare) {
PyErr_Format(PyExc_TypeError,
- "merge sort not available for item %zd", i);
+ "item %zd type does not have compare function", i);
goto fail;
}
if (!object
@@ -1520,15 +1521,25 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
*iptr++ = i;
}
for (j = 0; j < n; j++) {
+ int rcode;
elsize = PyArray_DESCR(mps[j])->elsize;
astride = PyArray_STRIDES(mps[j])[axis];
argsort = PyArray_DESCR(mps[j])->f->argsort[NPY_MERGESORT];
+ if(argsort == NULL) {
+ argsort = npy_amergesort;
+ }
_unaligned_strided_byte_copy(valbuffer, (npy_intp) elsize,
its[j]->dataptr, astride, N, elsize);
if (swaps[j]) {
_strided_byte_swap(valbuffer, (npy_intp) elsize, N, elsize);
}
- if (argsort(valbuffer, (npy_intp *)indbuffer, N, mps[j]) < 0) {
+ rcode = argsort(valbuffer, (npy_intp *)indbuffer, N, mps[j]);
+#if defined(NPY_PY3K)
+ if (rcode < 0 || (PyDataType_REFCHK(PyArray_DESCR(mps[j]))
+ && PyErr_Occurred())) {
+#else
+ if (rcode < 0) {
+#endif
PyDataMem_FREE(valbuffer);
PyDataMem_FREE(indbuffer);
free(swaps);
@@ -1551,9 +1562,19 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
*iptr++ = i;
}
for (j = 0; j < n; j++) {
+ int rcode;
argsort = PyArray_DESCR(mps[j])->f->argsort[NPY_MERGESORT];
- if (argsort(its[j]->dataptr, (npy_intp *)rit->dataptr,
- N, mps[j]) < 0) {
+ if(argsort == NULL) {
+ argsort = npy_amergesort;
+ }
+ rcode = argsort(its[j]->dataptr,
+ (npy_intp *)rit->dataptr, N, mps[j]);
+#if defined(NPY_PY3K)
+ if (rcode < 0 || (PyDataType_REFCHK(PyArray_DESCR(mps[j]))
+ && PyErr_Occurred())) {
+#else
+ if (rcode < 0) {
+#endif
goto fail;
}
PyArray_ITER_NEXT(its[j]);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 872f9bde4..9fd08e023 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -3207,6 +3207,22 @@ class TestLexsort(TestCase):
expected_idx = np.array([2, 1, 0])
assert_array_equal(idx, expected_idx)
+ def test_object(self): # gh-6312
+ a = np.random.choice(10, 1000)
+ b = np.random.choice(['abc', 'xy', 'wz', 'efghi', 'qwst', 'x'], 1000)
+
+ for u in a, b:
+ left = np.lexsort((u.astype('O'),))
+ right = np.argsort(u, kind='mergesort')
+ assert_array_equal(left, right)
+
+ for u, v in (a, b), (b, a):
+ idx = np.lexsort((u, v))
+ assert_array_equal(idx, np.lexsort((u.astype('O'), v)))
+ assert_array_equal(idx, np.lexsort((u, v.astype('O'))))
+ u, v = np.array(u, dtype='object'), np.array(v, dtype='object')
+ assert_array_equal(idx, np.lexsort((u, v)))
+
class TestIO(object):
"""Test tofile, fromfile, tobytes, and fromstring"""