summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-03-18 15:20:24 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-03-18 15:20:24 -0600
commit8968bc31944eb2af214b591d38eca3b0cea56f4b (patch)
tree70064a13231880e59207f13113ae0323dfcc595c /numpy/_array_api
parented05662905f83864a937fa67b7f762cda1277df2 (diff)
downloadnumpy-8968bc31944eb2af214b591d38eca3b0cea56f4b.tar.gz
bitwise_left_shift and bitwise_right_shift should return the dtype of the first argument
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_elementwise_functions.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index cd2e8661f..a8af04c62 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -123,10 +123,13 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ # Note: the function name is different here
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
raise TypeError('Only integer dtypes are allowed in bitwise_left_shift')
- # Note: the function name is different here
- return ndarray._new(np.left_shift(x1._array, x2._array))
+ # Note: The spec requires the return dtype of bitwise_left_shift to be the
+ # same as the first argument. np.left_shift() returns a type that is the
+ # type promotion of the two input types.
+ return ndarray._new(np.left_shift(x1._array, x2._array).astype(x1.dtype))
def bitwise_invert(x: array, /) -> array:
"""
@@ -155,10 +158,13 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ # Note: the function name is different here
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
raise TypeError('Only integer dtypes are allowed in bitwise_right_shift')
- # Note: the function name is different here
- return ndarray._new(np.right_shift(x1._array, x2._array))
+ # Note: The spec requires the return dtype of bitwise_left_shift to be the
+ # same as the first argument. np.left_shift() returns a type that is the
+ # type promotion of the two input types.
+ return ndarray._new(np.right_shift(x1._array, x2._array).astype(x1.dtype))
def bitwise_xor(x1: array, x2: array, /) -> array:
"""