summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorsasha <sasha@localhost>2006-02-26 04:14:32 +0000
committersasha <sasha@localhost>2006-02-26 04:14:32 +0000
commit4156b241aa3670f923428d4e72577a9962cdf042 (patch)
tree782d97aebdc0dbf09747ff437e84fa3c100e547b /numpy/core
parent013b3968457f78caf1a7185702771f6235515db9 (diff)
downloadnumpy-4156b241aa3670f923428d4e72577a9962cdf042.tar.gz
made subscripting return ndarray if ellipsis is present
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/arrayobject.c53
-rw-r--r--numpy/core/tests/test_multiarray.py14
2 files changed, 58 insertions, 9 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 288071cde..212064807 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -1995,13 +1995,20 @@ array_subscript(PyArrayObject *self, PyObject *op)
return NULL;
}
if (self->nd == 0) {
- if (op == Py_Ellipsis)
- return PyArray_ToScalar(self->data, self);
+ if (op == Py_Ellipsis) {
+ /* XXX: This leads to a small inconsistency
+ XXX: with the nd>0 case where (x[...] is x)
+ XXX: is false for nd>0 case. */
+ Py_INCREF(self);
+ return (PyObject *)self;
+ }
if (op == Py_None)
return add_new_axes_0d(self, 1);
if (PyTuple_Check(op)) {
- if (0 == PyTuple_GET_SIZE(op))
- return PyArray_ToScalar(self->data, self);
+ if (0 == PyTuple_GET_SIZE(op)) {
+ Py_INCREF(self);
+ return (PyObject *)self;
+ }
if ((nd = count_new_axes_0d(op)) == -1)
return NULL;
return add_new_axes_0d(self, nd);
@@ -2211,7 +2218,43 @@ array_ass_sub(PyArrayObject *self, PyObject *index, PyObject *op)
static PyObject *
array_subscript_nice(PyArrayObject *self, PyObject *op)
{
- return PyArray_Return((PyArrayObject *)array_subscript(self, op));
+ /* The following is just a copy of PyArray_Return with an
+ additional logic in the nd == 0 case. More efficient
+ implementation may be possible by refactoring
+ array_subscript */
+
+ PyArrayObject *mp = (PyArrayObject *)array_subscript(self, op);
+
+ if (mp == NULL) return NULL;
+
+ if (PyErr_Occurred()) {
+ Py_XDECREF(mp);
+ return NULL;
+ }
+
+ if (!PyArray_Check(mp)) return (PyObject *)mp;
+
+ if (mp->nd == 0) {
+ Bool noellipses = TRUE;
+ if (op == Py_Ellipsis)
+ noellipses = FALSE;
+ else if (PySequence_Check(op)) {
+ int n, i;
+ n = PySequence_Size(op);
+ for (i = 0; i < n; ++i)
+ if (PySequence_GetItem(op, i) == Py_Ellipsis) {
+ noellipses = FALSE;
+ break;
+ }
+ }
+ if (noellipses) {
+ PyObject *ret;
+ ret = PyArray_ToScalar(mp->data, mp);
+ Py_DECREF(mp);
+ return ret;
+ }
+ }
+ return (PyObject *)mp;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ff7f25ac0..0ec264407 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -117,8 +117,8 @@ class test_zero_rank(ScipyTestCase):
a,b = self.d
self.failUnlessEqual(a[...], 0)
self.failUnlessEqual(b[...], 'x')
- self.failUnless(type(a[...]) is a.dtype.type)
- self.failUnless(type(b[...]) is str)
+ self.failUnless(a[...] is a)
+ self.failUnless(b[...] is b)
def check_empty_subscript(self):
a,b = self.d
@@ -197,7 +197,7 @@ class test_bool(ScipyTestCase):
b1 = bool_(True)
self.failUnless(a1 is b1)
self.failUnless(array([True])[0] is a1)
- self.failUnless(array(True)[...] is a1)
+ self.failUnless(array(True)[()] is a1)
class test_methods(ScipyTestCase):
@@ -207,7 +207,13 @@ class test_methods(ScipyTestCase):
assert_equal(array([12.2,15.5]).round(-1), [10,20])
assert_equal(array([12.15,15.51]).round(1), [12.2,15.5])
-
+
+class test_subscripting(ScipyTestCase):
+ def check_test_zero_rank(self):
+ x = array([1,2,3])
+ self.failUnless(isinstance(x[0], int))
+ self.failUnless(type(x[0, ...]) is ndarray)
+
# Import tests from unicode
set_local_path()
from test_unicode import *