summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/code_generators/array_api_order.txt1
-rw-r--r--numpy/core/defchararray.py56
-rw-r--r--numpy/core/src/_sortmodule.c.src17
-rw-r--r--numpy/core/src/arrayobject.c181
-rw-r--r--numpy/core/src/multiarraymodule.c1
-rw-r--r--numpy/core/src/ucsnarrow.c5
-rw-r--r--numpy/core/tests/test_multiarray.py21
7 files changed, 225 insertions, 57 deletions
diff --git a/numpy/core/code_generators/array_api_order.txt b/numpy/core/code_generators/array_api_order.txt
index f47ab8aa5..ac6341d6e 100644
--- a/numpy/core/code_generators/array_api_order.txt
+++ b/numpy/core/code_generators/array_api_order.txt
@@ -69,3 +69,4 @@ PyArray_ScalarKind
PyArray_CanCoerceScalar
PyArray_NewFlagsObject
PyArray_CanCastScalar
+PyArray_CompareUCS4
diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py
index aeba32fc5..f79389967 100644
--- a/numpy/core/defchararray.py
+++ b/numpy/core/defchararray.py
@@ -41,34 +41,34 @@ class chararray(ndarray):
raise ValueError, "Can only create a chararray from string data."
- def _richcmpfunc(self, other, op):
- b = broadcast(self, other)
- result = empty(b.shape, dtype=bool)
- res = result.flat
- for k, val in enumerate(b):
- r1 = val[0].rstrip('\x00')
- r2 = val[1]
- res[k] = eval("r1 %s r2" % op, {'r1':r1,'r2':r2})
- return result
-
- # these should probably be moved to C
- def __eq__(self, other):
- return self._richcmpfunc(other, '==')
-
- def __ne__(self, other):
- return self._richcmpfunc(other, '!=')
-
- def __ge__(self, other):
- return self._richcmpfunc(other, '>=')
-
- def __le__(self, other):
- return self._richcmpfunc(other, '<=')
-
- def __gt__(self, other):
- return self._richcmpfunc(other, '>')
-
- def __lt__(self, other):
- return self._richcmpfunc(other, '<')
+## def _richcmpfunc(self, other, op):
+## b = broadcast(self, other)
+## result = empty(b.shape, dtype=bool)
+## res = result.flat
+## for k, val in enumerate(b):
+## r1 = val[0].rstrip('\x00')
+## r2 = val[1]
+## res[k] = eval("r1 %s r2" % op, {'r1':r1,'r2':r2})
+## return result
+
+ # these have been moved to C
+## def __eq__(self, other):
+## return self._richcmpfunc(other, '==')
+
+## def __ne__(self, other):
+## return self._richcmpfunc(other, '!=')
+
+## def __ge__(self, other):
+## return self._richcmpfunc(other, '>=')
+
+## def __le__(self, other):
+## return self._richcmpfunc(other, '<=')
+
+## def __gt__(self, other):
+## return self._richcmpfunc(other, '>')
+
+## def __lt__(self, other):
+## return self._richcmpfunc(other, '<')
def __add__(self, other):
b = broadcast(self, other)
diff --git a/numpy/core/src/_sortmodule.c.src b/numpy/core/src/_sortmodule.c.src
index 8f40e47ee..e152fe29f 100644
--- a/numpy/core/src/_sortmodule.c.src
+++ b/numpy/core/src/_sortmodule.c.src
@@ -355,23 +355,10 @@ static int
}
/**end repeat**/
-static int
-unincmp(Py_UNICODE *s1, Py_UNICODE *s2, register int len)
-{
- register Py_UNICODE c1, c2;
- while(len-- > 0) {
- c1 = *s1++;
- c2 = *s2++;
- if (c1 != c2)
- return (c1 < c2) ? -1 : 1;
- }
- return 0;
-}
-
/**begin repeat
#TYPE=STRING,UNICODE#
-#comp=strncmp,unincmp#
-#type=char *, Py_UNICODE *#
+#comp=strncmp,PyArray_CompareUCS4#
+#type=char *, PyArray_UCS4 *#
*/
static void
@TYPE@_amergesort0(intp *pl, intp *pr, @type@*v, intp *pw, int elsize)
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index b37201ee8..4fec8a726 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -3404,6 +3404,150 @@ array_str(PyArrayObject *self)
return s;
}
+/*OBJECT_API
+*/
+static int
+PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len)
+{
+ register PyArray_UCS4 c1, c2;
+ while(len-- > 0) {
+ c1 = *s1++;
+ c2 = *s2++;
+ if (c1 != c2) {
+ return (c1 < c2) ? -1 : 1;
+ }
+ }
+ return 0;
+}
+
+
+static int
+_compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
+ int cmp_op, int len, void *func)
+{
+ PyArrayIterObject *iself, *iother;
+ Bool *dptr;
+ intp size;
+ int val;
+ int (*cmpfunc)(void *, void *, size_t);
+
+ cmpfunc = func;
+ dptr = (Bool *)PyArray_DATA(result);
+ iself = multi->iters[0];
+ iother = multi->iters[1];
+ size = multi->size;
+ while(size--) {
+ val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr,
+ len);
+ switch (cmp_op) {
+ case Py_EQ:
+ *dptr = (val == 0);
+ break;
+ case Py_NE:
+ *dptr = (val != 0);
+ break;
+ case Py_LT:
+ *dptr = (val < 0);
+ break;
+ case Py_LE:
+ *dptr = (val <= 0);
+ break;
+ case Py_GT:
+ *dptr = (val > 0);
+ break;
+ case Py_GE:
+ *dptr = (val >= 0);
+ break;
+ default:
+ PyErr_SetString(PyExc_RuntimeError,
+ "bad comparison operator");
+ return -1;
+ }
+ PyArray_ITER_NEXT(iself);
+ PyArray_ITER_NEXT(iother);
+ dptr += 1;
+ }
+ return 0;
+}
+
+static PyObject *
+_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
+{
+ PyObject *result;
+ PyArrayMultiIterObject *mit;
+ double prior1, prior2;
+ int N, val;
+
+ /* Cast arrays to a common type */
+ if (self->descr->type != other->descr->type) {
+ PyObject *new;
+ if (self->descr->type_num == PyArray_STRING && \
+ other->descr->type_num == PyArray_UNICODE) {
+ Py_INCREF(other);
+ Py_INCREF(other->descr);
+ new = PyArray_FromAny((PyObject *)self, other->descr,
+ 0, 0, 0, NULL);
+ if (new == NULL) return NULL;
+ self = (PyArrayObject *)new;
+ }
+ else if (self->descr->type_num == PyArray_UNICODE && \
+ other->descr->type_num == PyArray_STRING) {
+ Py_INCREF(self);
+ Py_INCREF(self->descr);
+ new = PyArray_FromAny((PyObject *)other, self->descr,
+ 0, 0, 0, NULL);
+ if (new == NULL) return NULL;
+ other = (PyArrayObject *)new;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid string data-types"
+ "in comparison");
+ return NULL;
+ }
+ }
+ else {
+ Py_INCREF(self);
+ Py_INCREF(other);
+ }
+
+ /* Broad-cast the arrays to a common shape */
+ mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other);
+ Py_DECREF(self);
+ Py_DECREF(other);
+ if (mit == NULL) return NULL;
+
+ /* Choose largest size for result array */
+ if (self->descr->elsize > other->descr->elsize)
+ N = self->descr->elsize;
+ else
+ N = other->descr->elsize;
+
+ result = PyArray_NewFromDescr(&PyArray_Type,
+ PyArray_DescrFromType(PyArray_BOOL),
+ mit->nd,
+ mit->dimensions,
+ NULL, NULL, 0,
+ (PyObject *)
+ (prior2 > prior1 ? self : other));
+ if (result == NULL) goto finish;
+
+ if (self->descr->type_num == PyArray_STRING) {
+ val = _compare_strings(result, mit, cmp_op, N, strncmp);
+ }
+ else {
+ val = _compare_strings(result, mit, cmp_op,
+ N/sizeof(PyArray_UCS4),
+ PyArray_CompareUCS4);
+ }
+
+ if (val < 0) {Py_DECREF(result); result = NULL;}
+
+ finish:
+ Py_DECREF(mit);
+ return result;
+}
+
static PyObject *
array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
@@ -3413,11 +3557,13 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
switch (cmp_op)
{
case Py_LT:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.less);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.less);
+ break;
case Py_LE:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.less_equal);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.less_equal);
+ break;
case Py_EQ:
if (other == Py_None) {
Py_INCREF(Py_False);
@@ -3462,7 +3608,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
- return result;
+ break;
case Py_NE:
if (other == Py_None) {
Py_INCREF(Py_True);
@@ -3501,16 +3647,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
- return result;
+ break;
case Py_GT:
- return PyArray_GenericBinaryFunction(self, other,
- n_ops.greater);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.greater);
+ break;
case Py_GE:
- return PyArray_GenericBinaryFunction(self,
- other,
- n_ops.greater_equal);
+ result = PyArray_GenericBinaryFunction(self, other,
+ n_ops.greater_equal);
+ break;
}
- return NULL;
+ if (result == Py_NotImplemented) {
+ /* Try to handle string comparisons */
+ if (self->descr->type_num == PyArray_OBJECT) return result;
+ array_other = PyArray_FromObject(other,PyArray_NOTYPE, 0, 0);
+ if (PyArray_ISSTRING(self) || PyArray_ISSTRING(array_other)) {
+ result = _strings_richcompare(self, (PyArrayObject *)
+ array_other, cmp_op);
+ }
+ Py_DECREF(array_other);
+ }
+ return result;
}
static PyObject *
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index 740d3fd24..7a6ad50b5 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -1729,6 +1729,7 @@ PyArray_ConvertToCommonType(PyObject *op, int *retn)
}
+
/*MULTIARRAY_API
Numeric.choose()
*/
diff --git a/numpy/core/src/ucsnarrow.c b/numpy/core/src/ucsnarrow.c
index 9eb668b77..990d678c6 100644
--- a/numpy/core/src/ucsnarrow.c
+++ b/numpy/core/src/ucsnarrow.c
@@ -1,7 +1,8 @@
/* Functions only needed on narrow builds of Python
- for converting back and forth between the NumPy Unicode data-type (always 4-byte)
+ for converting back and forth between the NumPy Unicode data-type
+ (always 4-byte)
and the Python Unicode scalar (2-bytes on a narrow build).
- */
+*/
/* the ucs2 buffer must be large enough to hold 2*ucs4length characters
due to the use of surrogate pairs.
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 47e005b75..877865546 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -257,6 +257,27 @@ class test_fancy_indexing(ScipyTestCase):
x[:,:,(0,)] = 2.0
assert_array_equal(x, array([[[2.0]]]))
+class test_string_compare(ScipyTestCase):
+ def check_string(self):
+ g1 = array(["This","is","example"])
+ g2 = array(["This","was","example"])
+ assert_array_equal(g1 == g2, [True, False, True])
+ assert_array_equal(g1 != g2, [False, True, False])
+ assert_array_equal(g1 <= g2, [True, True, True])
+ assert_array_equal(g1 >= g2, [True, False, True])
+ assert_array_equal(g1 < g2, [False, True, False])
+ assert_array_equal(g1 > g2, [False, False, False])
+
+ def check_unicode(self):
+ g1 = array([u"This",u"is",u"example"])
+ g2 = array([u"This",u"was",u"example"])
+ assert_array_equal(g1 == g2, [True, False, True])
+ assert_array_equal(g1 != g2, [False, True, False])
+ assert_array_equal(g1 <= g2, [True, True, True])
+ assert_array_equal(g1 >= g2, [True, False, True])
+ assert_array_equal(g1 < g2, [False, True, False])
+ assert_array_equal(g1 > g2, [False, False, False])
+
# Import tests from unicode
set_local_path()