diff options
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 4d3996d86..87a947313 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections.abc import tempfile import sys @@ -3693,7 +3695,7 @@ class TestBinop: 'and': (np.bitwise_and, True, int), 'xor': (np.bitwise_xor, True, int), 'or': (np.bitwise_or, True, int), - 'matmul': (np.matmul, False, float), + 'matmul': (np.matmul, True, float), # 'ge': (np.less_equal, False), # 'gt': (np.less, False), # 'le': (np.greater_equal, False), @@ -7155,16 +7157,31 @@ class TestMatmulOperator(MatmulCommon): assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc')) assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc')) -def test_matmul_inplace(): - # It would be nice to support in-place matmul eventually, but for now - # we don't have a working implementation, so better just to error out - # and nudge people to writing "a = a @ b". - a = np.eye(3) - b = np.eye(3) - assert_raises(TypeError, a.__imatmul__, b) - import operator - assert_raises(TypeError, operator.imatmul, a, b) - assert_raises(TypeError, exec, "a @= b", globals(), locals()) + +class TestMatmulInplace: + DTYPES = {} + for i in MatmulCommon.types: + for j in MatmulCommon.types: + if np.can_cast(j, i): + DTYPES[f"{i}-{j}"] = (np.dtype(i), np.dtype(j)) + + @pytest.mark.parametrize("dtype1,dtype2", DTYPES.values(), ids=DTYPES) + def test_matmul_inplace(self, dtype1: np.dtype, dtype2: np.dtype) -> None: + a = np.arange(10).reshape(5, 2).astype(dtype1) + a_id = id(a) + b = np.ones((2, 2), dtype=dtype2) + + ref = a @ b + a @= b + + assert id(a) == a_id + assert a.dtype == dtype1 + assert a.shape == (5, 2) + if dtype1.kind in "fc": + np.testing.assert_allclose(a, ref) + else: + np.testing.assert_array_equal(a, ref) + def test_matmul_axes(): a = np.arange(3*4*5).reshape(3, 4, 5) |