diff options
| author | Matthew Barber <quitesimplymatt@gmail.com> | 2022-02-09 08:35:21 +0000 |
|---|---|---|
| committer | Matthew Barber <quitesimplymatt@gmail.com> | 2022-02-09 08:44:07 +0000 |
| commit | 995f5464b6c5d8569e159a96c6af106721a4e6d5 (patch) | |
| tree | f06f7155aac38e8b912b9af66a29568aafae420c /numpy/array_api | |
| parent | 8eac9a4bb5b497ca29ebb852f21169ecfd0191e1 (diff) | |
| download | numpy-995f5464b6c5d8569e159a96c6af106721a4e6d5.tar.gz | |
Note `np.array_api.can_cast()` does not use `np.can_cast()`
Diffstat (limited to 'numpy/array_api')
| -rw-r--r-- | numpy/array_api/_data_type_functions.py | 5 | ||||
| -rw-r--r-- | numpy/array_api/tests/test_data_type_functions.py | 2 |
2 files changed, 5 insertions, 2 deletions
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 1198ff778..1fb6062f6 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -56,10 +56,15 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: raise TypeError(f"{from_=}, but should be an array_api array or dtype") if to not in _all_dtypes: raise TypeError(f"{to=}, but should be a dtype") + # Note: We avoid np.can_cast() as it has discrepancies with the array API. + # See https://github.com/numpy/numpy/issues/20870 try: + # We promote `from_` and `to` together. We then check if the promoted + # dtype is `to`, which indicates if `from_` can (up)cast to `to`. dtype = _result_type(from_, to) return to == dtype except TypeError: + # _result_type() raises if the dtypes don't promote together return False diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py index 3f01bb311..efe3d0abd 100644 --- a/numpy/array_api/tests/test_data_type_functions.py +++ b/numpy/array_api/tests/test_data_type_functions.py @@ -8,8 +8,6 @@ from numpy import array_api as xp [ (xp.int8, xp.int16, True), (xp.int16, xp.int8, False), - # np.can_cast has discrepancies with the Array API - # See https://github.com/numpy/numpy/issues/20870 (xp.bool, xp.int8, False), (xp.asarray(0, dtype=xp.uint8), xp.int8, False), ], |
