summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_elementwise_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-02-26 18:22:06 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-02-26 18:22:06 -0700
commit2ff635c7cbc8804a3956ddbf8165f536dffc2df5 (patch)
treea4c1f271d8b30c71f6a16e6150abcd9d5f6257ac /numpy/_array_api/_elementwise_functions.py
parentb7856e348d731551405bdf0dd41ff1b0416da129 (diff)
downloadnumpy-2ff635c7cbc8804a3956ddbf8165f536dffc2df5.tar.gz
Don't check if a dtype is in all_dtypes
The array API namespace is not going to do type checking against arbitrary objects. An object that takes an array as input should assume that it will get an array API namespace array object. Passing a NumPy array or other type of object to any of the functions is undefined behavior, unless the type signature allows for it.
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r--numpy/_array_api/_elementwise_functions.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index 2357b337c..b48a38c3d 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from ._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
+from ._dtypes import (_boolean_dtypes, _floating_dtypes,
_integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes)
from ._types import array
from ._array_object import ndarray
@@ -213,8 +213,6 @@ def equal(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
- if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes:
- raise TypeError('Only array API spec dtypes are allowed in equal')
return ndarray._new(np.equal(x1._array, x2._array))
def exp(x: array, /) -> array:
@@ -443,8 +441,6 @@ def not_equal(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
- if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes:
- raise TypeError('Only array API spec dtypes are allowed in not_equal')
return ndarray._new(np.not_equal(x1._array, x2._array))
def positive(x: array, /) -> array: