diff options
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 202 |
1 files changed, 166 insertions, 36 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index c4303bdd3..f1c8b6b9b 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -4183,9 +4183,130 @@ _mystrncmp(char *s1, char *s2, int len1, int len2) return 0; } +/* Borrowed from Numarray */ + +#define SMALL_STRING 2048 + +#if defined(isspace) +#undef isspace +#define isspace(c) ((c==' ')||(c=='\t')||(c=='\n')||(c=='\r')||(c=='\v')||(c=='\f')) +#endif + +static void _rstripw(char *s, int n) +{ + int i; + for(i=strnlen(s,n)-1; i>=1; i--) /* Never strip to length 0. */ + { + int c = s[i]; + if (!c || isspace(c)) + s[i] = 0; + else + break; + } +} + +static void _unistripw(PyArray_UCS4 *s, int n) +{ + int i; + for(i=n-1; i>=1; i--) /* Never strip to length 0. */ + { + PyArray_UCS4 c = s[i]; + if (!c || isspace(c)) + s[i] = 0; + else + break; + } +} + + +static char * +_char_copy_n_strip(char *original, char *temp, int nc) +{ + if (nc > SMALL_STRING) { + temp = malloc(nc); + if (!temp) { + PyErr_NoMemory(); + return NULL; + } + } + memcpy(temp, original, nc); + _rstripw(temp, nc); + return temp; +} + +static void +_char_release(char *ptr, int nc) +{ + if (nc > SMALL_STRING) { + free(ptr); + } +} + +static char * +_uni_copy_n_strip(char *original, char *temp, int nc) +{ + if (nc*4 > SMALL_STRING) { + temp = malloc(nc); + if (!temp) { + PyErr_NoMemory(); + return NULL; + } + } + memcpy(temp, original, nc*sizeof(PyArray_UCS4)); + _unistripw((PyArray_UCS4 *)temp, nc); + return temp; +} + +static void +_uni_release(char *ptr, int nc) +{ + if (nc*sizeof(PyArray_UCS4) > SMALL_STRING) { + free(ptr); + } +} + + +/* End borrowed from numarray */ + +#define _rstrip_loop(CMP) { \ + void *aptr, *bptr; \ + char atemp[SMALL_STRING], btemp[SMALL_STRING]; \ + while(size--) { \ + aptr = stripfunc(iself->dataptr, atemp, N1); \ + if (!aptr) return -1; \ + bptr = stripfunc(iother->dataptr, btemp, N2); \ + if (!bptr) { \ + relfunc(aptr, N1); \ + return -1; \ + } \ + val = cmpfunc(aptr, bptr, N1, N2); \ + *dptr = (val CMP 0); \ + PyArray_ITER_NEXT(iself); \ + PyArray_ITER_NEXT(iother); \ + dptr += 1; \ + relfunc(aptr, N1); \ + relfunc(bptr, N2); \ + } \ + } + +#define _reg_loop(CMP) { \ + while(size--) { \ + val = cmpfunc((void *)iself->dataptr, \ + (void *)iother->dataptr, \ + N1, N2); \ + *dptr = (val CMP 0); \ + PyArray_ITER_NEXT(iself); \ + PyArray_ITER_NEXT(iother); \ + dptr += 1; \ + } \ + } + +#define _loop(CMP) if (rstrip) _rstrip_loop(CMP) \ + else _reg_loop(CMP) + static int _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, - int cmp_op, void *func) + int cmp_op, void *func, int rstrip) { PyArrayIterObject *iself, *iother; Bool *dptr; @@ -4193,6 +4314,8 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, int val; int N1, N2; int (*cmpfunc)(void *, void *, int, int); + void (*relfunc)(char *, int); + char* (*stripfunc)(char *, char *, int); cmpfunc = func; dptr = (Bool *)PyArray_DATA(result); @@ -4204,43 +4327,48 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, if ((void *)cmpfunc == (void *)_myunincmp) { N1 >>= 2; N2 >>= 2; + stripfunc = _uni_copy_n_strip; + relfunc = _uni_release; } - while(size--) { - val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, - N1, N2); - switch (cmp_op) { - case Py_EQ: - *dptr = (val == 0); - break; - case Py_NE: - *dptr = (val != 0); - break; - case Py_LT: - *dptr = (val < 0); - break; - case Py_LE: - *dptr = (val <= 0); - break; - case Py_GT: - *dptr = (val > 0); - break; - case Py_GE: - *dptr = (val >= 0); - break; - default: - PyErr_SetString(PyExc_RuntimeError, - "bad comparison operator"); - return -1; - } - PyArray_ITER_NEXT(iself); - PyArray_ITER_NEXT(iother); - dptr += 1; + else { + stripfunc = _char_copy_n_strip; + relfunc = _char_release; + } + switch (cmp_op) { + case Py_EQ: + _loop(==) + break; + case Py_NE: + _loop(!=) + break; + case Py_LT: + _loop(<) + break; + case Py_LE: + _loop(<=) + break; + case Py_GT: + _loop(>) + break; + case Py_GE: + _loop(>=) + break; + default: + PyErr_SetString(PyExc_RuntimeError, + "bad comparison operator"); + return -1; } return 0; } +#undef _loop +#undef _reg_loop +#undef _rstrip_loop +#undef SMALL_STRING + static PyObject * -_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) +_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, + int rstrip) { PyObject *result; PyArrayMultiIterObject *mit; @@ -4294,10 +4422,12 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) if (result == NULL) goto finish; if (self->descr->type_num == PyArray_UNICODE) { - val = _compare_strings(result, mit, cmp_op, _myunincmp); + val = _compare_strings(result, mit, cmp_op, _myunincmp, + rstrip); } else { - val = _compare_strings(result, mit, cmp_op, _mystrncmp); + val = _compare_strings(result, mit, cmp_op, _mystrncmp, + rstrip); } if (val < 0) {Py_DECREF(result); result = NULL;} @@ -4361,7 +4491,7 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op) } else { /* compare as a string */ /* assumes self and other have same descr->type */ - return _strings_richcompare(self, other, cmp_op); + return _strings_richcompare(self, other, cmp_op, 0); } } @@ -4531,7 +4661,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) if (PyArray_ISSTRING(self) && PyArray_ISSTRING(array_other)) { Py_DECREF(result); result = _strings_richcompare(self, (PyArrayObject *) - array_other, cmp_op); + array_other, cmp_op, 0); } Py_DECREF(array_other); } |