summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2010-07-18 16:52:48 +0000
committerPauli Virtanen <pav@iki.fi>2010-07-18 16:52:48 +0000
commit84c8c655652ca8cca087aca0512521800e29f314 (patch)
tree226606cd5bd81eea96e3802c69e4190bb10ff065 /numpy/core
parent73b8b54cdf0057850cefc3b2cb0be4fad8e70fb8 (diff)
downloadnumpy-84c8c655652ca8cca087aca0512521800e29f314.tar.gz
BUG: core: make .std() and .var() respect the out= keyword (#1434)
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/calculation.c11
-rw-r--r--numpy/core/tests/test_regression.py13
2 files changed, 21 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c
index 22225347d..bc078f097 100644
--- a/numpy/core/src/multiarray/calculation.c
+++ b/numpy/core/src/multiarray/calculation.c
@@ -393,11 +393,14 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
ret = PyArray_GenericUnaryFunction((PyAO *)obj1, n_ops.sqrt);
Py_DECREF(obj1);
}
- if (ret == NULL || PyArray_CheckExact(self)) {
- return ret;
+ if (ret == NULL) {
+ return NULL;
+ }
+ if (PyArray_CheckExact(self)) {
+ goto finish;
}
if (PyArray_Check(self) && Py_TYPE(self) == Py_TYPE(ret)) {
- return ret;
+ goto finish;
}
obj1 = PyArray_EnsureArray(ret);
if (obj1 == NULL) {
@@ -405,6 +408,8 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
}
ret = PyArray_View((PyAO *)obj1, NULL, Py_TYPE(self));
Py_DECREF(obj1);
+
+finish:
if (out) {
if (PyArray_CopyAnyInto(out, (PyArrayObject *)ret) < 0) {
Py_DECREF(ret);
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 2aac2fb1c..01cf5dad4 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1325,5 +1325,18 @@ class TestRegression(TestCase):
assert_equal(type(getattr(x, name)), np.float32,
err_msg=name)
+ def test_ticket_1434(self):
+ # Check that the out= argument in var and std has an effect
+ data = np.array(((1,2,3),(4,5,6),(7,8,9)))
+ out = np.zeros((3,))
+
+ ret = data.var(axis=1, out=out)
+ assert_(ret is out)
+ assert_array_equal(ret, data.var(axis=1))
+
+ ret = data.std(axis=1, out=out)
+ assert_(ret is out)
+ assert_array_equal(ret, data.std(axis=1))
+
if __name__ == "__main__":
run_module_suite()