summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/oldnumeric.py11
-rw-r--r--numpy/core/src/arraymethods.c2
-rw-r--r--numpy/core/src/multiarraymodule.c28
-rw-r--r--numpy/lib/arraysetops.py2
4 files changed, 23 insertions, 20 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py
index b8685d69e..4674cfdd9 100644
--- a/numpy/core/oldnumeric.py
+++ b/numpy/core/oldnumeric.py
@@ -366,8 +366,15 @@ def nonzero(a):
try:
nonzero = a.nonzero
except AttributeError:
- return _wrapit(a, 'nonzero')
- return nonzero()
+ res = _wrapit(a, 'nonzero')
+ else:
+ res = nonzero()
+
+ if len(res) == 1:
+ return res[0]
+ else:
+ raise ValueError, "Input argument must be 1d"
+
def shape(a):
"""shape(a) returns the shape of a (as a function call which
diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c
index cf4a7faae..3af56f39d 100644
--- a/numpy/core/src/arraymethods.c
+++ b/numpy/core/src/arraymethods.c
@@ -1382,7 +1382,7 @@ array_nonzero(PyArrayObject *self, PyObject *args)
{
if (!PyArg_ParseTuple(args, "")) return NULL;
- return _ARET(PyArray_Nonzero(self));
+ return PyArray_Nonzero(self);
}
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index da314d3c8..a1a5fe910 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -811,7 +811,8 @@ PyArray_Compress(PyArrayObject *self, PyObject *condition, int axis)
res = PyArray_Nonzero(cond);
Py_DECREF(cond);
- ret = PyArray_Take(self, res, axis);
+ if (res == NULL) return res;
+ ret = PyArray_Take(self, PyTuple_GET_ITEM(res, 0), axis);
Py_DECREF(res);
return ret;
}
@@ -838,12 +839,17 @@ PyArray_Nonzero(PyArrayObject *self)
}
PyArray_ITER_RESET(it);
+ ret = PyTuple_New(n);
+ if (ret == NULL) goto fail;
+ for (j=0; j<n; j++) {
+ item = PyArray_New(self->ob_type, 1, &count,
+ PyArray_INTP, NULL, NULL, 0, 0,
+ (PyObject *)self);
+ if (item == NULL) goto fail;
+ PyTuple_SET_ITEM(ret, j, item);
+ dptr[j] = (intp *)PyArray_DATA(item);
+ }
if (n==1) {
- ret = PyArray_New(self->ob_type, 1, &count, PyArray_INTP,
- NULL, NULL, 0, 0, (PyObject *)self);
- if (ret == NULL) goto fail;
- dptr[0] = (intp *)PyArray_DATA(ret);
-
for (i=0; i<size; i++) {
if (self->descr->f->nonzero(it->dataptr, self))
*(dptr[0])++ = i;
@@ -851,16 +857,6 @@ PyArray_Nonzero(PyArrayObject *self)
}
}
else {
- ret = PyTuple_New(n);
- for (j=0; j<n; j++) {
- item = PyArray_New(self->ob_type, 1, &count,
- PyArray_INTP, NULL, NULL, 0, 0,
- (PyObject *)self);
- if (item == NULL) goto fail;
- PyTuple_SET_ITEM(ret, j, item);
- dptr[j] = (intp *)PyArray_DATA(item);
- }
-
/* reset contiguous so that coordinates gets updated */
it->contiguous = 0;
for (i=0; i<size; i++) {
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index 5b88acd6c..7bd666029 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -117,7 +117,7 @@ def setmember1d( ar1, ar2 ):
aux2 = tt.take(perm)
flag = ediff1d( aux, 1 ) == 0
- ii = numpy.where( flag * aux2 )
+ ii = numpy.where( flag * aux2 )[0]
aux = perm[ii+1]
perm[ii+1] = perm[ii]
perm[ii] = aux