diff options
author | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2022-03-02 19:07:14 +0100 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-02 00:29:32 +0100 |
commit | 402545b801aedd7034d4a861d3e4b3d1fc61e2ce (patch) | |
tree | 1d8ce906e61312e14cdc876eeb6f23b5ed84a56e /numpy/core/src | |
parent | 53ad6ec73884a86e53231164f8ade8ce0053c10f (diff) | |
download | numpy-402545b801aedd7034d4a861d3e4b3d1fc61e2ce.tar.gz |
MAINT: Explicitly raise when `a @= b` would otherwise broadcast the output
Add special casing for `1d @ 1d` and `2d @ 1d` ops.
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index a149ed7cd..0b5b46fc5 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -351,9 +351,30 @@ array_matrix_multiply(PyObject *m1, PyObject *m2) static PyObject * array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2) { - INPLACE_GIVE_UP_IF_NEEDED(m1, m2, + PyArrayObject *m2_array; + int m1_ndim; + int m2_ndim; + + /* Explicitly raise a ValueError when the output would + * otherwise be broadcasted to `m1`. Three conditions must be met: + * * `m1.ndim in [1, 2]` + * * `m2.ndim == 1` + * * `m1.shape[-1] == m2.shape[0]` + */ + m2_array = (PyArrayObject *)PyArray_FromAny(m2, NULL, 0, 0, 0, NULL); + m1_ndim = PyArray_NDIM(m1); + m2_ndim = PyArray_NDIM(m2_array); + if (((m1_ndim == 1) || (m1_ndim == 2)) && (m2_ndim == 1) + && (PyArray_DIMS(m1)[m1_ndim - 1] == PyArray_DIMS(m2_array)[0])) { + PyErr_Format(PyExc_ValueError, + "output parameter has the wrong number of dimensions: " + "Found %d but expected %d", m1_ndim - 1, m1_ndim); + return NULL; + } + + INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array, nb_inplace_matrix_multiply, array_inplace_matrix_multiply); - return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.matmul); + return PyArray_GenericInplaceBinaryFunction(m1, (PyObject *)m2_array, n_ops.matmul); } /* |