summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorMatthew Barber <quitesimplymatt@gmail.com>2022-02-09 08:35:21 +0000
committerMatthew Barber <quitesimplymatt@gmail.com>2022-02-09 08:44:07 +0000
commit995f5464b6c5d8569e159a96c6af106721a4e6d5 (patch)
treef06f7155aac38e8b912b9af66a29568aafae420c /numpy/array_api
parent8eac9a4bb5b497ca29ebb852f21169ecfd0191e1 (diff)
downloadnumpy-995f5464b6c5d8569e159a96c6af106721a4e6d5.tar.gz
Note `np.array_api.can_cast()` does not use `np.can_cast()`
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_data_type_functions.py5
-rw-r--r--numpy/array_api/tests/test_data_type_functions.py2
2 files changed, 5 insertions, 2 deletions
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index 1198ff778..1fb6062f6 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -56,10 +56,15 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
+ # Note: We avoid np.can_cast() as it has discrepancies with the array API.
+ # See https://github.com/numpy/numpy/issues/20870
try:
+ # We promote `from_` and `to` together. We then check if the promoted
+ # dtype is `to`, which indicates if `from_` can (up)cast to `to`.
dtype = _result_type(from_, to)
return to == dtype
except TypeError:
+ # _result_type() raises if the dtypes don't promote together
return False
diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py
index 3f01bb311..efe3d0abd 100644
--- a/numpy/array_api/tests/test_data_type_functions.py
+++ b/numpy/array_api/tests/test_data_type_functions.py
@@ -8,8 +8,6 @@ from numpy import array_api as xp
[
(xp.int8, xp.int16, True),
(xp.int16, xp.int8, False),
- # np.can_cast has discrepancies with the Array API
- # See https://github.com/numpy/numpy/issues/20870
(xp.bool, xp.int8, False),
(xp.asarray(0, dtype=xp.uint8), xp.int8, False),
],