diff options
author | Travis Oliphant <oliphant@enthought.com> | 2008-05-13 16:40:19 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2008-05-13 16:40:19 +0000 |
commit | e3d0fec968a54a144de203a25d52e059b7447065 (patch) | |
tree | c731f0ccabad52fc729e5b43e54d1360d092b5eb /numpy/core | |
parent | b6ef006ba71899914ddb2867eba0a95d081c937d (diff) | |
download | numpy-e3d0fec968a54a144de203a25d52e059b7447065.tar.gz |
Fix ticket #791.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/arrayobject.c | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 8 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 17 |
3 files changed, 22 insertions, 5 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 867b32b65..51abf8e02 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -1976,7 +1976,7 @@ PyArray_ToList(PyArrayObject *self) v=(PyArrayObject *)array_big_item(self, i); } else { - v = PySequence_GetItem((PyObject *)self, i); + v = (PyArrayObject *)PySequence_GetItem((PyObject *)self, i); if ((!PyArray_Check(v)) || (v->nd >= self->nd)) { PyErr_SetString(PyExc_RuntimeError, "array_item not returning smaller-" \ diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index d08b3fd76..9e1fcbe9e 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -849,7 +849,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out, if ((new = _check_axis(self, &axis, 0))==NULL) return NULL; /* Compute and reshape mean */ - obj1 = PyArray_EnsureArray(PyArray_Mean((PyAO *)new, axis, rtype, NULL)); + obj1 = PyArray_EnsureAnyArray(PyArray_Mean((PyAO *)new, axis, rtype, NULL)); if (obj1 == NULL) {Py_DECREF(new); return NULL;} n = PyArray_NDIM(new); newshape = PyTuple_New(n); @@ -865,7 +865,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out, if (obj2 == NULL) {Py_DECREF(new); return NULL;} /* Compute x = x - mx */ - obj1 = PyArray_EnsureArray(PyNumber_Subtract((PyObject *)new, obj2)); + obj1 = PyArray_EnsureAnyArray(PyNumber_Subtract((PyObject *)new, obj2)); Py_DECREF(obj2); if (obj1 == NULL) {Py_DECREF(new); return NULL;} @@ -878,7 +878,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out, Py_INCREF(obj1); } if (obj3 == NULL) {Py_DECREF(new); return NULL;} - obj2 = PyArray_EnsureArray \ + obj2 = PyArray_EnsureAnyArray \ (PyArray_GenericBinaryFunction((PyAO *)obj1, obj3, n_ops.multiply)); Py_DECREF(obj1); Py_DECREF(obj3); @@ -921,7 +921,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out, Py_DECREF(obj2); if (!variance) { - obj1 = PyArray_EnsureArray(ret); + obj1 = PyArray_EnsureAnyArray(ret); /* sqrt() */ ret = PyArray_GenericUnaryFunction((PyAO *)obj1, n_ops.sqrt); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index ae6c43f10..166c3269b 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -854,6 +854,23 @@ class TestView(NumpyTestCase): assert(isinstance(y,np.matrix)) assert_equal(y.dtype,np.int16) +class TestStats(NumpyTestCase): + def test_subclass(self): + class TestArray(np.ndarray): + def __new__(cls, data, info): + result = np.array(data) + result = result.view(cls) + result.info = info + return result + def __array_finalize__(self, obj): + self.info = getattr(obj, "info", '') + dat = TestArray([[1,2,3,4],[5,6,7,8]], 'jubba') + res = dat.mean(1) + assert res.info == dat.info + res = dat.std(1) + assert res.info == dat.info + res = dat.var(1) + assert res.info == dat.info # Import tests without matching module names set_local_path() |