diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-04-14 17:38:11 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-04-14 17:38:11 +0000 |
commit | 5567185f1bbe1f568c2bcbaeb34ea8df84e83610 (patch) | |
tree | 4ab9f997707c58a3a366d7431c8b4098ca429515 /numpy/core/src/arrayobject.c | |
parent | 036d7a3724a89d880339c89a1aff171d7e4c1e10 (diff) | |
download | numpy-5567185f1bbe1f568c2bcbaeb34ea8df84e83610.tar.gz |
Add string (and unicode) comparison to array objects.
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 181 |
1 files changed, 169 insertions, 12 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index b37201ee8..4fec8a726 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -3404,6 +3404,150 @@ array_str(PyArrayObject *self) return s; } +/*OBJECT_API +*/ +static int +PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len) +{ + register PyArray_UCS4 c1, c2; + while(len-- > 0) { + c1 = *s1++; + c2 = *s2++; + if (c1 != c2) { + return (c1 < c2) ? -1 : 1; + } + } + return 0; +} + + +static int +_compare_strings(PyObject *result, PyArrayMultiIterObject *multi, + int cmp_op, int len, void *func) +{ + PyArrayIterObject *iself, *iother; + Bool *dptr; + intp size; + int val; + int (*cmpfunc)(void *, void *, size_t); + + cmpfunc = func; + dptr = (Bool *)PyArray_DATA(result); + iself = multi->iters[0]; + iother = multi->iters[1]; + size = multi->size; + while(size--) { + val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, + len); + switch (cmp_op) { + case Py_EQ: + *dptr = (val == 0); + break; + case Py_NE: + *dptr = (val != 0); + break; + case Py_LT: + *dptr = (val < 0); + break; + case Py_LE: + *dptr = (val <= 0); + break; + case Py_GT: + *dptr = (val > 0); + break; + case Py_GE: + *dptr = (val >= 0); + break; + default: + PyErr_SetString(PyExc_RuntimeError, + "bad comparison operator"); + return -1; + } + PyArray_ITER_NEXT(iself); + PyArray_ITER_NEXT(iother); + dptr += 1; + } + return 0; +} + +static PyObject * +_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) +{ + PyObject *result; + PyArrayMultiIterObject *mit; + double prior1, prior2; + int N, val; + + /* Cast arrays to a common type */ + if (self->descr->type != other->descr->type) { + PyObject *new; + if (self->descr->type_num == PyArray_STRING && \ + other->descr->type_num == PyArray_UNICODE) { + Py_INCREF(other); + Py_INCREF(other->descr); + new = PyArray_FromAny((PyObject *)self, other->descr, + 0, 0, 0, NULL); + if (new == NULL) return NULL; + self = (PyArrayObject *)new; + } + else if (self->descr->type_num == PyArray_UNICODE && \ + other->descr->type_num == PyArray_STRING) { + Py_INCREF(self); + Py_INCREF(self->descr); + new = PyArray_FromAny((PyObject *)other, self->descr, + 0, 0, 0, NULL); + if (new == NULL) return NULL; + other = (PyArrayObject *)new; + } + else { + PyErr_SetString(PyExc_TypeError, + "invalid string data-types" + "in comparison"); + return NULL; + } + } + else { + Py_INCREF(self); + Py_INCREF(other); + } + + /* Broad-cast the arrays to a common shape */ + mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other); + Py_DECREF(self); + Py_DECREF(other); + if (mit == NULL) return NULL; + + /* Choose largest size for result array */ + if (self->descr->elsize > other->descr->elsize) + N = self->descr->elsize; + else + N = other->descr->elsize; + + result = PyArray_NewFromDescr(&PyArray_Type, + PyArray_DescrFromType(PyArray_BOOL), + mit->nd, + mit->dimensions, + NULL, NULL, 0, + (PyObject *) + (prior2 > prior1 ? self : other)); + if (result == NULL) goto finish; + + if (self->descr->type_num == PyArray_STRING) { + val = _compare_strings(result, mit, cmp_op, N, strncmp); + } + else { + val = _compare_strings(result, mit, cmp_op, + N/sizeof(PyArray_UCS4), + PyArray_CompareUCS4); + } + + if (val < 0) {Py_DECREF(result); result = NULL;} + + finish: + Py_DECREF(mit); + return result; +} + static PyObject * array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) { @@ -3413,11 +3557,13 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) switch (cmp_op) { case Py_LT: - return PyArray_GenericBinaryFunction(self, other, - n_ops.less); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.less); + break; case Py_LE: - return PyArray_GenericBinaryFunction(self, other, - n_ops.less_equal); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.less_equal); + break; case Py_EQ: if (other == Py_None) { Py_INCREF(Py_False); @@ -3462,7 +3608,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } - return result; + break; case Py_NE: if (other == Py_None) { Py_INCREF(Py_True); @@ -3501,16 +3647,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } - return result; + break; case Py_GT: - return PyArray_GenericBinaryFunction(self, other, - n_ops.greater); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.greater); + break; case Py_GE: - return PyArray_GenericBinaryFunction(self, - other, - n_ops.greater_equal); + result = PyArray_GenericBinaryFunction(self, other, + n_ops.greater_equal); + break; } - return NULL; + if (result == Py_NotImplemented) { + /* Try to handle string comparisons */ + if (self->descr->type_num == PyArray_OBJECT) return result; + array_other = PyArray_FromObject(other,PyArray_NOTYPE, 0, 0); + if (PyArray_ISSTRING(self) || PyArray_ISSTRING(array_other)) { + result = _strings_richcompare(self, (PyArrayObject *) + array_other, cmp_op); + } + Py_DECREF(array_other); + } + return result; } static PyObject * |