diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/code_generators/array_api_order.txt | 1 | ||||
-rw-r--r-- | numpy/core/defchararray.py | 56 | ||||
-rw-r--r-- | numpy/core/src/_sortmodule.c.src | 17 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 181 | ||||
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 1 | ||||
-rw-r--r-- | numpy/core/src/ucsnarrow.c | 5 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 21 |
7 files changed, 225 insertions, 57 deletions
diff --git a/numpy/core/code_generators/array_api_order.txt b/numpy/core/code_generators/array_api_order.txt index f47ab8aa5..ac6341d6e 100644 --- a/numpy/core/code_generators/array_api_order.txt +++ b/numpy/core/code_generators/array_api_order.txt @@ -69,3 +69,4 @@ PyArray_ScalarKind PyArray_CanCoerceScalar PyArray_NewFlagsObject PyArray_CanCastScalar +PyArray_CompareUCS4 diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py index aeba32fc5..f79389967 100644 --- a/numpy/core/defchararray.py +++ b/numpy/core/defchararray.py @@ -41,34 +41,34 @@ class chararray(ndarray): raise ValueError, "Can only create a chararray from string data." - def _richcmpfunc(self, other, op): - b = broadcast(self, other) - result = empty(b.shape, dtype=bool) - res = result.flat - for k, val in enumerate(b): - r1 = val[0].rstrip('\x00') - r2 = val[1] - res[k] = eval("r1 %s r2" % op, {'r1':r1,'r2':r2}) - return result - - # these should probably be moved to C - def __eq__(self, other): - return self._richcmpfunc(other, '==') - - def __ne__(self, other): - return self._richcmpfunc(other, '!=') - - def __ge__(self, other): - return self._richcmpfunc(other, '>=') - - def __le__(self, other): - return self._richcmpfunc(other, '<=') - - def __gt__(self, other): - return self._richcmpfunc(other, '>') - - def __lt__(self, other): - return self._richcmpfunc(other, '<') +## def _richcmpfunc(self, other, op): +## b = broadcast(self, other) +## result = empty(b.shape, dtype=bool) +## res = result.flat +## for k, val in enumerate(b): +## r1 = val[0].rstrip('\x00') +## r2 = val[1] +## res[k] = eval("r1 %s r2" % op, {'r1':r1,'r2':r2}) +## return result + + # these have been moved to C +## def __eq__(self, other): +## return self._richcmpfunc(other, '==') + +## def __ne__(self, other): +## return self._richcmpfunc(other, '!=') + +## def __ge__(self, other): +## return self._richcmpfunc(other, '>=') + +## def __le__(self, other): +## return self._richcmpfunc(other, '<=') + +## def __gt__(self, other): +## return self._richcmpfunc(other, '>') + +## def __lt__(self, other): +## return self._richcmpfunc(other, '<') def __add__(self, other): b = broadcast(self, other) diff --git a/numpy/core/src/_sortmodule.c.src b/numpy/core/src/_sortmodule.c.src index 8f40e47ee..e152fe29f 100644 --- a/numpy/core/src/_sortmodule.c.src +++ b/numpy/core/src/_sortmodule.c.src @@ -355,23 +355,10 @@ static int } /**end repeat**/ -static int -unincmp(Py_UNICODE *s1, Py_UNICODE *s2, register int len) -{ - register Py_UNICODE c1, c2; - while(len-- > 0) { - c1 = *s1++; - c2 = *s2++; - if (c1 != c2) - return (c1 < c2) ? -1 : 1; - } - return 0; -} - /**begin repeat #TYPE=STRING,UNICODE# -#comp=strncmp,unincmp# -#type=char *, Py_UNICODE *# +#comp=strncmp,PyArray_CompareUCS4# +#type=char *, PyArray_UCS4 *# */ static void @TYPE@_amergesort0(intp *pl, intp *pr, @type@*v, intp *pw, int elsize) diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index b37201ee8..4fec8a726 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -3404,6 +3404,150 @@ array_str(PyArrayObject *self) return s; } +/*OBJECT_API +*/ +static int +PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len) +{ + register PyArray_UCS4 c1, c2; + while(len-- > 0) { + c1 = *s1++; + c2 = *s2++; + if (c1 != c2) { + return (c1 < c2) ? -1 : 1; + } + } + return 0; +} + + +static int +_compare_strings(PyObject *result, PyArrayMultiIterObject *multi, + int cmp_op, int len, void *func) +{ + PyArrayIterObject *iself, *iother; + Bool *dptr; + intp size; + int val; + int (*cmpfunc)(void *, void *, size_t); + + cmpfunc = func; + dptr = (Bool *)PyArray_DATA(result); + iself = multi->iters[0]; + iother = multi->iters[1]; + size = multi->size; + while(size--) { + val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, + len); + switch (cmp_op) { + case Py_EQ: + *dptr = (val == 0); + break; + case Py_NE: + *dptr = (val != 0); + break; + case Py_LT: + *dptr = (val < 0); + break; + case Py_LE: + *dptr = (val <= 0); + break; + case Py_GT: + *dptr = (val > 0); + break; + case Py_GE: + *dptr = (val >= 0); + break; + default: + PyErr_SetString(PyExc_RuntimeError, + "bad comparison operator"); + return -1; + } + PyArray_ITER_NEXT(iself); + PyArray_ITER_NEXT(iother); + dptr += 1; + } + return 0; +} + +static PyObject * +_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) +{ + PyObject *result; + PyArrayMultiIterObject *mit; + double prior1, prior2; + int N, val; + + /* Cast arrays to a common type */ + if (self->descr->type != other->descr->type) { + PyObject *new; + if (self->descr->type_num == PyArray_STRING && \ + other->descr->type_num == PyArray_UNICODE) { + Py_INCREF(other); + Py_INCREF(other->descr); + new = PyArray_FromAny((PyObject *)self, other->descr, + 0, 0, 0, NULL); + if (new == NULL) return NULL; + self = (PyArrayObject *)new; + } + else if (self->descr->type_num == PyArray_UNICODE && \ + other->descr->type_num == PyArray_STRING) { + Py_INCREF(self); + Py_INCREF(self->descr); + new = PyArray_FromAny((PyObject *)other, self->descr, + 0, 0, 0, NULL); + if (new == NULL) return NULL; + other = (PyArrayObject *)new; + } + else { + PyErr_SetString(PyExc_TypeError, + "invalid string data-types" + "in comparison"); + return NULL; + } + } + else { + Py_INCREF(self); + Py_INCREF(other); + } + + /* Broad-cast the arrays to a common shape */ + mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other); + Py_DECREF(self); + Py_DECREF(other); + if (mit == NULL) return NULL; + + /* Choose largest size for result array */ + if (self->descr->elsize > other->descr->elsize) + N = self->descr->elsize; + else + N = other->descr->elsize; + + result = PyArray_NewFromDescr(&PyArray_Type, + PyArray_DescrFromType(PyArray_BOOL), + mit->nd, + mit->dimensions, + NULL, NULL, 0, + (PyObject *) + (prior2 > prior1 ? self : other)); + if (result == NULL) goto finish; + + if (self->descr->type_num == PyArray_STRING) { + val = _compare_strings(result, mit, cmp_op, N, strncmp); + } + else { + val = _compare_strings(result, mit, cmp_op, + N/sizeof(PyArray_UCS4), + PyArray_CompareUCS4); + } + + if (val < 0) {Py_DECREF(result); result = NULL;} + + finish: + Py_DECREF(mit); + return result; +} + static PyObject * array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) { @@ -3413,11 +3557,13 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) switch (cmp_op) { case Py_LT: - return PyArray_GenericBinaryFunction(self, other, - n_ops.less); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.less); + break; case Py_LE: - return PyArray_GenericBinaryFunction(self, other, - n_ops.less_equal); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.less_equal); + break; case Py_EQ: if (other == Py_None) { Py_INCREF(Py_False); @@ -3462,7 +3608,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } - return result; + break; case Py_NE: if (other == Py_None) { Py_INCREF(Py_True); @@ -3501,16 +3647,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } - return result; + break; case Py_GT: - return PyArray_GenericBinaryFunction(self, other, - n_ops.greater); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.greater); + break; case Py_GE: - return PyArray_GenericBinaryFunction(self, - other, - n_ops.greater_equal); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.greater_equal); + break; } - return NULL; + if (result == Py_NotImplemented) { + /* Try to handle string comparisons */ + if (self->descr->type_num == PyArray_OBJECT) return result; + array_other = PyArray_FromObject(other,PyArray_NOTYPE, 0, 0); + if (PyArray_ISSTRING(self) || PyArray_ISSTRING(array_other)) { + result = _strings_richcompare(self, (PyArrayObject *) + array_other, cmp_op); + } + Py_DECREF(array_other); + } + return result; } static PyObject * diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 740d3fd24..7a6ad50b5 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -1729,6 +1729,7 @@ PyArray_ConvertToCommonType(PyObject *op, int *retn) } + /*MULTIARRAY_API Numeric.choose() */ diff --git a/numpy/core/src/ucsnarrow.c b/numpy/core/src/ucsnarrow.c index 9eb668b77..990d678c6 100644 --- a/numpy/core/src/ucsnarrow.c +++ b/numpy/core/src/ucsnarrow.c @@ -1,7 +1,8 @@ /* Functions only needed on narrow builds of Python - for converting back and forth between the NumPy Unicode data-type (always 4-byte) + for converting back and forth between the NumPy Unicode data-type + (always 4-byte) and the Python Unicode scalar (2-bytes on a narrow build). - */ +*/ /* the ucs2 buffer must be large enough to hold 2*ucs4length characters due to the use of surrogate pairs. diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 47e005b75..877865546 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -257,6 +257,27 @@ class test_fancy_indexing(ScipyTestCase): x[:,:,(0,)] = 2.0 assert_array_equal(x, array([[[2.0]]])) +class test_string_compare(ScipyTestCase): + def check_string(self): + g1 = array(["This","is","example"]) + g2 = array(["This","was","example"]) + assert_array_equal(g1 == g2, [True, False, True]) + assert_array_equal(g1 != g2, [False, True, False]) + assert_array_equal(g1 <= g2, [True, True, True]) + assert_array_equal(g1 >= g2, [True, False, True]) + assert_array_equal(g1 < g2, [False, True, False]) + assert_array_equal(g1 > g2, [False, False, False]) + + def check_unicode(self): + g1 = array([u"This",u"is",u"example"]) + g2 = array([u"This",u"was",u"example"]) + assert_array_equal(g1 == g2, [True, False, True]) + assert_array_equal(g1 != g2, [False, True, False]) + assert_array_equal(g1 <= g2, [True, True, True]) + assert_array_equal(g1 >= g2, [True, False, True]) + assert_array_equal(g1 < g2, [False, True, False]) + assert_array_equal(g1 > g2, [False, False, False]) + # Import tests from unicode set_local_path() |