summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_elementwise_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-04-13 18:11:09 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-04-13 18:11:09 -0600
commitd40d2bcfbe01678479fe741ab8b0ff5e431e0329 (patch)
treeba31ff2eb41e23d8802cf5ec8d3860046abcf52c /numpy/_array_api/_elementwise_functions.py
parentb75a135751e4b38f144027678d1ddc74ee4d50fc (diff)
downloadnumpy-d40d2bcfbe01678479fe741ab8b0ff5e431e0329.tar.gz
Fix ceil() and floor() in the array API to always return the same dtype
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r--numpy/_array_api/_elementwise_functions.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index aa48f440c..3ca71b53e 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -190,6 +190,9 @@ def ceil(x: array, /) -> array:
"""
if x.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in ceil')
+ if x.dtype in _integer_dtypes:
+ # Note: The return dtype of ceil is the same as the input
+ return x
return ndarray._new(np.ceil(x._array))
def cos(x: array, /) -> array:
@@ -258,6 +261,9 @@ def floor(x: array, /) -> array:
"""
if x.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in floor')
+ if x.dtype in _integer_dtypes:
+ # Note: The return dtype of floor is the same as the input
+ return x
return ndarray._new(np.floor(x._array))
def floor_divide(x1: array, x2: array, /) -> array: