summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-04-12 22:49:00 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-04-12 22:49:00 +0000
commitec554d0caee4fb34122533eb630e7130e5568db3 (patch)
tree0d5c417ef355aaf07da68aa66c324f8e82d24f20 /numpy/core
parent5534a9987fa9b916a4c035220a3fb488bd315753 (diff)
downloadnumpy-ec554d0caee4fb34122533eb630e7130e5568db3.tar.gz
Fixed where to always return a tuple
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/oldnumeric.py11
-rw-r--r--numpy/core/src/arraymethods.c2
-rw-r--r--numpy/core/src/multiarraymodule.c28
3 files changed, 22 insertions, 19 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py
index b8685d69e..4674cfdd9 100644
--- a/numpy/core/oldnumeric.py
+++ b/numpy/core/oldnumeric.py
@@ -366,8 +366,15 @@ def nonzero(a):
try:
nonzero = a.nonzero
except AttributeError:
- return _wrapit(a, 'nonzero')
- return nonzero()
+ res = _wrapit(a, 'nonzero')
+ else:
+ res = nonzero()
+
+ if len(res) == 1:
+ return res[0]
+ else:
+ raise ValueError, "Input argument must be 1d"
+
def shape(a):
"""shape(a) returns the shape of a (as a function call which
diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c
index cf4a7faae..3af56f39d 100644
--- a/numpy/core/src/arraymethods.c
+++ b/numpy/core/src/arraymethods.c
@@ -1382,7 +1382,7 @@ array_nonzero(PyArrayObject *self, PyObject *args)
{
if (!PyArg_ParseTuple(args, "")) return NULL;
- return _ARET(PyArray_Nonzero(self));
+ return PyArray_Nonzero(self);
}
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index da314d3c8..a1a5fe910 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -811,7 +811,8 @@ PyArray_Compress(PyArrayObject *self, PyObject *condition, int axis)
res = PyArray_Nonzero(cond);
Py_DECREF(cond);
- ret = PyArray_Take(self, res, axis);
+ if (res == NULL) return res;
+ ret = PyArray_Take(self, PyTuple_GET_ITEM(res, 0), axis);
Py_DECREF(res);
return ret;
}
@@ -838,12 +839,17 @@ PyArray_Nonzero(PyArrayObject *self)
}
PyArray_ITER_RESET(it);
+ ret = PyTuple_New(n);
+ if (ret == NULL) goto fail;
+ for (j=0; j<n; j++) {
+ item = PyArray_New(self->ob_type, 1, &count,
+ PyArray_INTP, NULL, NULL, 0, 0,
+ (PyObject *)self);
+ if (item == NULL) goto fail;
+ PyTuple_SET_ITEM(ret, j, item);
+ dptr[j] = (intp *)PyArray_DATA(item);
+ }
if (n==1) {
- ret = PyArray_New(self->ob_type, 1, &count, PyArray_INTP,
- NULL, NULL, 0, 0, (PyObject *)self);
- if (ret == NULL) goto fail;
- dptr[0] = (intp *)PyArray_DATA(ret);
-
for (i=0; i<size; i++) {
if (self->descr->f->nonzero(it->dataptr, self))
*(dptr[0])++ = i;
@@ -851,16 +857,6 @@ PyArray_Nonzero(PyArrayObject *self)
}
}
else {
- ret = PyTuple_New(n);
- for (j=0; j<n; j++) {
- item = PyArray_New(self->ob_type, 1, &count,
- PyArray_INTP, NULL, NULL, 0, 0,
- (PyObject *)self);
- if (item == NULL) goto fail;
- PyTuple_SET_ITEM(ret, j, item);
- dptr[j] = (intp *)PyArray_DATA(item);
- }
-
/* reset contiguous so that coordinates gets updated */
it->contiguous = 0;
for (i=0; i<size; i++) {