summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/_sortmodule.c.src17
-rw-r--r--numpy/core/src/arrayobject.c181
-rw-r--r--numpy/core/src/multiarraymodule.c1
-rw-r--r--numpy/core/src/ucsnarrow.c5
4 files changed, 175 insertions, 29 deletions
diff --git a/numpy/core/src/_sortmodule.c.src b/numpy/core/src/_sortmodule.c.src
index 8f40e47ee..e152fe29f 100644
--- a/numpy/core/src/_sortmodule.c.src
+++ b/numpy/core/src/_sortmodule.c.src
@@ -355,23 +355,10 @@ static int
}
/**end repeat**/
-static int
-unincmp(Py_UNICODE *s1, Py_UNICODE *s2, register int len)
-{
- register Py_UNICODE c1, c2;
- while(len-- > 0) {
- c1 = *s1++;
- c2 = *s2++;
- if (c1 != c2)
- return (c1 < c2) ? -1 : 1;
- }
- return 0;
-}
-
/**begin repeat
#TYPE=STRING,UNICODE#
-#comp=strncmp,unincmp#
-#type=char *, Py_UNICODE *#
+#comp=strncmp,PyArray_CompareUCS4#
+#type=char *, PyArray_UCS4 *#
*/
static void
@TYPE@_amergesort0(intp *pl, intp *pr, @type@*v, intp *pw, int elsize)
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index b37201ee8..4fec8a726 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -3404,6 +3404,150 @@ array_str(PyArrayObject *self)
return s;
}
+/*OBJECT_API
+*/
+static int
+PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len)
+{
+ register PyArray_UCS4 c1, c2;
+ while(len-- > 0) {
+ c1 = *s1++;
+ c2 = *s2++;
+ if (c1 != c2) {
+ return (c1 < c2) ? -1 : 1;
+ }
+ }
+ return 0;
+}
+
+
+static int
+_compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
+ int cmp_op, int len, void *func)
+{
+ PyArrayIterObject *iself, *iother;
+ Bool *dptr;
+ intp size;
+ int val;
+ int (*cmpfunc)(void *, void *, size_t);
+
+ cmpfunc = func;
+ dptr = (Bool *)PyArray_DATA(result);
+ iself = multi->iters[0];
+ iother = multi->iters[1];
+ size = multi->size;
+ while(size--) {
+ val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr,
+ len);
+ switch (cmp_op) {
+ case Py_EQ:
+ *dptr = (val == 0);
+ break;
+ case Py_NE:
+ *dptr = (val != 0);
+ break;
+ case Py_LT:
+ *dptr = (val < 0);
+ break;
+ case Py_LE:
+ *dptr = (val <= 0);
+ break;
+ case Py_GT:
+ *dptr = (val > 0);
+ break;
+ case Py_GE:
+ *dptr = (val >= 0);
+ break;
+ default:
+ PyErr_SetString(PyExc_RuntimeError,
+ "bad comparison operator");
+ return -1;
+ }
+ PyArray_ITER_NEXT(iself);
+ PyArray_ITER_NEXT(iother);
+ dptr += 1;
+ }
+ return 0;
+}
+
+static PyObject *
+_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
+{
+ PyObject *result;
+ PyArrayMultiIterObject *mit;
+ double prior1, prior2;
+ int N, val;
+
+ /* Cast arrays to a common type */
+ if (self->descr->type != other->descr->type) {
+ PyObject *new;
+ if (self->descr->type_num == PyArray_STRING && \
+ other->descr->type_num == PyArray_UNICODE) {
+ Py_INCREF(other);
+ Py_INCREF(other->descr);
+ new = PyArray_FromAny((PyObject *)self, other->descr,
+ 0, 0, 0, NULL);
+ if (new == NULL) return NULL;
+ self = (PyArrayObject *)new;
+ }
+ else if (self->descr->type_num == PyArray_UNICODE && \
+ other->descr->type_num == PyArray_STRING) {
+ Py_INCREF(self);
+ Py_INCREF(self->descr);
+ new = PyArray_FromAny((PyObject *)other, self->descr,
+ 0, 0, 0, NULL);
+ if (new == NULL) return NULL;
+ other = (PyArrayObject *)new;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid string data-types"
+ "in comparison");
+ return NULL;
+ }
+ }
+ else {
+ Py_INCREF(self);
+ Py_INCREF(other);
+ }
+
+ /* Broad-cast the arrays to a common shape */
+ mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other);
+ Py_DECREF(self);
+ Py_DECREF(other);
+ if (mit == NULL) return NULL;
+
+ /* Choose largest size for result array */
+ if (self->descr->elsize > other->descr->elsize)
+ N = self->descr->elsize;
+ else
+ N = other->descr->elsize;
+
+ result = PyArray_NewFromDescr(&PyArray_Type,
+ PyArray_DescrFromType(PyArray_BOOL),
+ mit->nd,
+ mit->dimensions,
+ NULL, NULL, 0,
+ (PyObject *)
+ (prior2 > prior1 ? self : other));
+ if (result == NULL) goto finish;
+
+ if (self->descr->type_num == PyArray_STRING) {
+ val = _compare_strings(result, mit, cmp_op, N, strncmp);
+ }
+ else {
+ val = _compare_strings(result, mit, cmp_op,
+ N/sizeof(PyArray_UCS4),
+ PyArray_CompareUCS4);
+ }
+
+ if (val < 0) {Py_DECREF(result); result = NULL;}
+
+ finish:
+ Py_DECREF(mit);
+ return result;
+}
+
static PyObject *
array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
@@ -3413,11 +3557,13 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
switch (cmp_op)
{
case Py_LT:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.less);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.less);
+ break;
case Py_LE:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.less_equal);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.less_equal);
+ break;
case Py_EQ:
if (other == Py_None) {
Py_INCREF(Py_False);
@@ -3462,7 +3608,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
- return result;
+ break;
case Py_NE:
if (other == Py_None) {
Py_INCREF(Py_True);
@@ -3501,16 +3647,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
- return result;
+ break;
case Py_GT:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.greater);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.greater);
+ break;
case Py_GE:
- return PyArray_GenericBinaryFunction(self,
- other,
- n_ops.greater_equal);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.greater_equal);
+ break;
}
- return NULL;
+ if (result == Py_NotImplemented) {
+ /* Try to handle string comparisons */
+ if (self->descr->type_num == PyArray_OBJECT) return result;
+ array_other = PyArray_FromObject(other,PyArray_NOTYPE, 0, 0);
+ if (PyArray_ISSTRING(self) || PyArray_ISSTRING(array_other)) {
+ result = _strings_richcompare(self, (PyArrayObject *)
+ array_other, cmp_op);
+ }
+ Py_DECREF(array_other);
+ }
+ return result;
}
static PyObject *
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index 740d3fd24..7a6ad50b5 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -1729,6 +1729,7 @@ PyArray_ConvertToCommonType(PyObject *op, int *retn)
}
+
/*MULTIARRAY_API
Numeric.choose()
*/
diff --git a/numpy/core/src/ucsnarrow.c b/numpy/core/src/ucsnarrow.c
index 9eb668b77..990d678c6 100644
--- a/numpy/core/src/ucsnarrow.c
+++ b/numpy/core/src/ucsnarrow.c
@@ -1,7 +1,8 @@
/* Functions only needed on narrow builds of Python
- for converting back and forth between the NumPy Unicode data-type (always 4-byte)
+ for converting back and forth between the NumPy Unicode data-type
+ (always 4-byte)
and the Python Unicode scalar (2-bytes on a narrow build).
- */
+*/
/* the ucs2 buffer must be large enough to hold 2*ucs4length characters
due to the use of surrogate pairs.