diff options
-rw-r--r-- | numpy/_array_api/_array_object.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index a3de25478..8f7252160 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -648,12 +648,21 @@ class ndarray: """ Performs the operation __imatmul__. """ + # Note: NumPy does not implement __imatmul__. + if isinstance(other, (int, float, bool)): # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._promote_scalar(other) - res = self._array.__imatmul__(other._array) - return self.__class__._new(res) + # __imatmul__ can only be allowed when it would not change the shape + # of self. + other_shape = other.shape + if self.shape == () or other_shape == (): + raise ValueError("@= requires at least one dimension") + if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: + raise ValueError("@= cannot change the shape of the input array") + self._array[:] = self._array.__matmul__(other._array) + return self def __rmatmul__(self: array, other: array, /) -> array: """ |