summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2023-03-26 14:32:25 +0300
committerGitHub <noreply@github.com>2023-03-26 14:32:25 +0300
commita37978a106073eaec5cb9e0cb54785fafb639650 (patch)
treeac2ef9d4bbcf603f500051237f9c9ed318df9ff7
parent1e292ff08d3b797fa69b889c4a3fed99970308c8 (diff)
parent0bbe7dbf267cf835efcd514283815292bd94403f (diff)
downloadnumpy-a37978a106073eaec5cb9e0cb54785fafb639650.tar.gz
Merge pull request #21120 from BvB93/matmul
ENH: Add support for inplace matrix multiplication
-rw-r--r--doc/release/upcoming_changes/21120.new_feature.rst21
-rw-r--r--numpy/__init__.pyi14
-rw-r--r--numpy/array_api/_array_object.py14
-rw-r--r--numpy/core/src/multiarray/number.c72
-rw-r--r--numpy/core/tests/test_multiarray.py77
5 files changed, 167 insertions, 31 deletions
diff --git a/doc/release/upcoming_changes/21120.new_feature.rst b/doc/release/upcoming_changes/21120.new_feature.rst
new file mode 100644
index 000000000..7d4dbf743
--- /dev/null
+++ b/doc/release/upcoming_changes/21120.new_feature.rst
@@ -0,0 +1,21 @@
+Add support for inplace matrix multiplication
+----------------------------------------------
+It is now possible to perform inplace matrix multiplication
+via the ``@=`` operator.
+
+.. code-block:: python
+
+ >>> import numpy as np
+
+ >>> a = np.arange(6).reshape(3, 2)
+ >>> print(a)
+ [[0 1]
+ [2 3]
+ [4 5]]
+
+ >>> b = np.ones((2, 2), dtype=int)
+ >>> a @= b
+ >>> print(a)
+ [[1 1]
+ [5 5]
+ [9 9]]
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index ee5fbb601..8627f6c60 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1921,7 +1921,6 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __neg__(self: NDArray[object_]) -> Any: ...
# Binary ops
- # NOTE: `ndarray` does not implement `__imatmul__`
@overload
def __matmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
@overload
@@ -2508,6 +2507,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
+ @overload
+ def __imatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...
+ @overload
+ def __imatmul__(self: NDArray[unsignedinteger[_NBit1]], other: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[_NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co) -> NDArray[floating[_NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
+
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
def __dlpack_device__(self) -> tuple[int, L[0]]: ...
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index eee117be6..a949b5977 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -850,23 +850,13 @@ class Array:
"""
Performs the operation __imatmul__.
"""
- # Note: NumPy does not implement __imatmul__.
-
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
if other is NotImplemented:
return other
-
- # __imatmul__ can only be allowed when it would not change the shape
- # of self.
- other_shape = other.shape
- if self.shape == () or other_shape == ():
- raise ValueError("@= requires at least one dimension")
- if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
- raise ValueError("@= cannot change the shape of the input array")
- self._array[:] = self._array.__matmul__(other._array)
- return self
+ res = self._array.__imatmul__(other._array)
+ return self.__class__._new(res)
def __rmatmul__(self: Array, other: Array, /) -> Array:
"""
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 2e25152d5..c208fb203 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -53,6 +53,8 @@ static PyObject *
array_inplace_remainder(PyArrayObject *m1, PyObject *m2);
static PyObject *
array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo));
+static PyObject *
+array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2);
/*
* Dictionary can contain any of the numeric operations, by name.
@@ -339,7 +341,6 @@ array_divmod(PyObject *m1, PyObject *m2)
return PyArray_GenericBinaryFunction(m1, m2, n_ops.divmod);
}
-/* Need this to be version dependent on account of the slot check */
static PyObject *
array_matrix_multiply(PyObject *m1, PyObject *m2)
{
@@ -348,13 +349,70 @@ array_matrix_multiply(PyObject *m1, PyObject *m2)
}
static PyObject *
-array_inplace_matrix_multiply(
- PyArrayObject *NPY_UNUSED(m1), PyObject *NPY_UNUSED(m2))
+array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other)
{
- PyErr_SetString(PyExc_TypeError,
- "In-place matrix multiplication is not (yet) supported. "
- "Use 'a = a @ b' instead of 'a @= b'.");
- return NULL;
+ static PyObject *AxisError_cls = NULL;
+ npy_cache_import("numpy.exceptions", "AxisError", &AxisError_cls);
+ if (AxisError_cls == NULL) {
+ return NULL;
+ }
+
+ INPLACE_GIVE_UP_IF_NEEDED(self, other,
+ nb_inplace_matrix_multiply, array_inplace_matrix_multiply);
+
+ /*
+ * Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
+ * if the result without `out` would have less dimensions than `a`.
+ * Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
+ * case exactly when the second operand has both core dimensions.
+ *
+ * The error here will be confusing, but for now, we enforce this by
+ * passing the correct `axes=`.
+ */
+ static PyObject *axes_1d_obj_kwargs = NULL;
+ static PyObject *axes_2d_obj_kwargs = NULL;
+ if (NPY_UNLIKELY(axes_1d_obj_kwargs == NULL)) {
+ axes_1d_obj_kwargs = Py_BuildValue(
+ "{s, [(i), (i, i), (i)]}", "axes", -1, -2, -1, -1);
+ if (axes_1d_obj_kwargs == NULL) {
+ return NULL;
+ }
+ }
+ if (NPY_UNLIKELY(axes_2d_obj_kwargs == NULL)) {
+ axes_2d_obj_kwargs = Py_BuildValue(
+ "{s, [(i, i), (i, i), (i, i)]}", "axes", -2, -1, -2, -1, -2, -1);
+ if (axes_2d_obj_kwargs == NULL) {
+ return NULL;
+ }
+ }
+
+ PyObject *args = PyTuple_Pack(3, self, other, self);
+ if (args == NULL) {
+ return NULL;
+ }
+ PyObject *kwargs;
+ if (PyArray_NDIM(self) == 1) {
+ kwargs = axes_1d_obj_kwargs;
+ }
+ else {
+ kwargs = axes_2d_obj_kwargs;
+ }
+ PyObject *res = PyObject_Call(n_ops.matmul, args, kwargs);
+ Py_DECREF(args);
+
+ if (res == NULL) {
+ /*
+ * AxisError should indicate that the axes argument didn't work out
+ * which should mean the second operand not being 2 dimensional.
+ */
+ if (PyErr_ExceptionMatches(AxisError_cls)) {
+ PyErr_SetString(PyExc_ValueError,
+ "inplace matrix multiplication requires the first operand to "
+ "have at least one and the second at least two dimensions.");
+ }
+ }
+
+ return res;
}
/*
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 94f6ef7ad..77f651659 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import collections.abc
import tempfile
import sys
@@ -3729,7 +3731,7 @@ class TestBinop:
'and': (np.bitwise_and, True, int),
'xor': (np.bitwise_xor, True, int),
'or': (np.bitwise_or, True, int),
- 'matmul': (np.matmul, False, float),
+ 'matmul': (np.matmul, True, float),
# 'ge': (np.less_equal, False),
# 'gt': (np.less, False),
# 'le': (np.greater_equal, False),
@@ -7180,16 +7182,69 @@ class TestMatmulOperator(MatmulCommon):
assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc'))
assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc'))
-def test_matmul_inplace():
- # It would be nice to support in-place matmul eventually, but for now
- # we don't have a working implementation, so better just to error out
- # and nudge people to writing "a = a @ b".
- a = np.eye(3)
- b = np.eye(3)
- assert_raises(TypeError, a.__imatmul__, b)
- import operator
- assert_raises(TypeError, operator.imatmul, a, b)
- assert_raises(TypeError, exec, "a @= b", globals(), locals())
+
+class TestMatmulInplace:
+ DTYPES = {}
+ for i in MatmulCommon.types:
+ for j in MatmulCommon.types:
+ if np.can_cast(j, i):
+ DTYPES[f"{i}-{j}"] = (np.dtype(i), np.dtype(j))
+
+ @pytest.mark.parametrize("dtype1,dtype2", DTYPES.values(), ids=DTYPES)
+ def test_basic(self, dtype1: np.dtype, dtype2: np.dtype) -> None:
+ a = np.arange(10).reshape(5, 2).astype(dtype1)
+ a_id = id(a)
+ b = np.ones((2, 2), dtype=dtype2)
+
+ ref = a @ b
+ a @= b
+
+ assert id(a) == a_id
+ assert a.dtype == dtype1
+ assert a.shape == (5, 2)
+ if dtype1.kind in "fc":
+ np.testing.assert_allclose(a, ref)
+ else:
+ np.testing.assert_array_equal(a, ref)
+
+ SHAPES = {
+ "2d_large": ((10**5, 10), (10, 10)),
+ "3d_large": ((10**4, 10, 10), (1, 10, 10)),
+ "1d": ((3,), (3,)),
+ "2d_1d": ((3, 3), (3,)),
+ "1d_2d": ((3,), (3, 3)),
+ "2d_broadcast": ((3, 3), (3, 1)),
+ "2d_broadcast_reverse": ((1, 3), (3, 3)),
+ "3d_broadcast1": ((3, 3, 3), (1, 3, 1)),
+ "3d_broadcast2": ((3, 3, 3), (1, 3, 3)),
+ "3d_broadcast3": ((3, 3, 3), (3, 3, 1)),
+ "3d_broadcast_reverse1": ((1, 3, 3), (3, 3, 3)),
+ "3d_broadcast_reverse2": ((3, 1, 3), (3, 3, 3)),
+ "3d_broadcast_reverse3": ((1, 1, 3), (3, 3, 3)),
+ }
+
+ @pytest.mark.parametrize("a_shape,b_shape", SHAPES.values(), ids=SHAPES)
+ def test_shapes(self, a_shape: tuple[int, ...], b_shape: tuple[int, ...]):
+ a_size = np.product(a_shape)
+ a = np.arange(a_size).reshape(a_shape).astype(np.float64)
+ a_id = id(a)
+
+ b_size = np.product(b_shape)
+ b = np.arange(b_size).reshape(b_shape)
+
+ ref = a @ b
+ if ref.shape != a_shape:
+ with pytest.raises(ValueError):
+ a @= b
+ return
+ else:
+ a @= b
+
+ assert id(a) == a_id
+ assert a.dtype.type == np.float64
+ assert a.shape == a_shape
+ np.testing.assert_allclose(a, ref)
+
def test_matmul_axes():
a = np.arange(3*4*5).reshape(3, 4, 5)