summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
authorsasha <sasha@localhost>2006-02-26 04:14:32 +0000
committersasha <sasha@localhost>2006-02-26 04:14:32 +0000
commit4156b241aa3670f923428d4e72577a9962cdf042 (patch)
tree782d97aebdc0dbf09747ff437e84fa3c100e547b /numpy/core/src/arrayobject.c
parent013b3968457f78caf1a7185702771f6235515db9 (diff)
downloadnumpy-4156b241aa3670f923428d4e72577a9962cdf042.tar.gz
made subscripting return ndarray if ellipsis is present
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c53
1 files changed, 48 insertions, 5 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 288071cde..212064807 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -1995,13 +1995,20 @@ array_subscript(PyArrayObject *self, PyObject *op)
return NULL;
}
if (self->nd == 0) {
- if (op == Py_Ellipsis)
- return PyArray_ToScalar(self->data, self);
+ if (op == Py_Ellipsis) {
+ /* XXX: This leads to a small inconsistency
+ XXX: with the nd>0 case where (x[...] is x)
+ XXX: is false for nd>0 case. */
+ Py_INCREF(self);
+ return (PyObject *)self;
+ }
if (op == Py_None)
return add_new_axes_0d(self, 1);
if (PyTuple_Check(op)) {
- if (0 == PyTuple_GET_SIZE(op))
- return PyArray_ToScalar(self->data, self);
+ if (0 == PyTuple_GET_SIZE(op)) {
+ Py_INCREF(self);
+ return (PyObject *)self;
+ }
if ((nd = count_new_axes_0d(op)) == -1)
return NULL;
return add_new_axes_0d(self, nd);
@@ -2211,7 +2218,43 @@ array_ass_sub(PyArrayObject *self, PyObject *index, PyObject *op)
static PyObject *
array_subscript_nice(PyArrayObject *self, PyObject *op)
{
- return PyArray_Return((PyArrayObject *)array_subscript(self, op));
+ /* The following is just a copy of PyArray_Return with an
+ additional logic in the nd == 0 case. More efficient
+ implementation may be possible by refactoring
+ array_subscript */
+
+ PyArrayObject *mp = (PyArrayObject *)array_subscript(self, op);
+
+ if (mp == NULL) return NULL;
+
+ if (PyErr_Occurred()) {
+ Py_XDECREF(mp);
+ return NULL;
+ }
+
+ if (!PyArray_Check(mp)) return (PyObject *)mp;
+
+ if (mp->nd == 0) {
+ Bool noellipses = TRUE;
+ if (op == Py_Ellipsis)
+ noellipses = FALSE;
+ else if (PySequence_Check(op)) {
+ int n, i;
+ n = PySequence_Size(op);
+ for (i = 0; i < n; ++i)
+ if (PySequence_GetItem(op, i) == Py_Ellipsis) {
+ noellipses = FALSE;
+ break;
+ }
+ }
+ if (noellipses) {
+ PyObject *ret;
+ ret = PyArray_ToScalar(mp->data, mp);
+ Py_DECREF(mp);
+ return ret;
+ }
+ }
+ return (PyObject *)mp;
}