summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2022-03-02 19:07:14 +0100
committerSebastian Berg <sebastianb@nvidia.com>2022-12-02 00:29:32 +0100
commit402545b801aedd7034d4a861d3e4b3d1fc61e2ce (patch)
tree1d8ce906e61312e14cdc876eeb6f23b5ed84a56e /numpy
parent53ad6ec73884a86e53231164f8ade8ce0053c10f (diff)
downloadnumpy-402545b801aedd7034d4a861d3e4b3d1fc61e2ce.tar.gz
MAINT: Explicitly raise when `a @= b` would otherwise broadcast the output
Add special casing for `1d @ 1d` and `2d @ 1d` ops.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/number.c25
-rw-r--r--numpy/core/tests/test_multiarray.py3
2 files changed, 26 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index a149ed7cd..0b5b46fc5 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -351,9 +351,30 @@ array_matrix_multiply(PyObject *m1, PyObject *m2)
static PyObject *
array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
{
- INPLACE_GIVE_UP_IF_NEEDED(m1, m2,
+ PyArrayObject *m2_array;
+ int m1_ndim;
+ int m2_ndim;
+
+ /* Explicitly raise a ValueError when the output would
+ * otherwise be broadcasted to `m1`. Three conditions must be met:
+ * * `m1.ndim in [1, 2]`
+ * * `m2.ndim == 1`
+ * * `m1.shape[-1] == m2.shape[0]`
+ */
+ m2_array = (PyArrayObject *)PyArray_FromAny(m2, NULL, 0, 0, 0, NULL);
+ m1_ndim = PyArray_NDIM(m1);
+ m2_ndim = PyArray_NDIM(m2_array);
+ if (((m1_ndim == 1) || (m1_ndim == 2)) && (m2_ndim == 1)
+ && (PyArray_DIMS(m1)[m1_ndim - 1] == PyArray_DIMS(m2_array)[0])) {
+ PyErr_Format(PyExc_ValueError,
+ "output parameter has the wrong number of dimensions: "
+ "Found %d but expected %d", m1_ndim - 1, m1_ndim);
+ return NULL;
+ }
+
+ INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array,
nb_inplace_matrix_multiply, array_inplace_matrix_multiply);
- return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.matmul);
+ return PyArray_GenericInplaceBinaryFunction(m1, (PyObject *)m2_array, n_ops.matmul);
}
/*
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index a2979096b..baf1bf40f 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -7185,6 +7185,9 @@ class TestMatmulInplace:
SHAPES = {
"2d_large": ((10**5, 10), (10, 10)),
"3d_large": ((10**4, 10, 10), (1, 10, 10)),
+ "1d": ((3,), (3,)),
+ "2d_1d": ((3, 3), (3,)),
+ "1d_2d": ((3,), (3, 3)),
"2d_broadcast": ((3, 3), (3, 1)),
"2d_broadcast_reverse": ((1, 3), (3, 3)),
"3d_broadcast1": ((3, 3, 3), (1, 3, 1)),