summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_array_object.py17
-rw-r--r--numpy/array_api/_dtypes.py10
-rw-r--r--numpy/array_api/tests/test_elementwise_functions.py16
3 files changed, 17 insertions, 26 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index af70058e6..50906642d 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -98,23 +98,14 @@ class Array:
if other is NotImplemented:
return other
"""
- from ._dtypes import _result_type
-
- _dtypes = {
- 'all': _all_dtypes,
- 'numeric': _numeric_dtypes,
- 'integer': _integer_dtypes,
- 'integer or boolean': _integer_or_boolean_dtypes,
- 'boolean': _boolean_dtypes,
- 'floating-point': _floating_dtypes,
- }
-
- if self.dtype not in _dtypes[dtype_category]:
+ from ._dtypes import _result_type, _dtype_categories
+
+ if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, Array):
- if other.dtype not in _dtypes[dtype_category]:
+ if other.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
else:
return NotImplemented
diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py
index fcdb562da..07be267da 100644
--- a/numpy/array_api/_dtypes.py
+++ b/numpy/array_api/_dtypes.py
@@ -23,6 +23,16 @@ _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
+_dtype_categories = {
+ 'all': _all_dtypes,
+ 'numeric': _numeric_dtypes,
+ 'integer': _integer_dtypes,
+ 'integer or boolean': _integer_or_boolean_dtypes,
+ 'boolean': _boolean_dtypes,
+ 'floating-point': _floating_dtypes,
+}
+
+
# Note: the spec defines a restricted type promotion table compared to NumPy.
# In particular, cross-kind promotions like integer + float or boolean +
# integer are not allowed, even for functions that accept both kinds.
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py
index 994cb0bf0..2a5ddbc87 100644
--- a/numpy/array_api/tests/test_elementwise_functions.py
+++ b/numpy/array_api/tests/test_elementwise_functions.py
@@ -4,9 +4,8 @@ from numpy.testing import assert_raises
from .. import asarray, _elementwise_functions
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
-from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
- _integer_dtypes, _integer_or_boolean_dtypes,
- _numeric_dtypes)
+from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes,
+ _integer_dtypes)
def nargs(func):
return len(getfullargspec(func).args)
@@ -75,15 +74,6 @@ def test_function_types():
'trunc': 'numeric',
}
- _dtypes = {
- 'all': _all_dtypes,
- 'numeric': _numeric_dtypes,
- 'integer': _integer_dtypes,
- 'integer_or_boolean': _integer_or_boolean_dtypes,
- 'boolean': _boolean_dtypes,
- 'floating': _floating_dtypes,
- }
-
def _array_vals():
for d in _integer_dtypes:
yield asarray(1, dtype=d)
@@ -94,7 +84,7 @@ def test_function_types():
for x in _array_vals():
for func_name, types in elementwise_function_input_types.items():
- dtypes = _dtypes[types]
+ dtypes = _dtype_categories[types]
func = getattr(_elementwise_functions, func_name)
if nargs(func) == 2:
for y in _array_vals():