diff options
author | sasha <sasha@localhost> | 2006-02-26 04:14:32 +0000 |
---|---|---|
committer | sasha <sasha@localhost> | 2006-02-26 04:14:32 +0000 |
commit | 4156b241aa3670f923428d4e72577a9962cdf042 (patch) | |
tree | 782d97aebdc0dbf09747ff437e84fa3c100e547b /numpy/core | |
parent | 013b3968457f78caf1a7185702771f6235515db9 (diff) | |
download | numpy-4156b241aa3670f923428d4e72577a9962cdf042.tar.gz |
made subscripting return ndarray if ellipsis is present
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/arrayobject.c | 53 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 14 |
2 files changed, 58 insertions, 9 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 288071cde..212064807 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -1995,13 +1995,20 @@ array_subscript(PyArrayObject *self, PyObject *op) return NULL; } if (self->nd == 0) { - if (op == Py_Ellipsis) - return PyArray_ToScalar(self->data, self); + if (op == Py_Ellipsis) { + /* XXX: This leads to a small inconsistency + XXX: with the nd>0 case where (x[...] is x) + XXX: is false for nd>0 case. */ + Py_INCREF(self); + return (PyObject *)self; + } if (op == Py_None) return add_new_axes_0d(self, 1); if (PyTuple_Check(op)) { - if (0 == PyTuple_GET_SIZE(op)) - return PyArray_ToScalar(self->data, self); + if (0 == PyTuple_GET_SIZE(op)) { + Py_INCREF(self); + return (PyObject *)self; + } if ((nd = count_new_axes_0d(op)) == -1) return NULL; return add_new_axes_0d(self, nd); @@ -2211,7 +2218,43 @@ array_ass_sub(PyArrayObject *self, PyObject *index, PyObject *op) static PyObject * array_subscript_nice(PyArrayObject *self, PyObject *op) { - return PyArray_Return((PyArrayObject *)array_subscript(self, op)); + /* The following is just a copy of PyArray_Return with an + additional logic in the nd == 0 case. More efficient + implementation may be possible by refactoring + array_subscript */ + + PyArrayObject *mp = (PyArrayObject *)array_subscript(self, op); + + if (mp == NULL) return NULL; + + if (PyErr_Occurred()) { + Py_XDECREF(mp); + return NULL; + } + + if (!PyArray_Check(mp)) return (PyObject *)mp; + + if (mp->nd == 0) { + Bool noellipses = TRUE; + if (op == Py_Ellipsis) + noellipses = FALSE; + else if (PySequence_Check(op)) { + int n, i; + n = PySequence_Size(op); + for (i = 0; i < n; ++i) + if (PySequence_GetItem(op, i) == Py_Ellipsis) { + noellipses = FALSE; + break; + } + } + if (noellipses) { + PyObject *ret; + ret = PyArray_ToScalar(mp->data, mp); + Py_DECREF(mp); + return ret; + } + } + return (PyObject *)mp; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index ff7f25ac0..0ec264407 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -117,8 +117,8 @@ class test_zero_rank(ScipyTestCase): a,b = self.d self.failUnlessEqual(a[...], 0) self.failUnlessEqual(b[...], 'x') - self.failUnless(type(a[...]) is a.dtype.type) - self.failUnless(type(b[...]) is str) + self.failUnless(a[...] is a) + self.failUnless(b[...] is b) def check_empty_subscript(self): a,b = self.d @@ -197,7 +197,7 @@ class test_bool(ScipyTestCase): b1 = bool_(True) self.failUnless(a1 is b1) self.failUnless(array([True])[0] is a1) - self.failUnless(array(True)[...] is a1) + self.failUnless(array(True)[()] is a1) class test_methods(ScipyTestCase): @@ -207,7 +207,13 @@ class test_methods(ScipyTestCase): assert_equal(array([12.2,15.5]).round(-1), [10,20]) assert_equal(array([12.15,15.51]).round(1), [12.2,15.5]) - + +class test_subscripting(ScipyTestCase): + def check_test_zero_rank(self): + x = array([1,2,3]) + self.failUnless(isinstance(x[0], int)) + self.failUnless(type(x[0, ...]) is ndarray) + # Import tests from unicode set_local_path() from test_unicode import * |