summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2008-05-13 16:40:19 +0000
committerTravis Oliphant <oliphant@enthought.com>2008-05-13 16:40:19 +0000
commite3d0fec968a54a144de203a25d52e059b7447065 (patch)
treec731f0ccabad52fc729e5b43e54d1360d092b5eb /numpy/core
parentb6ef006ba71899914ddb2867eba0a95d081c937d (diff)
downloadnumpy-e3d0fec968a54a144de203a25d52e059b7447065.tar.gz
Fix ticket #791.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/arrayobject.c2
-rw-r--r--numpy/core/src/multiarraymodule.c8
-rw-r--r--numpy/core/tests/test_multiarray.py17
3 files changed, 22 insertions, 5 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 867b32b65..51abf8e02 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -1976,7 +1976,7 @@ PyArray_ToList(PyArrayObject *self)
v=(PyArrayObject *)array_big_item(self, i);
}
else {
- v = PySequence_GetItem((PyObject *)self, i);
+ v = (PyArrayObject *)PySequence_GetItem((PyObject *)self, i);
if ((!PyArray_Check(v)) || (v->nd >= self->nd)) {
PyErr_SetString(PyExc_RuntimeError,
"array_item not returning smaller-" \
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index d08b3fd76..9e1fcbe9e 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -849,7 +849,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
if ((new = _check_axis(self, &axis, 0))==NULL) return NULL;
/* Compute and reshape mean */
- obj1 = PyArray_EnsureArray(PyArray_Mean((PyAO *)new, axis, rtype, NULL));
+ obj1 = PyArray_EnsureAnyArray(PyArray_Mean((PyAO *)new, axis, rtype, NULL));
if (obj1 == NULL) {Py_DECREF(new); return NULL;}
n = PyArray_NDIM(new);
newshape = PyTuple_New(n);
@@ -865,7 +865,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
if (obj2 == NULL) {Py_DECREF(new); return NULL;}
/* Compute x = x - mx */
- obj1 = PyArray_EnsureArray(PyNumber_Subtract((PyObject *)new, obj2));
+ obj1 = PyArray_EnsureAnyArray(PyNumber_Subtract((PyObject *)new, obj2));
Py_DECREF(obj2);
if (obj1 == NULL) {Py_DECREF(new); return NULL;}
@@ -878,7 +878,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
Py_INCREF(obj1);
}
if (obj3 == NULL) {Py_DECREF(new); return NULL;}
- obj2 = PyArray_EnsureArray \
+ obj2 = PyArray_EnsureAnyArray \
(PyArray_GenericBinaryFunction((PyAO *)obj1, obj3, n_ops.multiply));
Py_DECREF(obj1);
Py_DECREF(obj3);
@@ -921,7 +921,7 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
Py_DECREF(obj2);
if (!variance) {
- obj1 = PyArray_EnsureArray(ret);
+ obj1 = PyArray_EnsureAnyArray(ret);
/* sqrt() */
ret = PyArray_GenericUnaryFunction((PyAO *)obj1, n_ops.sqrt);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ae6c43f10..166c3269b 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -854,6 +854,23 @@ class TestView(NumpyTestCase):
assert(isinstance(y,np.matrix))
assert_equal(y.dtype,np.int16)
+class TestStats(NumpyTestCase):
+ def test_subclass(self):
+ class TestArray(np.ndarray):
+ def __new__(cls, data, info):
+ result = np.array(data)
+ result = result.view(cls)
+ result.info = info
+ return result
+ def __array_finalize__(self, obj):
+ self.info = getattr(obj, "info", '')
+ dat = TestArray([[1,2,3,4],[5,6,7,8]], 'jubba')
+ res = dat.mean(1)
+ assert res.info == dat.info
+ res = dat.std(1)
+ assert res.info == dat.info
+ res = dat.var(1)
+ assert res.info == dat.info
# Import tests without matching module names
set_local_path()