diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-08 16:10:27 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-08 16:10:27 -0600 |
commit | 13796236295b344ee83e79c8a33ad6205c0095db (patch) | |
tree | 518499b1d8ef636ccfc118aa124b69283c418f10 /numpy/_array_api/_array_object.py | |
parent | 01780805fabd160514a25d44972d527c3c99f8c8 (diff) | |
download | numpy-13796236295b344ee83e79c8a33ad6205c0095db.tar.gz |
Fix the __imatmul__ method in the array API namespace
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index a3de25478..8f7252160 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -648,12 +648,21 @@ class ndarray: """ Performs the operation __imatmul__. """ + # Note: NumPy does not implement __imatmul__. + if isinstance(other, (int, float, bool)): # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._promote_scalar(other) - res = self._array.__imatmul__(other._array) - return self.__class__._new(res) + # __imatmul__ can only be allowed when it would not change the shape + # of self. + other_shape = other.shape + if self.shape == () or other_shape == (): + raise ValueError("@= requires at least one dimension") + if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: + raise ValueError("@= cannot change the shape of the input array") + self._array[:] = self._array.__matmul__(other._array) + return self def __rmatmul__(self: array, other: array, /) -> array: """ |