summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-08-04 20:01:11 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-08-04 20:01:11 -0600
commitbc20d334b575f897157b1cf3eecda77f3e40e049 (patch)
treefb5c121332bc8078cd2115cb5a6f023f580afdf4 /numpy/array_api
parent5605d687019dc55e594d4e227747c72bebb71a3c (diff)
downloadnumpy-bc20d334b575f897157b1cf3eecda77f3e40e049.tar.gz
Move the array API dtype categories into the top level
They are not an official part of the spec but are useful for various parts of the implementation.
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_array_object.py17
-rw-r--r--numpy/array_api/_dtypes.py10
-rw-r--r--numpy/array_api/tests/test_elementwise_functions.py16
3 files changed, 17 insertions, 26 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index af70058e6..50906642d 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -98,23 +98,14 @@ class Array:
if other is NotImplemented:
return other
"""
- from ._dtypes import _result_type
-
- _dtypes = {
- 'all': _all_dtypes,
- 'numeric': _numeric_dtypes,
- 'integer': _integer_dtypes,
- 'integer or boolean': _integer_or_boolean_dtypes,
- 'boolean': _boolean_dtypes,
- 'floating-point': _floating_dtypes,
- }
-
- if self.dtype not in _dtypes[dtype_category]:
+ from ._dtypes import _result_type, _dtype_categories
+
+ if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, Array):
- if other.dtype not in _dtypes[dtype_category]:
+ if other.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
else:
return NotImplemented
diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py
index fcdb562da..07be267da 100644
--- a/numpy/array_api/_dtypes.py
+++ b/numpy/array_api/_dtypes.py
@@ -23,6 +23,16 @@ _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
+_dtype_categories = {
+ 'all': _all_dtypes,
+ 'numeric': _numeric_dtypes,
+ 'integer': _integer_dtypes,
+ 'integer or boolean': _integer_or_boolean_dtypes,
+ 'boolean': _boolean_dtypes,
+ 'floating-point': _floating_dtypes,
+}
+
+
# Note: the spec defines a restricted type promotion table compared to NumPy.
# In particular, cross-kind promotions like integer + float or boolean +
# integer are not allowed, even for functions that accept both kinds.
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py
index 994cb0bf0..2a5ddbc87 100644
--- a/numpy/array_api/tests/test_elementwise_functions.py
+++ b/numpy/array_api/tests/test_elementwise_functions.py
@@ -4,9 +4,8 @@ from numpy.testing import assert_raises
from .. import asarray, _elementwise_functions
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
-from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
- _integer_dtypes, _integer_or_boolean_dtypes,
- _numeric_dtypes)
+from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes,
+ _integer_dtypes)
def nargs(func):
return len(getfullargspec(func).args)
@@ -75,15 +74,6 @@ def test_function_types():
'trunc': 'numeric',
}
- _dtypes = {
- 'all': _all_dtypes,
- 'numeric': _numeric_dtypes,
- 'integer': _integer_dtypes,
- 'integer_or_boolean': _integer_or_boolean_dtypes,
- 'boolean': _boolean_dtypes,
- 'floating': _floating_dtypes,
- }
-
def _array_vals():
for d in _integer_dtypes:
yield asarray(1, dtype=d)
@@ -94,7 +84,7 @@ def test_function_types():
for x in _array_vals():
for func_name, types in elementwise_function_input_types.items():
- dtypes = _dtypes[types]
+ dtypes = _dtype_categories[types]
func = getattr(_elementwise_functions, func_name)
if nargs(func) == 2:
for y in _array_vals():