diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-09 18:08:22 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-09 18:08:22 -0600 |
commit | 5217236995f839250a148e533f4395007288cc24 (patch) | |
tree | f1797447fc0408bda58c78649c2113569dca2dfa /numpy/_array_api/_array_object.py | |
parent | 6765494edee2b90f239ae622abb5f3a7d218aa84 (diff) | |
download | numpy-5217236995f839250a148e533f4395007288cc24.tar.gz |
Make the array API left and right shift do type promotion
The spec previously said it should return the type of the left argument, but
this was changed to do type promotion to be consistent with all the other
elementwise functions/operators.
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 28 |
1 files changed, 8 insertions, 20 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 0e0544afe..6b9647626 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -410,11 +410,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) - # Note: The spec requires the return dtype of bitwise_left_shift, and - # hence also __lshift__, to be the same as the first argument. - # np.ndarray.__lshift__ returns a type that is the type promotion of - # the two input types. - res = self._array.__lshift__(other._array).astype(self.dtype) + self, other = self._normalize_two_args(self, other) + res = self._array.__lshift__(other._array) return self.__class__._new(res) def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: @@ -517,11 +514,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) - # Note: The spec requires the return dtype of bitwise_right_shift, and - # hence also __rshift__, to be the same as the first argument. - # np.ndarray.__rshift__ returns a type that is the type promotion of - # the two input types. - res = self._array.__rshift__(other._array).astype(self.dtype) + self, other = self._normalize_two_args(self, other) + res = self._array.__rshift__(other._array) return self.__class__._new(res) def __setitem__(self, key, value, /): @@ -646,11 +640,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) - # Note: The spec requires the return dtype of bitwise_left_shift, and - # hence also __lshift__, to be the same as the first argument. - # np.ndarray.__lshift__ returns a type that is the type promotion of - # the two input types. - res = self._array.__rlshift__(other._array).astype(other.dtype) + self, other = self._normalize_two_args(self, other) + res = self._array.__rlshift__(other._array) return self.__class__._new(res) def __imatmul__(self: Array, other: Array, /) -> Array: @@ -787,11 +778,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) - # Note: The spec requires the return dtype of bitwise_right_shift, and - # hence also __rshift__, to be the same as the first argument. - # np.ndarray.__rshift__ returns a type that is the type promotion of - # the two input types. - res = self._array.__rrshift__(other._array).astype(other.dtype) + self, other = self._normalize_two_args(self, other) + res = self._array.__rrshift__(other._array) return self.__class__._new(res) @np.errstate(all='ignore') |