summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c181
1 files changed, 169 insertions, 12 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index b37201ee8..4fec8a726 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -3404,6 +3404,150 @@ array_str(PyArrayObject *self)
return s;
}
+/*OBJECT_API
+*/
+static int
+PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len)
+{
+ register PyArray_UCS4 c1, c2;
+ while(len-- > 0) {
+ c1 = *s1++;
+ c2 = *s2++;
+ if (c1 != c2) {
+ return (c1 < c2) ? -1 : 1;
+ }
+ }
+ return 0;
+}
+
+
+static int
+_compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
+ int cmp_op, int len, void *func)
+{
+ PyArrayIterObject *iself, *iother;
+ Bool *dptr;
+ intp size;
+ int val;
+ int (*cmpfunc)(void *, void *, size_t);
+
+ cmpfunc = func;
+ dptr = (Bool *)PyArray_DATA(result);
+ iself = multi->iters[0];
+ iother = multi->iters[1];
+ size = multi->size;
+ while(size--) {
+ val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr,
+ len);
+ switch (cmp_op) {
+ case Py_EQ:
+ *dptr = (val == 0);
+ break;
+ case Py_NE:
+ *dptr = (val != 0);
+ break;
+ case Py_LT:
+ *dptr = (val < 0);
+ break;
+ case Py_LE:
+ *dptr = (val <= 0);
+ break;
+ case Py_GT:
+ *dptr = (val > 0);
+ break;
+ case Py_GE:
+ *dptr = (val >= 0);
+ break;
+ default:
+ PyErr_SetString(PyExc_RuntimeError,
+ "bad comparison operator");
+ return -1;
+ }
+ PyArray_ITER_NEXT(iself);
+ PyArray_ITER_NEXT(iother);
+ dptr += 1;
+ }
+ return 0;
+}
+
+static PyObject *
+_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
+{
+ PyObject *result;
+ PyArrayMultiIterObject *mit;
+ double prior1, prior2;
+ int N, val;
+
+ /* Cast arrays to a common type */
+ if (self->descr->type != other->descr->type) {
+ PyObject *new;
+ if (self->descr->type_num == PyArray_STRING && \
+ other->descr->type_num == PyArray_UNICODE) {
+ Py_INCREF(other);
+ Py_INCREF(other->descr);
+ new = PyArray_FromAny((PyObject *)self, other->descr,
+ 0, 0, 0, NULL);
+ if (new == NULL) return NULL;
+ self = (PyArrayObject *)new;
+ }
+ else if (self->descr->type_num == PyArray_UNICODE && \
+ other->descr->type_num == PyArray_STRING) {
+ Py_INCREF(self);
+ Py_INCREF(self->descr);
+ new = PyArray_FromAny((PyObject *)other, self->descr,
+ 0, 0, 0, NULL);
+ if (new == NULL) return NULL;
+ other = (PyArrayObject *)new;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid string data-types"
+ "in comparison");
+ return NULL;
+ }
+ }
+ else {
+ Py_INCREF(self);
+ Py_INCREF(other);
+ }
+
+ /* Broad-cast the arrays to a common shape */
+ mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other);
+ Py_DECREF(self);
+ Py_DECREF(other);
+ if (mit == NULL) return NULL;
+
+ /* Choose largest size for result array */
+ if (self->descr->elsize > other->descr->elsize)
+ N = self->descr->elsize;
+ else
+ N = other->descr->elsize;
+
+ result = PyArray_NewFromDescr(&PyArray_Type,
+ PyArray_DescrFromType(PyArray_BOOL),
+ mit->nd,
+ mit->dimensions,
+ NULL, NULL, 0,
+ (PyObject *)
+ (prior2 > prior1 ? self : other));
+ if (result == NULL) goto finish;
+
+ if (self->descr->type_num == PyArray_STRING) {
+ val = _compare_strings(result, mit, cmp_op, N, strncmp);
+ }
+ else {
+ val = _compare_strings(result, mit, cmp_op,
+ N/sizeof(PyArray_UCS4),
+ PyArray_CompareUCS4);
+ }
+
+ if (val < 0) {Py_DECREF(result); result = NULL;}
+
+ finish:
+ Py_DECREF(mit);
+ return result;
+}
+
static PyObject *
array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
@@ -3413,11 +3557,13 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
switch (cmp_op)
{
case Py_LT:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.less);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.less);
+ break;
case Py_LE:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.less_equal);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.less_equal);
+ break;
case Py_EQ:
if (other == Py_None) {
Py_INCREF(Py_False);
@@ -3462,7 +3608,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
- return result;
+ break;
case Py_NE:
if (other == Py_None) {
Py_INCREF(Py_True);
@@ -3501,16 +3647,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
- return result;
+ break;
case Py_GT:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.greater);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.greater);
+ break;
case Py_GE:
- return PyArray_GenericBinaryFunction(self,
- other,
- n_ops.greater_equal);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.greater_equal);
+ break;
}
- return NULL;
+ if (result == Py_NotImplemented) {
+ /* Try to handle string comparisons */
+ if (self->descr->type_num == PyArray_OBJECT) return result;
+ array_other = PyArray_FromObject(other,PyArray_NOTYPE, 0, 0);
+ if (PyArray_ISSTRING(self) || PyArray_ISSTRING(array_other)) {
+ result = _strings_richcompare(self, (PyArrayObject *)
+ array_other, cmp_op);
+ }
+ Py_DECREF(array_other);
+ }
+ return result;
}
static PyObject *