diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/reduction.c | 3 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 13 |
3 files changed, 23 insertions, 24 deletions
diff --git a/numpy/core/src/umath/reduction.c b/numpy/core/src/umath/reduction.c index f69aea2d0..3f2b94a4a 100644 --- a/numpy/core/src/umath/reduction.c +++ b/numpy/core/src/umath/reduction.c @@ -483,7 +483,8 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out, if (op_view == NULL) { goto fail; } - if (PyArray_SIZE(op_view) == 0) { + /* empty op_view signals no reduction; but 0-d arrays cannot be empty */ + if ((PyArray_SIZE(op_view) == 0) || (PyArray_NDIM(operand) == 0)) { Py_DECREF(op_view); op_view = NULL; goto finish; diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 4fac30b5a..a3a164731 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -3719,31 +3719,16 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, * 'prod', et al, also allow a reduction where axis=0, even * though this is technically incorrect. */ - if (operation == UFUNC_REDUCE && - (naxes == 0 || (naxes == 1 && axes[0] == 0))) { + naxes = 0; + + if (!(operation == UFUNC_REDUCE && + (naxes == 0 || (naxes == 1 && axes[0] == 0)))) { + PyErr_Format(PyExc_TypeError, "cannot %s on a scalar", + _reduce_type[operation]); Py_XDECREF(otype); - /* If there's an output parameter, copy the value */ - if (out != NULL) { - if (PyArray_CopyInto(out, mp) < 0) { - Py_DECREF(mp); - return NULL; - } - else { - Py_DECREF(mp); - Py_INCREF(out); - return (PyObject *)out; - } - } - /* Otherwise return the array unscathed */ - else { - return PyArray_Return(mp); - } + Py_DECREF(mp); + return NULL; } - PyErr_Format(PyExc_TypeError, "cannot %s on a scalar", - _reduce_type[operation]); - Py_XDECREF(otype); - Py_DECREF(mp); - return NULL; } /* diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index a59db5562..3005da8da 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -565,12 +565,25 @@ class TestUfunc(TestCase): assert_equal(np.max(3, axis=0), 3) assert_equal(np.min(2.5, axis=0), 2.5) + # Check scalar behaviour for ufuncs without an identity + assert_equal(np.power.reduce(3), 3) + # Make sure that scalars are coming out from this operation assert_(type(np.prod(np.float32(2.5), axis=0)) is np.float32) assert_(type(np.sum(np.float32(2.5), axis=0)) is np.float32) assert_(type(np.max(np.float32(2.5), axis=0)) is np.float32) assert_(type(np.min(np.float32(2.5), axis=0)) is np.float32) + # check if scalars/0-d arrays get cast + assert_(type(np.any(0, axis=0)) is np.bool_) + + # assert that 0-d arrays get wrapped + class MyArray(np.ndarray): + pass + a = np.array(1).view(MyArray) + assert_(type(np.any(a)) is MyArray) + + def test_casting_out_param(self): # Test that it's possible to do casts on output a = np.ones((200,100), np.int64) |