summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2022-06-08 09:52:14 -0600
committerGitHub <noreply@github.com>2022-06-08 09:52:14 -0600
commite05e65718b34aea0c3be393459e2d5933ab075e0 (patch)
tree47bcde37c1f781815f942bd8cddf1e902280a551 /numpy/core
parent1b7ad60e0f8d94fcb7294fb9bf8b12fa75cb43a8 (diff)
parent703a6fa6ac7f5049bb17726a1ba93f7835bf92ae (diff)
downloadnumpy-e05e65718b34aea0c3be393459e2d5933ab075e0.tar.gz
Merge pull request #21690 from seberg/fix-21673
BUG: Prevent attempted broadcasting of 0-D output operands in ufuncs
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/umath/ufunc_object.c4
-rw-r--r--numpy/core/tests/test_ufunc.py14
2 files changed, 16 insertions, 2 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 290ed24a6..fce7d61de 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1243,7 +1243,7 @@ try_trivial_single_output_loop(PyArrayMethod_Context *context,
int op_ndim = PyArray_NDIM(op[iop]);
/* Special case 0-D since we can handle broadcasting using a 0-stride */
- if (op_ndim == 0) {
+ if (op_ndim == 0 && iop < nin) {
fixed_strides[iop] = 0;
continue;
}
@@ -1254,7 +1254,7 @@ try_trivial_single_output_loop(PyArrayMethod_Context *context,
operation_shape = PyArray_SHAPE(op[iop]);
}
else if (op_ndim != operation_ndim) {
- return -2; /* dimension mismatch (except 0-d ops) */
+ return -2; /* dimension mismatch (except 0-d input ops) */
}
else if (!PyArray_CompareLists(
operation_shape, PyArray_DIMS(op[iop]), op_ndim)) {
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 43306d7cf..852044d32 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -799,6 +799,20 @@ class TestUfunc:
# the result would be just a scalar `5`, but is broadcast fully:
assert (out == 5).all()
+ @pytest.mark.parametrize(["arr", "out"], [
+ ([2], np.empty(())),
+ ([1, 2], np.empty(1)),
+ (np.ones((4, 3)), np.empty((4, 1)))],
+ ids=["(1,)->()", "(2,)->(1,)", "(4, 3)->(4, 1)"])
+ def test_out_broadcast_errors(self, arr, out):
+ # Output is (currently) allowed to broadcast inputs, but it cannot be
+ # smaller than the actual result.
+ with pytest.raises(ValueError, match="non-broadcastable"):
+ np.positive(arr, out=out)
+
+ with pytest.raises(ValueError, match="non-broadcastable"):
+ np.add(np.ones(()), arr, out=out)
+
def test_type_cast(self):
msg = "type cast"
a = np.arange(6, dtype='short').reshape((2, 3))