diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-03-18 15:20:24 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-03-18 15:20:24 -0600 |
commit | 8968bc31944eb2af214b591d38eca3b0cea56f4b (patch) | |
tree | 70064a13231880e59207f13113ae0323dfcc595c /numpy/_array_api/_elementwise_functions.py | |
parent | ed05662905f83864a937fa67b7f762cda1277df2 (diff) | |
download | numpy-8968bc31944eb2af214b591d38eca3b0cea56f4b.tar.gz |
bitwise_left_shift and bitwise_right_shift should return the dtype of the first argument
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index cd2e8661f..a8af04c62 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -123,10 +123,13 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + # Note: the function name is different here if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') - # Note: the function name is different here - return ndarray._new(np.left_shift(x1._array, x2._array)) + # Note: The spec requires the return dtype of bitwise_left_shift to be the + # same as the first argument. np.left_shift() returns a type that is the + # type promotion of the two input types. + return ndarray._new(np.left_shift(x1._array, x2._array).astype(x1.dtype)) def bitwise_invert(x: array, /) -> array: """ @@ -155,10 +158,13 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + # Note: the function name is different here if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') - # Note: the function name is different here - return ndarray._new(np.right_shift(x1._array, x2._array)) + # Note: The spec requires the return dtype of bitwise_left_shift to be the + # same as the first argument. np.left_shift() returns a type that is the + # type promotion of the two input types. + return ndarray._new(np.right_shift(x1._array, x2._array).astype(x1.dtype)) def bitwise_xor(x1: array, x2: array, /) -> array: """ |