summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_array_object.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-09 18:08:22 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-09 18:08:22 -0600
commit5217236995f839250a148e533f4395007288cc24 (patch)
treef1797447fc0408bda58c78649c2113569dca2dfa /numpy/_array_api/_array_object.py
parent6765494edee2b90f239ae622abb5f3a7d218aa84 (diff)
downloadnumpy-5217236995f839250a148e533f4395007288cc24.tar.gz
Make the array API left and right shift do type promotion
The spec previously said it should return the type of the left argument, but this was changed to do type promotion to be consistent with all the other elementwise functions/operators.
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r--numpy/_array_api/_array_object.py28
1 files changed, 8 insertions, 20 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index 0e0544afe..6b9647626 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -410,11 +410,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
- # Note: The spec requires the return dtype of bitwise_left_shift, and
- # hence also __lshift__, to be the same as the first argument.
- # np.ndarray.__lshift__ returns a type that is the type promotion of
- # the two input types.
- res = self._array.__lshift__(other._array).astype(self.dtype)
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__lshift__(other._array)
return self.__class__._new(res)
def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
@@ -517,11 +514,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
- # Note: The spec requires the return dtype of bitwise_right_shift, and
- # hence also __rshift__, to be the same as the first argument.
- # np.ndarray.__rshift__ returns a type that is the type promotion of
- # the two input types.
- res = self._array.__rshift__(other._array).astype(self.dtype)
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rshift__(other._array)
return self.__class__._new(res)
def __setitem__(self, key, value, /):
@@ -646,11 +640,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
- # Note: The spec requires the return dtype of bitwise_left_shift, and
- # hence also __lshift__, to be the same as the first argument.
- # np.ndarray.__lshift__ returns a type that is the type promotion of
- # the two input types.
- res = self._array.__rlshift__(other._array).astype(other.dtype)
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rlshift__(other._array)
return self.__class__._new(res)
def __imatmul__(self: Array, other: Array, /) -> Array:
@@ -787,11 +778,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
- # Note: The spec requires the return dtype of bitwise_right_shift, and
- # hence also __rshift__, to be the same as the first argument.
- # np.ndarray.__rshift__ returns a type that is the type promotion of
- # the two input types.
- res = self._array.__rrshift__(other._array).astype(other.dtype)
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rrshift__(other._array)
return self.__class__._new(res)
@np.errstate(all='ignore')