summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-14 20:01:12 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-14 20:01:12 +0000
commitc70b3c6fe0e073fc70eb8b424c30ca6c5c01ea04 (patch)
treefcf9794705f1ae3397187cfa6bbede8df48cee79 /numpy/core
parent779dc154b799a6660f7f60ef50c09fc445329999 (diff)
downloadnumpy-c70b3c6fe0e073fc70eb8b424c30ca6c5c01ea04.tar.gz
Strip characters from chararrays during comparision
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/defchararray.py29
-rw-r--r--numpy/core/numeric.py3
-rw-r--r--numpy/core/src/arrayobject.c202
-rw-r--r--numpy/core/src/multiarraymodule.c67
-rw-r--r--numpy/core/src/ufuncobject.c2
5 files changed, 262 insertions, 41 deletions
diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py
index 20f26f7c3..35d206de7 100644
--- a/numpy/core/defchararray.py
+++ b/numpy/core/defchararray.py
@@ -1,5 +1,5 @@
from numerictypes import string_, unicode_, integer, object_
-from numeric import ndarray, broadcast, empty
+from numeric import ndarray, broadcast, empty, compare_chararrays
from numeric import array as narray
import sys
@@ -12,7 +12,8 @@ _unicode = unicode
# This adds + and * operations and methods of str and unicode types
# which operate on an element-by-element basis
-# It also strips white-space on element retrieval
+# It also strips white-space on element retrieval and on
+# comparisons
class chararray(ndarray):
def __new__(subtype, shape, itemsize=1, unicode=False, buffer=None,
@@ -44,9 +45,31 @@ class chararray(ndarray):
def __getitem__(self, obj):
val = ndarray.__getitem__(self, obj)
if isinstance(val, (string_, unicode_)):
- return val.rstrip()
+ temp = val.rstrip()
+ if len(temp) == 0:
+ val = val[0]
+ else:
+ val = temp
return val
+ def __eq__(self, other):
+ return compare_chararrays(self, other, '==', True)
+
+ def __ne__(self, other):
+ return compare_chararrays(self, other, '!=', True)
+
+ def __ge__(self, other):
+ return compare_chararrays(self, other, '>=', True)
+
+ def __le__(self, other):
+ return compare_chararrays(self, other, '<=', True)
+
+ def __gt__(self, other):
+ return compare_chararrays(self, other, '>', True)
+
+ def __lt__(self, other):
+ return compare_chararrays(self, other, '<', True)
+
def __add__(self, other):
b = broadcast(self, other)
arr = b.iters[1].base
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index d4537f435..8099a1273 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -14,7 +14,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc',
'fromiter', 'array_equal', 'array_equiv',
'indices', 'fromfunction',
'load', 'loads', 'isscalar', 'binary_repr', 'base_repr',
- 'ones', 'identity', 'allclose',
+ 'ones', 'identity', 'allclose', 'compare_chararrays',
'seterr', 'geterr', 'setbufsize', 'getbufsize',
'seterrcall', 'geterrcall', 'flatnonzero',
'Inf', 'inf', 'infty', 'Infinity',
@@ -119,6 +119,7 @@ fastCopyAndTranspose = multiarray._fastCopyAndTranspose
set_numeric_ops = multiarray.set_numeric_ops
can_cast = multiarray.can_cast
lexsort = multiarray.lexsort
+compare_chararrays = multiarray.compare_chararrays
def asarray(a, dtype=None, order=None):
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index c4303bdd3..f1c8b6b9b 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -4183,9 +4183,130 @@ _mystrncmp(char *s1, char *s2, int len1, int len2)
return 0;
}
+/* Borrowed from Numarray */
+
+#define SMALL_STRING 2048
+
+#if defined(isspace)
+#undef isspace
+#define isspace(c) ((c==' ')||(c=='\t')||(c=='\n')||(c=='\r')||(c=='\v')||(c=='\f'))
+#endif
+
+static void _rstripw(char *s, int n)
+{
+ int i;
+ for(i=strnlen(s,n)-1; i>=1; i--) /* Never strip to length 0. */
+ {
+ int c = s[i];
+ if (!c || isspace(c))
+ s[i] = 0;
+ else
+ break;
+ }
+}
+
+static void _unistripw(PyArray_UCS4 *s, int n)
+{
+ int i;
+ for(i=n-1; i>=1; i--) /* Never strip to length 0. */
+ {
+ PyArray_UCS4 c = s[i];
+ if (!c || isspace(c))
+ s[i] = 0;
+ else
+ break;
+ }
+}
+
+
+static char *
+_char_copy_n_strip(char *original, char *temp, int nc)
+{
+ if (nc > SMALL_STRING) {
+ temp = malloc(nc);
+ if (!temp) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ }
+ memcpy(temp, original, nc);
+ _rstripw(temp, nc);
+ return temp;
+}
+
+static void
+_char_release(char *ptr, int nc)
+{
+ if (nc > SMALL_STRING) {
+ free(ptr);
+ }
+}
+
+static char *
+_uni_copy_n_strip(char *original, char *temp, int nc)
+{
+ if (nc*4 > SMALL_STRING) {
+ temp = malloc(nc);
+ if (!temp) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ }
+ memcpy(temp, original, nc*sizeof(PyArray_UCS4));
+ _unistripw((PyArray_UCS4 *)temp, nc);
+ return temp;
+}
+
+static void
+_uni_release(char *ptr, int nc)
+{
+ if (nc*sizeof(PyArray_UCS4) > SMALL_STRING) {
+ free(ptr);
+ }
+}
+
+
+/* End borrowed from numarray */
+
+#define _rstrip_loop(CMP) { \
+ void *aptr, *bptr; \
+ char atemp[SMALL_STRING], btemp[SMALL_STRING]; \
+ while(size--) { \
+ aptr = stripfunc(iself->dataptr, atemp, N1); \
+ if (!aptr) return -1; \
+ bptr = stripfunc(iother->dataptr, btemp, N2); \
+ if (!bptr) { \
+ relfunc(aptr, N1); \
+ return -1; \
+ } \
+ val = cmpfunc(aptr, bptr, N1, N2); \
+ *dptr = (val CMP 0); \
+ PyArray_ITER_NEXT(iself); \
+ PyArray_ITER_NEXT(iother); \
+ dptr += 1; \
+ relfunc(aptr, N1); \
+ relfunc(bptr, N2); \
+ } \
+ }
+
+#define _reg_loop(CMP) { \
+ while(size--) { \
+ val = cmpfunc((void *)iself->dataptr, \
+ (void *)iother->dataptr, \
+ N1, N2); \
+ *dptr = (val CMP 0); \
+ PyArray_ITER_NEXT(iself); \
+ PyArray_ITER_NEXT(iother); \
+ dptr += 1; \
+ } \
+ }
+
+#define _loop(CMP) if (rstrip) _rstrip_loop(CMP) \
+ else _reg_loop(CMP)
+
static int
_compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
- int cmp_op, void *func)
+ int cmp_op, void *func, int rstrip)
{
PyArrayIterObject *iself, *iother;
Bool *dptr;
@@ -4193,6 +4314,8 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
int val;
int N1, N2;
int (*cmpfunc)(void *, void *, int, int);
+ void (*relfunc)(char *, int);
+ char* (*stripfunc)(char *, char *, int);
cmpfunc = func;
dptr = (Bool *)PyArray_DATA(result);
@@ -4204,43 +4327,48 @@ _compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
if ((void *)cmpfunc == (void *)_myunincmp) {
N1 >>= 2;
N2 >>= 2;
+ stripfunc = _uni_copy_n_strip;
+ relfunc = _uni_release;
}
- while(size--) {
- val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr,
- N1, N2);
- switch (cmp_op) {
- case Py_EQ:
- *dptr = (val == 0);
- break;
- case Py_NE:
- *dptr = (val != 0);
- break;
- case Py_LT:
- *dptr = (val < 0);
- break;
- case Py_LE:
- *dptr = (val <= 0);
- break;
- case Py_GT:
- *dptr = (val > 0);
- break;
- case Py_GE:
- *dptr = (val >= 0);
- break;
- default:
- PyErr_SetString(PyExc_RuntimeError,
- "bad comparison operator");
- return -1;
- }
- PyArray_ITER_NEXT(iself);
- PyArray_ITER_NEXT(iother);
- dptr += 1;
+ else {
+ stripfunc = _char_copy_n_strip;
+ relfunc = _char_release;
+ }
+ switch (cmp_op) {
+ case Py_EQ:
+ _loop(==)
+ break;
+ case Py_NE:
+ _loop(!=)
+ break;
+ case Py_LT:
+ _loop(<)
+ break;
+ case Py_LE:
+ _loop(<=)
+ break;
+ case Py_GT:
+ _loop(>)
+ break;
+ case Py_GE:
+ _loop(>=)
+ break;
+ default:
+ PyErr_SetString(PyExc_RuntimeError,
+ "bad comparison operator");
+ return -1;
}
return 0;
}
+#undef _loop
+#undef _reg_loop
+#undef _rstrip_loop
+#undef SMALL_STRING
+
static PyObject *
-_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
+_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
+ int rstrip)
{
PyObject *result;
PyArrayMultiIterObject *mit;
@@ -4294,10 +4422,12 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
if (result == NULL) goto finish;
if (self->descr->type_num == PyArray_UNICODE) {
- val = _compare_strings(result, mit, cmp_op, _myunincmp);
+ val = _compare_strings(result, mit, cmp_op, _myunincmp,
+ rstrip);
}
else {
- val = _compare_strings(result, mit, cmp_op, _mystrncmp);
+ val = _compare_strings(result, mit, cmp_op, _mystrncmp,
+ rstrip);
}
if (val < 0) {Py_DECREF(result); result = NULL;}
@@ -4361,7 +4491,7 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
}
else { /* compare as a string */
/* assumes self and other have same descr->type */
- return _strings_richcompare(self, other, cmp_op);
+ return _strings_richcompare(self, other, cmp_op, 0);
}
}
@@ -4531,7 +4661,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
if (PyArray_ISSTRING(self) && PyArray_ISSTRING(array_other)) {
Py_DECREF(result);
result = _strings_richcompare(self, (PyArrayObject *)
- array_other, cmp_op);
+ array_other, cmp_op, 0);
}
Py_DECREF(array_other);
}
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index 9a219ca3e..257e93258 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -6357,6 +6357,71 @@ format_longfloat(PyObject *dummy, PyObject *args, PyObject *kwds)
return PyString_FromString(repr);
}
+static PyObject *
+compare_chararrays(PyObject *dummy, PyObject *args, PyObject *kwds)
+{
+ PyObject *array;
+ PyObject *other;
+ PyArrayObject *newarr, *newoth;
+ int cmp_op;
+ Bool rstrip;
+ char *cmp_str;
+ Py_ssize_t strlen;
+ PyObject *res=NULL;
+ static char msg[] = \
+ "comparision must be '==', '!=', '<', '>', '<=', '>='";
+
+ static char *kwlist[] = {"a1", "a2", "cmp", "rstrip", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOs#O&", kwlist,
+ &array, &other,
+ &cmp_str, &strlen,
+ PyArray_BoolConverter, &rstrip))
+ return NULL;
+
+ if (strlen < 1 || strlen > 2) goto err;
+ if (strlen > 1) {
+ if (cmp_str[1] != '=') goto err;
+ if (cmp_str[0] == '=') cmp_op = Py_EQ;
+ else if (cmp_str[0] == '!') cmp_op = Py_NE;
+ else if (cmp_str[0] == '<') cmp_op = Py_LE;
+ else if (cmp_str[0] == '>') cmp_op = Py_GE;
+ else goto err;
+ }
+ else {
+ if (cmp_str[0] == '<') cmp_op = Py_LT;
+ else if (cmp_str[0] == '>') cmp_op = Py_GT;
+ else goto err;
+ }
+
+ newarr = (PyArrayObject *)PyArray_FROM_O(array);
+ if (newarr == NULL) return NULL;
+ newoth = (PyArrayObject *)PyArray_FROM_O(other);
+ if (newoth == NULL) {
+ Py_DECREF(newarr);
+ return NULL;
+ }
+
+ if (PyArray_ISSTRING(newarr) && PyArray_ISSTRING(newoth)) {
+ res = _strings_richcompare(newarr, newoth, cmp_op, rstrip != 0);
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "comparison of non-string arrays");
+ }
+
+ Py_DECREF(newarr);
+ Py_DECREF(newoth);
+ return res;
+
+ err:
+ PyErr_SetString(PyExc_ValueError, msg);
+ return NULL;
+}
+
+
+
+
static struct PyMethodDef array_module_methods[] = {
{"_get_ndarray_c_version", (PyCFunction)array__get_ndarray_c_version,
METH_VARARGS|METH_KEYWORDS, NULL},
@@ -6409,6 +6474,8 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"format_longfloat", (PyCFunction)format_longfloat,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"compare_chararrays", (PyCFunction)compare_chararrays,
+ METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0} /* sentinel */
};
diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c
index a2b711093..d7ecb8a75 100644
--- a/numpy/core/src/ufuncobject.c
+++ b/numpy/core/src/ufuncobject.c
@@ -1761,7 +1761,7 @@ construct_reduce(PyUFuncObject *self, PyArrayObject **arr, PyArrayObject *out,
PyUFuncReduceObject *loop;
PyArrayObject *idarr;
PyArrayObject *aar;
- intp loop_i[MAX_DIMS], outsize;
+ intp loop_i[MAX_DIMS], outsize=0;
int arg_types[3] = {otype, otype, otype};
PyArray_SCALARKIND scalars[3] = {PyArray_NOSCALAR, PyArray_NOSCALAR,
PyArray_NOSCALAR};