diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2022-02-25 17:22:17 +0100 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-02 00:29:32 +0100 |
commit | e0ed8ceae87bcb6ac3ad9190ec409b3ed631429a (patch) | |
tree | 1d6425cabee31fdfbe1cb51afa82c8ed2b119e7f /numpy | |
parent | 9b27375050ccf0fc6f610dd8050d1867b15de187 (diff) | |
download | numpy-e0ed8ceae87bcb6ac3ad9190ec409b3ed631429a.tar.gz |
TST: Add tests for inplace matrix multiplication
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 4d3996d86..87a947313 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections.abc import tempfile import sys @@ -3693,7 +3695,7 @@ class TestBinop: 'and': (np.bitwise_and, True, int), 'xor': (np.bitwise_xor, True, int), 'or': (np.bitwise_or, True, int), - 'matmul': (np.matmul, False, float), + 'matmul': (np.matmul, True, float), # 'ge': (np.less_equal, False), # 'gt': (np.less, False), # 'le': (np.greater_equal, False), @@ -7155,16 +7157,31 @@ class TestMatmulOperator(MatmulCommon): assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc')) assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc')) -def test_matmul_inplace(): - # It would be nice to support in-place matmul eventually, but for now - # we don't have a working implementation, so better just to error out - # and nudge people to writing "a = a @ b". - a = np.eye(3) - b = np.eye(3) - assert_raises(TypeError, a.__imatmul__, b) - import operator - assert_raises(TypeError, operator.imatmul, a, b) - assert_raises(TypeError, exec, "a @= b", globals(), locals()) + +class TestMatmulInplace: + DTYPES = {} + for i in MatmulCommon.types: + for j in MatmulCommon.types: + if np.can_cast(j, i): + DTYPES[f"{i}-{j}"] = (np.dtype(i), np.dtype(j)) + + @pytest.mark.parametrize("dtype1,dtype2", DTYPES.values(), ids=DTYPES) + def test_matmul_inplace(self, dtype1: np.dtype, dtype2: np.dtype) -> None: + a = np.arange(10).reshape(5, 2).astype(dtype1) + a_id = id(a) + b = np.ones((2, 2), dtype=dtype2) + + ref = a @ b + a @= b + + assert id(a) == a_id + assert a.dtype == dtype1 + assert a.shape == (5, 2) + if dtype1.kind in "fc": + np.testing.assert_allclose(a, ref) + else: + np.testing.assert_array_equal(a, ref) + def test_matmul_axes(): a = np.arange(3*4*5).reshape(3, 4, 5) |