diff options
author | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2022-03-02 19:07:14 +0100 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-02 00:29:32 +0100 |
commit | 402545b801aedd7034d4a861d3e4b3d1fc61e2ce (patch) | |
tree | 1d8ce906e61312e14cdc876eeb6f23b5ed84a56e /numpy | |
parent | 53ad6ec73884a86e53231164f8ade8ce0053c10f (diff) | |
download | numpy-402545b801aedd7034d4a861d3e4b3d1fc61e2ce.tar.gz |
MAINT: Explicitly raise when `a @= b` would otherwise broadcast the output
Add special casing for `1d @ 1d` and `2d @ 1d` ops.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 25 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 3 |
2 files changed, 26 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index a149ed7cd..0b5b46fc5 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -351,9 +351,30 @@ array_matrix_multiply(PyObject *m1, PyObject *m2) static PyObject * array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2) { - INPLACE_GIVE_UP_IF_NEEDED(m1, m2, + PyArrayObject *m2_array; + int m1_ndim; + int m2_ndim; + + /* Explicitly raise a ValueError when the output would + * otherwise be broadcasted to `m1`. Three conditions must be met: + * * `m1.ndim in [1, 2]` + * * `m2.ndim == 1` + * * `m1.shape[-1] == m2.shape[0]` + */ + m2_array = (PyArrayObject *)PyArray_FromAny(m2, NULL, 0, 0, 0, NULL); + m1_ndim = PyArray_NDIM(m1); + m2_ndim = PyArray_NDIM(m2_array); + if (((m1_ndim == 1) || (m1_ndim == 2)) && (m2_ndim == 1) + && (PyArray_DIMS(m1)[m1_ndim - 1] == PyArray_DIMS(m2_array)[0])) { + PyErr_Format(PyExc_ValueError, + "output parameter has the wrong number of dimensions: " + "Found %d but expected %d", m1_ndim - 1, m1_ndim); + return NULL; + } + + INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array, nb_inplace_matrix_multiply, array_inplace_matrix_multiply); - return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.matmul); + return PyArray_GenericInplaceBinaryFunction(m1, (PyObject *)m2_array, n_ops.matmul); } /* diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index a2979096b..baf1bf40f 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -7185,6 +7185,9 @@ class TestMatmulInplace: SHAPES = { "2d_large": ((10**5, 10), (10, 10)), "3d_large": ((10**4, 10, 10), (1, 10, 10)), + "1d": ((3,), (3,)), + "2d_1d": ((3, 3), (3,)), + "1d_2d": ((3,), (3, 3)), "2d_broadcast": ((3, 3), (3, 1)), "2d_broadcast_reverse": ((1, 3), (3, 3)), "3d_broadcast1": ((3, 3, 3), (1, 3, 1)), |