summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-22 18:20:44 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-22 18:20:44 -0600
commite7f6dfecccc9dc84520af1a9f0000b3b0d0f4895 (patch)
treeb42909e9f8714f478c2756f9afa6a7a8c57c1f4b /numpy
parentdeaf0bf6fc819c9c7b4dcffe0d4aee43bdc33bae (diff)
downloadnumpy-e7f6dfecccc9dc84520af1a9f0000b3b0d0f4895.tar.gz
Fix the array API trunc() to return the same dtype as the input
It is similar to floor() and ceil() but I missed it previously.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/_array_api/_elementwise_functions.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index 67fb7034d..7833ebe54 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -653,4 +653,7 @@ def trunc(x: Array, /) -> Array:
"""
if x.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in trunc')
+ if x.dtype in _integer_dtypes:
+ # Note: The return dtype of trunc is the same as the input
+ return x
return Array._new(np.trunc(x._array))