summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c202
1 files changed, 166 insertions, 36 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index c4303bdd3..f1c8b6b9b 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -4183,9 +4183,130 @@ _mystrncmp(char *s1, char *s2, int len1, int len2)
return 0;
}
+/* Borrowed from Numarray */
+
+#define SMALL_STRING 2048
+
+#if defined(isspace)
+#undef isspace
+#define isspace(c) ((c==' ')||(c=='\t')||(c=='\n')||(c=='\r')||(c=='\v')||(c=='\f'))
+#endif
+
+static void _rstripw(char *s, int n)
+{
+ int i;
+ for(i=strnlen(s,n)-1; i>=1; i--) /* Never strip to length 0. */
+ {
+ int c = s[i];
+ if (!c || isspace(c))
+ s[i] = 0;
+ else
+ break;
+ }
+}
+
+static void _unistripw(PyArray_UCS4 *s, int n)
+{
+ int i;
+ for(i=n-1; i>=1; i--) /* Never strip to length 0. */
+ {
+ PyArray_UCS4 c = s[i];
+ if (!c || isspace(c))
+ s[i] = 0;
+ else
+ break;
+ }
+}
+
+
+static char *
+_char_copy_n_strip(char *original, char *temp, int nc)
+{
+ if (nc > SMALL_STRING) {
+ temp = malloc(nc);
+ if (!temp) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ }
+ memcpy(temp, original, nc);
+ _rstripw(temp, nc);
+ return temp;
+}
+
+static void
+_char_release(char *ptr, int nc)
+{
+ if (nc > SMALL_STRING) {
+ free(ptr);
+ }
+}
+
+static char *
+_uni_copy_n_strip(char *original, char *temp, int nc)
+{
+ if (nc*4 > SMALL_STRING) {
+ temp = malloc(nc);
+ if (!temp) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ }
+ memcpy(temp, original, nc*sizeof(PyArray_UCS4));
+ _unistripw((PyArray_UCS4 *)temp, nc);
+ return temp;
+}
+
+static void
+_uni_release(char *ptr, int nc)
+{
+ if (nc*sizeof(PyArray_UCS4) > SMALL_STRING) {
+ free(ptr);
+ }
+}
+
+
+/* End borrowed from numarray */
+
+#define _rstrip_loop(CMP) { \
+ void *aptr, *bptr; \
+ char atemp[SMALL_STRING], btemp[SMALL_STRING]; \
+ while(size--) { \
+ aptr = stripfunc(iself->dataptr, atemp, N1); \
+ if (!aptr) return -1; \
+ bptr = stripfunc(iother->dataptr, btemp, N2); \
+ if (!bptr) { \
+ relfunc(aptr, N1); \
+ return -1; \
+ } \
+ val = cmpfunc(aptr, bptr, N1, N2); \
+ *dptr = (val CMP 0); \
+ PyArray_ITER_NEXT(iself); \
+ PyArray_ITER_NEXT(iother); \
+ dptr += 1; \
+ relfunc(aptr, N1); \
+ relfunc(bptr, N2); \
+ } \
+ }
+
+#define _reg_loop(CMP) { \
+ while(size--) { \
+ val = cmpfunc((void *)iself->dataptr, \
+ (void *)iother->dataptr, \
+ N1, N2); \
+ *dptr = (val CMP 0); \
+ PyArray_ITER_NEXT(iself); \
+ PyArray_ITER_NEXT(iother); \
+ dptr += 1; \
+ } \
+ }
+
+#define _loop(CMP) if (rstrip) _rstrip_loop(CMP) \
+ else _reg_loop(CMP)
+
static int
_compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
- int cmp_op, void *func)
+ int cmp_op, void *func, int rstrip)
{
PyArrayIterObject *iself, *iother;
Bool *dptr;
@@ -4193,6 +4314,8 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
int val;
int N1, N2;
int (*cmpfunc)(void *, void *, int, int);
+ void (*relfunc)(char *, int);
+ char* (*stripfunc)(char *, char *, int);
cmpfunc = func;
dptr = (Bool *)PyArray_DATA(result);
@@ -4204,43 +4327,48 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
if ((void *)cmpfunc == (void *)_myunincmp) {
N1 >>= 2;
N2 >>= 2;
+ stripfunc = _uni_copy_n_strip;
+ relfunc = _uni_release;
}
- while(size--) {
- val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr,
- N1, N2);
- switch (cmp_op) {
- case Py_EQ:
- *dptr = (val == 0);
- break;
- case Py_NE:
- *dptr = (val != 0);
- break;
- case Py_LT:
- *dptr = (val < 0);
- break;
- case Py_LE:
- *dptr = (val <= 0);
- break;
- case Py_GT:
- *dptr = (val > 0);
- break;
- case Py_GE:
- *dptr = (val >= 0);
- break;
- default:
- PyErr_SetString(PyExc_RuntimeError,
- "bad comparison operator");
- return -1;
- }
- PyArray_ITER_NEXT(iself);
- PyArray_ITER_NEXT(iother);
- dptr += 1;
+ else {
+ stripfunc = _char_copy_n_strip;
+ relfunc = _char_release;
+ }
+ switch (cmp_op) {
+ case Py_EQ:
+ _loop(==)
+ break;
+ case Py_NE:
+ _loop(!=)
+ break;
+ case Py_LT:
+ _loop(<)
+ break;
+ case Py_LE:
+ _loop(<=)
+ break;
+ case Py_GT:
+ _loop(>)
+ break;
+ case Py_GE:
+ _loop(>=)
+ break;
+ default:
+ PyErr_SetString(PyExc_RuntimeError,
+ "bad comparison operator");
+ return -1;
}
return 0;
}
+#undef _loop
+#undef _reg_loop
+#undef _rstrip_loop
+#undef SMALL_STRING
+
static PyObject *
-_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
+_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
+ int rstrip)
{
PyObject *result;
PyArrayMultiIterObject *mit;
@@ -4294,10 +4422,12 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
if (result == NULL) goto finish;
if (self->descr->type_num == PyArray_UNICODE) {
- val = _compare_strings(result, mit, cmp_op, _myunincmp);
+ val = _compare_strings(result, mit, cmp_op, _myunincmp,
+ rstrip);
}
else {
- val = _compare_strings(result, mit, cmp_op, _mystrncmp);
+ val = _compare_strings(result, mit, cmp_op, _mystrncmp,
+ rstrip);
}
if (val < 0) {Py_DECREF(result); result = NULL;}
@@ -4361,7 +4491,7 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
}
else { /* compare as a string */
/* assumes self and other have same descr->type */
- return _strings_richcompare(self, other, cmp_op);
+ return _strings_richcompare(self, other, cmp_op, 0);
}
}
@@ -4531,7 +4661,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
if (PyArray_ISSTRING(self) && PyArray_ISSTRING(array_other)) {
Py_DECREF(result);
result = _strings_richcompare(self, (PyArrayObject *)
- array_other, cmp_op);
+ array_other, cmp_op, 0);
}
Py_DECREF(array_other);
}