diff options
author | Matti Picus <matti.picus@gmail.com> | 2023-03-26 14:32:25 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-26 14:32:25 +0300 |
commit | a37978a106073eaec5cb9e0cb54785fafb639650 (patch) | |
tree | ac2ef9d4bbcf603f500051237f9c9ed318df9ff7 | |
parent | 1e292ff08d3b797fa69b889c4a3fed99970308c8 (diff) | |
parent | 0bbe7dbf267cf835efcd514283815292bd94403f (diff) | |
download | numpy-a37978a106073eaec5cb9e0cb54785fafb639650.tar.gz |
Merge pull request #21120 from BvB93/matmul
ENH: Add support for inplace matrix multiplication
-rw-r--r-- | doc/release/upcoming_changes/21120.new_feature.rst | 21 | ||||
-rw-r--r-- | numpy/__init__.pyi | 14 | ||||
-rw-r--r-- | numpy/array_api/_array_object.py | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.c | 72 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 77 |
5 files changed, 167 insertions, 31 deletions
diff --git a/doc/release/upcoming_changes/21120.new_feature.rst b/doc/release/upcoming_changes/21120.new_feature.rst new file mode 100644 index 000000000..7d4dbf743 --- /dev/null +++ b/doc/release/upcoming_changes/21120.new_feature.rst @@ -0,0 +1,21 @@ +Add support for inplace matrix multiplication +---------------------------------------------- +It is now possible to perform inplace matrix multiplication +via the ``@=`` operator. + +.. code-block:: python + + >>> import numpy as np + + >>> a = np.arange(6).reshape(3, 2) + >>> print(a) + [[0 1] + [2 3] + [4 5]] + + >>> b = np.ones((2, 2), dtype=int) + >>> a @= b + >>> print(a) + [[1 1] + [5 5] + [9 9]] diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index ee5fbb601..8627f6c60 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -1921,7 +1921,6 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): def __neg__(self: NDArray[object_]) -> Any: ... # Binary ops - # NOTE: `ndarray` does not implement `__imatmul__` @overload def __matmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc] @overload @@ -2508,6 +2507,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): @overload def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ... + @overload + def __imatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... + @overload + def __imatmul__(self: NDArray[unsignedinteger[_NBit1]], other: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[_NBit1]]: ... + @overload + def __imatmul__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ... + @overload + def __imatmul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co) -> NDArray[floating[_NBit1]]: ... + @overload + def __imatmul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co) -> NDArray[complexfloating[_NBit1, _NBit1]]: ... + @overload + def __imatmul__(self: NDArray[object_], other: Any) -> NDArray[object_]: ... + def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ... def __dlpack_device__(self) -> tuple[int, L[0]]: ... diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index eee117be6..a949b5977 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -850,23 +850,13 @@ class Array: """ Performs the operation __imatmul__. """ - # Note: NumPy does not implement __imatmul__. - # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - - # __imatmul__ can only be allowed when it would not change the shape - # of self. - other_shape = other.shape - if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") - if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: - raise ValueError("@= cannot change the shape of the input array") - self._array[:] = self._array.__matmul__(other._array) - return self + res = self._array.__imatmul__(other._array) + return self.__class__._new(res) def __rmatmul__(self: Array, other: Array, /) -> Array: """ diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 2e25152d5..c208fb203 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -53,6 +53,8 @@ static PyObject * array_inplace_remainder(PyArrayObject *m1, PyObject *m2); static PyObject * array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo)); +static PyObject * +array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2); /* * Dictionary can contain any of the numeric operations, by name. @@ -339,7 +341,6 @@ array_divmod(PyObject *m1, PyObject *m2) return PyArray_GenericBinaryFunction(m1, m2, n_ops.divmod); } -/* Need this to be version dependent on account of the slot check */ static PyObject * array_matrix_multiply(PyObject *m1, PyObject *m2) { @@ -348,13 +349,70 @@ array_matrix_multiply(PyObject *m1, PyObject *m2) } static PyObject * -array_inplace_matrix_multiply( - PyArrayObject *NPY_UNUSED(m1), PyObject *NPY_UNUSED(m2)) +array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other) { - PyErr_SetString(PyExc_TypeError, - "In-place matrix multiplication is not (yet) supported. " - "Use 'a = a @ b' instead of 'a @= b'."); - return NULL; + static PyObject *AxisError_cls = NULL; + npy_cache_import("numpy.exceptions", "AxisError", &AxisError_cls); + if (AxisError_cls == NULL) { + return NULL; + } + + INPLACE_GIVE_UP_IF_NEEDED(self, other, + nb_inplace_matrix_multiply, array_inplace_matrix_multiply); + + /* + * Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast + * if the result without `out` would have less dimensions than `a`. + * Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the + * case exactly when the second operand has both core dimensions. + * + * The error here will be confusing, but for now, we enforce this by + * passing the correct `axes=`. + */ + static PyObject *axes_1d_obj_kwargs = NULL; + static PyObject *axes_2d_obj_kwargs = NULL; + if (NPY_UNLIKELY(axes_1d_obj_kwargs == NULL)) { + axes_1d_obj_kwargs = Py_BuildValue( + "{s, [(i), (i, i), (i)]}", "axes", -1, -2, -1, -1); + if (axes_1d_obj_kwargs == NULL) { + return NULL; + } + } + if (NPY_UNLIKELY(axes_2d_obj_kwargs == NULL)) { + axes_2d_obj_kwargs = Py_BuildValue( + "{s, [(i, i), (i, i), (i, i)]}", "axes", -2, -1, -2, -1, -2, -1); + if (axes_2d_obj_kwargs == NULL) { + return NULL; + } + } + + PyObject *args = PyTuple_Pack(3, self, other, self); + if (args == NULL) { + return NULL; + } + PyObject *kwargs; + if (PyArray_NDIM(self) == 1) { + kwargs = axes_1d_obj_kwargs; + } + else { + kwargs = axes_2d_obj_kwargs; + } + PyObject *res = PyObject_Call(n_ops.matmul, args, kwargs); + Py_DECREF(args); + + if (res == NULL) { + /* + * AxisError should indicate that the axes argument didn't work out + * which should mean the second operand not being 2 dimensional. + */ + if (PyErr_ExceptionMatches(AxisError_cls)) { + PyErr_SetString(PyExc_ValueError, + "inplace matrix multiplication requires the first operand to " + "have at least one and the second at least two dimensions."); + } + } + + return res; } /* diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 94f6ef7ad..77f651659 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections.abc import tempfile import sys @@ -3729,7 +3731,7 @@ class TestBinop: 'and': (np.bitwise_and, True, int), 'xor': (np.bitwise_xor, True, int), 'or': (np.bitwise_or, True, int), - 'matmul': (np.matmul, False, float), + 'matmul': (np.matmul, True, float), # 'ge': (np.less_equal, False), # 'gt': (np.less, False), # 'le': (np.greater_equal, False), @@ -7180,16 +7182,69 @@ class TestMatmulOperator(MatmulCommon): assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc')) assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc')) -def test_matmul_inplace(): - # It would be nice to support in-place matmul eventually, but for now - # we don't have a working implementation, so better just to error out - # and nudge people to writing "a = a @ b". - a = np.eye(3) - b = np.eye(3) - assert_raises(TypeError, a.__imatmul__, b) - import operator - assert_raises(TypeError, operator.imatmul, a, b) - assert_raises(TypeError, exec, "a @= b", globals(), locals()) + +class TestMatmulInplace: + DTYPES = {} + for i in MatmulCommon.types: + for j in MatmulCommon.types: + if np.can_cast(j, i): + DTYPES[f"{i}-{j}"] = (np.dtype(i), np.dtype(j)) + + @pytest.mark.parametrize("dtype1,dtype2", DTYPES.values(), ids=DTYPES) + def test_basic(self, dtype1: np.dtype, dtype2: np.dtype) -> None: + a = np.arange(10).reshape(5, 2).astype(dtype1) + a_id = id(a) + b = np.ones((2, 2), dtype=dtype2) + + ref = a @ b + a @= b + + assert id(a) == a_id + assert a.dtype == dtype1 + assert a.shape == (5, 2) + if dtype1.kind in "fc": + np.testing.assert_allclose(a, ref) + else: + np.testing.assert_array_equal(a, ref) + + SHAPES = { + "2d_large": ((10**5, 10), (10, 10)), + "3d_large": ((10**4, 10, 10), (1, 10, 10)), + "1d": ((3,), (3,)), + "2d_1d": ((3, 3), (3,)), + "1d_2d": ((3,), (3, 3)), + "2d_broadcast": ((3, 3), (3, 1)), + "2d_broadcast_reverse": ((1, 3), (3, 3)), + "3d_broadcast1": ((3, 3, 3), (1, 3, 1)), + "3d_broadcast2": ((3, 3, 3), (1, 3, 3)), + "3d_broadcast3": ((3, 3, 3), (3, 3, 1)), + "3d_broadcast_reverse1": ((1, 3, 3), (3, 3, 3)), + "3d_broadcast_reverse2": ((3, 1, 3), (3, 3, 3)), + "3d_broadcast_reverse3": ((1, 1, 3), (3, 3, 3)), + } + + @pytest.mark.parametrize("a_shape,b_shape", SHAPES.values(), ids=SHAPES) + def test_shapes(self, a_shape: tuple[int, ...], b_shape: tuple[int, ...]): + a_size = np.product(a_shape) + a = np.arange(a_size).reshape(a_shape).astype(np.float64) + a_id = id(a) + + b_size = np.product(b_shape) + b = np.arange(b_size).reshape(b_shape) + + ref = a @ b + if ref.shape != a_shape: + with pytest.raises(ValueError): + a @= b + return + else: + a @= b + + assert id(a) == a_id + assert a.dtype.type == np.float64 + assert a.shape == a_shape + np.testing.assert_allclose(a, ref) + def test_matmul_axes(): a = np.arange(3*4*5).reshape(3, 4, 5) |