From 2ff635c7cbc8804a3956ddbf8165f536dffc2df5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Feb 2021 18:22:06 -0700 Subject: Don't check if a dtype is in all_dtypes The array API namespace is not going to do type checking against arbitrary objects. An object that takes an array as input should assume that it will get an array API namespace array object. Passing a NumPy array or other type of object to any of the functions is undefined behavior, unless the type signature allows for it. --- numpy/_array_api/_elementwise_functions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) (limited to 'numpy/_array_api/_elementwise_functions.py') diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index 2357b337c..b48a38c3d 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, +from ._dtypes import (_boolean_dtypes, _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes) from ._types import array from ._array_object import ndarray @@ -213,8 +213,6 @@ def equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ - if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes: - raise TypeError('Only array API spec dtypes are allowed in equal') return ndarray._new(np.equal(x1._array, x2._array)) def exp(x: array, /) -> array: @@ -443,8 +441,6 @@ def not_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ - if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes: - raise TypeError('Only array API spec dtypes are allowed in not_equal') return ndarray._new(np.not_equal(x1._array, x2._array)) def positive(x: array, /) -> array: -- cgit v1.2.1