diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2022-06-08 09:52:14 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-08 09:52:14 -0600 |
commit | e05e65718b34aea0c3be393459e2d5933ab075e0 (patch) | |
tree | 47bcde37c1f781815f942bd8cddf1e902280a551 /numpy/core | |
parent | 1b7ad60e0f8d94fcb7294fb9bf8b12fa75cb43a8 (diff) | |
parent | 703a6fa6ac7f5049bb17726a1ba93f7835bf92ae (diff) | |
download | numpy-e05e65718b34aea0c3be393459e2d5933ab075e0.tar.gz |
Merge pull request #21690 from seberg/fix-21673
BUG: Prevent attempted broadcasting of 0-D output operands in ufuncs
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 14 |
2 files changed, 16 insertions, 2 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 290ed24a6..fce7d61de 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -1243,7 +1243,7 @@ try_trivial_single_output_loop(PyArrayMethod_Context *context, int op_ndim = PyArray_NDIM(op[iop]); /* Special case 0-D since we can handle broadcasting using a 0-stride */ - if (op_ndim == 0) { + if (op_ndim == 0 && iop < nin) { fixed_strides[iop] = 0; continue; } @@ -1254,7 +1254,7 @@ try_trivial_single_output_loop(PyArrayMethod_Context *context, operation_shape = PyArray_SHAPE(op[iop]); } else if (op_ndim != operation_ndim) { - return -2; /* dimension mismatch (except 0-d ops) */ + return -2; /* dimension mismatch (except 0-d input ops) */ } else if (!PyArray_CompareLists( operation_shape, PyArray_DIMS(op[iop]), op_ndim)) { diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 43306d7cf..852044d32 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -799,6 +799,20 @@ class TestUfunc: # the result would be just a scalar `5`, but is broadcast fully: assert (out == 5).all() + @pytest.mark.parametrize(["arr", "out"], [ + ([2], np.empty(())), + ([1, 2], np.empty(1)), + (np.ones((4, 3)), np.empty((4, 1)))], + ids=["(1,)->()", "(2,)->(1,)", "(4, 3)->(4, 1)"]) + def test_out_broadcast_errors(self, arr, out): + # Output is (currently) allowed to broadcast inputs, but it cannot be + # smaller than the actual result. + with pytest.raises(ValueError, match="non-broadcastable"): + np.positive(arr, out=out) + + with pytest.raises(ValueError, match="non-broadcastable"): + np.add(np.ones(()), arr, out=out) + def test_type_cast(self): msg = "type cast" a = np.arange(6, dtype='short').reshape((2, 3)) |