diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-14 20:01:12 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-14 20:01:12 +0000 |
commit | c70b3c6fe0e073fc70eb8b424c30ca6c5c01ea04 (patch) | |
tree | fcf9794705f1ae3397187cfa6bbede8df48cee79 /numpy/core | |
parent | 779dc154b799a6660f7f60ef50c09fc445329999 (diff) | |
download | numpy-c70b3c6fe0e073fc70eb8b424c30ca6c5c01ea04.tar.gz |
Strip characters from chararrays during comparision
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/defchararray.py | 29 | ||||
-rw-r--r-- | numpy/core/numeric.py | 3 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 202 | ||||
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 67 | ||||
-rw-r--r-- | numpy/core/src/ufuncobject.c | 2 |
5 files changed, 262 insertions, 41 deletions
diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py index 20f26f7c3..35d206de7 100644 --- a/numpy/core/defchararray.py +++ b/numpy/core/defchararray.py @@ -1,5 +1,5 @@ from numerictypes import string_, unicode_, integer, object_ -from numeric import ndarray, broadcast, empty +from numeric import ndarray, broadcast, empty, compare_chararrays from numeric import array as narray import sys @@ -12,7 +12,8 @@ _unicode = unicode # This adds + and * operations and methods of str and unicode types # which operate on an element-by-element basis -# It also strips white-space on element retrieval +# It also strips white-space on element retrieval and on +# comparisons class chararray(ndarray): def __new__(subtype, shape, itemsize=1, unicode=False, buffer=None, @@ -44,9 +45,31 @@ class chararray(ndarray): def __getitem__(self, obj): val = ndarray.__getitem__(self, obj) if isinstance(val, (string_, unicode_)): - return val.rstrip() + temp = val.rstrip() + if len(temp) == 0: + val = val[0] + else: + val = temp return val + def __eq__(self, other): + return compare_chararrays(self, other, '==', True) + + def __ne__(self, other): + return compare_chararrays(self, other, '!=', True) + + def __ge__(self, other): + return compare_chararrays(self, other, '>=', True) + + def __le__(self, other): + return compare_chararrays(self, other, '<=', True) + + def __gt__(self, other): + return compare_chararrays(self, other, '>', True) + + def __lt__(self, other): + return compare_chararrays(self, other, '<', True) + def __add__(self, other): b = broadcast(self, other) arr = b.iters[1].base diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index d4537f435..8099a1273 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -14,7 +14,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc', 'fromiter', 'array_equal', 'array_equiv', 'indices', 'fromfunction', 'load', 'loads', 'isscalar', 'binary_repr', 'base_repr', - 'ones', 'identity', 'allclose', + 'ones', 'identity', 'allclose', 'compare_chararrays', 'seterr', 'geterr', 'setbufsize', 'getbufsize', 'seterrcall', 'geterrcall', 'flatnonzero', 'Inf', 'inf', 'infty', 'Infinity', @@ -119,6 +119,7 @@ fastCopyAndTranspose = multiarray._fastCopyAndTranspose set_numeric_ops = multiarray.set_numeric_ops can_cast = multiarray.can_cast lexsort = multiarray.lexsort +compare_chararrays = multiarray.compare_chararrays def asarray(a, dtype=None, order=None): diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index c4303bdd3..f1c8b6b9b 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -4183,9 +4183,130 @@ _mystrncmp(char *s1, char *s2, int len1, int len2) return 0; } +/* Borrowed from Numarray */ + +#define SMALL_STRING 2048 + +#if defined(isspace) +#undef isspace +#define isspace(c) ((c==' ')||(c=='\t')||(c=='\n')||(c=='\r')||(c=='\v')||(c=='\f')) +#endif + +static void _rstripw(char *s, int n) +{ + int i; + for(i=strnlen(s,n)-1; i>=1; i--) /* Never strip to length 0. */ + { + int c = s[i]; + if (!c || isspace(c)) + s[i] = 0; + else + break; + } +} + +static void _unistripw(PyArray_UCS4 *s, int n) +{ + int i; + for(i=n-1; i>=1; i--) /* Never strip to length 0. */ + { + PyArray_UCS4 c = s[i]; + if (!c || isspace(c)) + s[i] = 0; + else + break; + } +} + + +static char * +_char_copy_n_strip(char *original, char *temp, int nc) +{ + if (nc > SMALL_STRING) { + temp = malloc(nc); + if (!temp) { + PyErr_NoMemory(); + return NULL; + } + } + memcpy(temp, original, nc); + _rstripw(temp, nc); + return temp; +} + +static void +_char_release(char *ptr, int nc) +{ + if (nc > SMALL_STRING) { + free(ptr); + } +} + +static char * +_uni_copy_n_strip(char *original, char *temp, int nc) +{ + if (nc*4 > SMALL_STRING) { + temp = malloc(nc); + if (!temp) { + PyErr_NoMemory(); + return NULL; + } + } + memcpy(temp, original, nc*sizeof(PyArray_UCS4)); + _unistripw((PyArray_UCS4 *)temp, nc); + return temp; +} + +static void +_uni_release(char *ptr, int nc) +{ + if (nc*sizeof(PyArray_UCS4) > SMALL_STRING) { + free(ptr); + } +} + + +/* End borrowed from numarray */ + +#define _rstrip_loop(CMP) { \ + void *aptr, *bptr; \ + char atemp[SMALL_STRING], btemp[SMALL_STRING]; \ + while(size--) { \ + aptr = stripfunc(iself->dataptr, atemp, N1); \ + if (!aptr) return -1; \ + bptr = stripfunc(iother->dataptr, btemp, N2); \ + if (!bptr) { \ + relfunc(aptr, N1); \ + return -1; \ + } \ + val = cmpfunc(aptr, bptr, N1, N2); \ + *dptr = (val CMP 0); \ + PyArray_ITER_NEXT(iself); \ + PyArray_ITER_NEXT(iother); \ + dptr += 1; \ + relfunc(aptr, N1); \ + relfunc(bptr, N2); \ + } \ + } + +#define _reg_loop(CMP) { \ + while(size--) { \ + val = cmpfunc((void *)iself->dataptr, \ + (void *)iother->dataptr, \ + N1, N2); \ + *dptr = (val CMP 0); \ + PyArray_ITER_NEXT(iself); \ + PyArray_ITER_NEXT(iother); \ + dptr += 1; \ + } \ + } + +#define _loop(CMP) if (rstrip) _rstrip_loop(CMP) \ + else _reg_loop(CMP) + static int _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, - int cmp_op, void *func) + int cmp_op, void *func, int rstrip) { PyArrayIterObject *iself, *iother; Bool *dptr; @@ -4193,6 +4314,8 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, int val; int N1, N2; int (*cmpfunc)(void *, void *, int, int); + void (*relfunc)(char *, int); + char* (*stripfunc)(char *, char *, int); cmpfunc = func; dptr = (Bool *)PyArray_DATA(result); @@ -4204,43 +4327,48 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, if ((void *)cmpfunc == (void *)_myunincmp) { N1 >>= 2; N2 >>= 2; + stripfunc = _uni_copy_n_strip; + relfunc = _uni_release; } - while(size--) { - val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, - N1, N2); - switch (cmp_op) { - case Py_EQ: - *dptr = (val == 0); - break; - case Py_NE: - *dptr = (val != 0); - break; - case Py_LT: - *dptr = (val < 0); - break; - case Py_LE: - *dptr = (val <= 0); - break; - case Py_GT: - *dptr = (val > 0); - break; - case Py_GE: - *dptr = (val >= 0); - break; - default: - PyErr_SetString(PyExc_RuntimeError, - "bad comparison operator"); - return -1; - } - PyArray_ITER_NEXT(iself); - PyArray_ITER_NEXT(iother); - dptr += 1; + else { + stripfunc = _char_copy_n_strip; + relfunc = _char_release; + } + switch (cmp_op) { + case Py_EQ: + _loop(==) + break; + case Py_NE: + _loop(!=) + break; + case Py_LT: + _loop(<) + break; + case Py_LE: + _loop(<=) + break; + case Py_GT: + _loop(>) + break; + case Py_GE: + _loop(>=) + break; + default: + PyErr_SetString(PyExc_RuntimeError, + "bad comparison operator"); + return -1; } return 0; } +#undef _loop +#undef _reg_loop +#undef _rstrip_loop +#undef SMALL_STRING + static PyObject * -_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) +_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, + int rstrip) { PyObject *result; PyArrayMultiIterObject *mit; @@ -4294,10 +4422,12 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) if (result == NULL) goto finish; if (self->descr->type_num == PyArray_UNICODE) { - val = _compare_strings(result, mit, cmp_op, _myunincmp); + val = _compare_strings(result, mit, cmp_op, _myunincmp, + rstrip); } else { - val = _compare_strings(result, mit, cmp_op, _mystrncmp); + val = _compare_strings(result, mit, cmp_op, _mystrncmp, + rstrip); } if (val < 0) {Py_DECREF(result); result = NULL;} @@ -4361,7 +4491,7 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op) } else { /* compare as a string */ /* assumes self and other have same descr->type */ - return _strings_richcompare(self, other, cmp_op); + return _strings_richcompare(self, other, cmp_op, 0); } } @@ -4531,7 +4661,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) if (PyArray_ISSTRING(self) && PyArray_ISSTRING(array_other)) { Py_DECREF(result); result = _strings_richcompare(self, (PyArrayObject *) - array_other, cmp_op); + array_other, cmp_op, 0); } Py_DECREF(array_other); } diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 9a219ca3e..257e93258 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -6357,6 +6357,71 @@ format_longfloat(PyObject *dummy, PyObject *args, PyObject *kwds) return PyString_FromString(repr); } +static PyObject * +compare_chararrays(PyObject *dummy, PyObject *args, PyObject *kwds) +{ + PyObject *array; + PyObject *other; + PyArrayObject *newarr, *newoth; + int cmp_op; + Bool rstrip; + char *cmp_str; + Py_ssize_t strlen; + PyObject *res=NULL; + static char msg[] = \ + "comparision must be '==', '!=', '<', '>', '<=', '>='"; + + static char *kwlist[] = {"a1", "a2", "cmp", "rstrip", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOs#O&", kwlist, + &array, &other, + &cmp_str, &strlen, + PyArray_BoolConverter, &rstrip)) + return NULL; + + if (strlen < 1 || strlen > 2) goto err; + if (strlen > 1) { + if (cmp_str[1] != '=') goto err; + if (cmp_str[0] == '=') cmp_op = Py_EQ; + else if (cmp_str[0] == '!') cmp_op = Py_NE; + else if (cmp_str[0] == '<') cmp_op = Py_LE; + else if (cmp_str[0] == '>') cmp_op = Py_GE; + else goto err; + } + else { + if (cmp_str[0] == '<') cmp_op = Py_LT; + else if (cmp_str[0] == '>') cmp_op = Py_GT; + else goto err; + } + + newarr = (PyArrayObject *)PyArray_FROM_O(array); + if (newarr == NULL) return NULL; + newoth = (PyArrayObject *)PyArray_FROM_O(other); + if (newoth == NULL) { + Py_DECREF(newarr); + return NULL; + } + + if (PyArray_ISSTRING(newarr) && PyArray_ISSTRING(newoth)) { + res = _strings_richcompare(newarr, newoth, cmp_op, rstrip != 0); + } + else { + PyErr_SetString(PyExc_TypeError, + "comparison of non-string arrays"); + } + + Py_DECREF(newarr); + Py_DECREF(newoth); + return res; + + err: + PyErr_SetString(PyExc_ValueError, msg); + return NULL; +} + + + + static struct PyMethodDef array_module_methods[] = { {"_get_ndarray_c_version", (PyCFunction)array__get_ndarray_c_version, METH_VARARGS|METH_KEYWORDS, NULL}, @@ -6409,6 +6474,8 @@ static struct PyMethodDef array_module_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"format_longfloat", (PyCFunction)format_longfloat, METH_VARARGS | METH_KEYWORDS, NULL}, + {"compare_chararrays", (PyCFunction)compare_chararrays, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0} /* sentinel */ }; diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index a2b711093..d7ecb8a75 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -1761,7 +1761,7 @@ construct_reduce(PyUFuncObject *self, PyArrayObject **arr, PyArrayObject *out, PyUFuncReduceObject *loop; PyArrayObject *idarr; PyArrayObject *aar; - intp loop_i[MAX_DIMS], outsize; + intp loop_i[MAX_DIMS], outsize=0; int arg_types[3] = {otype, otype, otype}; PyArray_SCALARKIND scalars[3] = {PyArray_NOSCALAR, PyArray_NOSCALAR, PyArray_NOSCALAR}; |