summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_array_object.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-08 16:10:27 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-08 16:10:27 -0600
commit13796236295b344ee83e79c8a33ad6205c0095db (patch)
tree518499b1d8ef636ccfc118aa124b69283c418f10 /numpy/_array_api/_array_object.py
parent01780805fabd160514a25d44972d527c3c99f8c8 (diff)
downloadnumpy-13796236295b344ee83e79c8a33ad6205c0095db.tar.gz
Fix the __imatmul__ method in the array API namespace
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r--numpy/_array_api/_array_object.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index a3de25478..8f7252160 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -648,12 +648,21 @@ class ndarray:
"""
Performs the operation __imatmul__.
"""
+ # Note: NumPy does not implement __imatmul__.
+
if isinstance(other, (int, float, bool)):
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._promote_scalar(other)
- res = self._array.__imatmul__(other._array)
- return self.__class__._new(res)
+ # __imatmul__ can only be allowed when it would not change the shape
+ # of self.
+ other_shape = other.shape
+ if self.shape == () or other_shape == ():
+ raise ValueError("@= requires at least one dimension")
+ if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
+ raise ValueError("@= cannot change the shape of the input array")
+ self._array[:] = self._array.__matmul__(other._array)
+ return self
def __rmatmul__(self: array, other: array, /) -> array:
"""