summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_array_object.py305
-rw-r--r--numpy/_array_api/tests/test_array_object.py178
2 files changed, 376 insertions, 107 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index f8bad0b59..f6371fbf4 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -18,7 +18,8 @@ from __future__ import annotations
import operator
from enum import IntEnum
from ._creation_functions import asarray
-from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes
+from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes,
+ _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes)
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
if TYPE_CHECKING:
@@ -83,6 +84,52 @@ class Array:
"""
return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})"
+ # These are various helper functions to make the array behavior match the
+ # spec in places where it either deviates from or is more strict than
+ # NumPy behavior
+
+ def _check_allowed_dtypes(self, other, dtype_category, op):
+ """
+ Helper function for operators to only allow specific input dtypes
+
+ Use like
+
+ other = self._check_allowed_dtypes(other, 'numeric', '__add__')
+ if other is NotImplemented:
+ return other
+ """
+ from ._dtypes import _result_type
+
+ _dtypes = {
+ 'all': _all_dtypes,
+ 'numeric': _numeric_dtypes,
+ 'integer': _integer_dtypes,
+ 'integer or boolean': _integer_or_boolean_dtypes,
+ 'boolean': _boolean_dtypes,
+ 'floating-point': _floating_dtypes,
+ }
+
+ if self.dtype not in _dtypes[dtype_category]:
+ raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
+ elif isinstance(other, Array):
+ if other.dtype not in _dtypes[dtype_category]:
+ raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
+ else:
+ return NotImplemented
+
+ res_dtype = _result_type(self.dtype, other.dtype)
+ if op.startswith('__i'):
+ # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example,
+
+ # >>> a = np.array(1, dtype=np.int8)
+ # >>> a += np.array(1, dtype=np.int16)
+ if res_dtype != self.dtype:
+ raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}")
+
+ return other
+
# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar):
"""
@@ -270,8 +317,9 @@ class Array:
"""
Performs the operation __add__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__add__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__add__(other._array)
return self.__class__._new(res)
@@ -280,8 +328,9 @@ class Array:
"""
Performs the operation __and__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__and__(other._array)
return self.__class__._new(res)
@@ -321,6 +370,11 @@ class Array:
"""
Performs the operation __eq__.
"""
+ # Even though "all" dtypes are allowed, we still require them to be
+ # promotable with each other.
+ other = self._check_allowed_dtypes(other, 'all', '__eq__')
+ if other is NotImplemented:
+ return other
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
self, other = self._normalize_two_args(self, other)
@@ -341,8 +395,9 @@ class Array:
"""
Performs the operation __floordiv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__floordiv__(other._array)
return self.__class__._new(res)
@@ -351,8 +406,9 @@ class Array:
"""
Performs the operation __ge__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__ge__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ge__(other._array)
return self.__class__._new(res)
@@ -371,8 +427,9 @@ class Array:
"""
Performs the operation __gt__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__gt__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__gt__(other._array)
return self.__class__._new(res)
@@ -391,6 +448,8 @@ class Array:
"""
Performs the operation __invert__.
"""
+ if self.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError('Only integer or boolean dtypes are allowed in __invert__')
res = self._array.__invert__()
return self.__class__._new(res)
@@ -398,8 +457,9 @@ class Array:
"""
Performs the operation __le__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__le__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__le__(other._array)
return self.__class__._new(res)
@@ -416,8 +476,9 @@ class Array:
"""
Performs the operation __lshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__lshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lshift__(other._array)
return self.__class__._new(res)
@@ -426,8 +487,9 @@ class Array:
"""
Performs the operation __lt__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__lt__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lt__(other._array)
return self.__class__._new(res)
@@ -436,10 +498,11 @@ class Array:
"""
Performs the operation __matmul__.
"""
- if isinstance(other, (int, float, bool)):
- # matmul is not defined for scalars, but without this, we may get
- # the wrong error message from asarray.
- other = self._promote_scalar(other)
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, 'numeric', '__matmul__')
+ if other is NotImplemented:
+ return other
res = self._array.__matmul__(other._array)
return self.__class__._new(res)
@@ -447,8 +510,9 @@ class Array:
"""
Performs the operation __mod__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__mod__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mod__(other._array)
return self.__class__._new(res)
@@ -457,8 +521,9 @@ class Array:
"""
Performs the operation __mul__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__mul__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mul__(other._array)
return self.__class__._new(res)
@@ -467,6 +532,9 @@ class Array:
"""
Performs the operation __ne__.
"""
+ other = self._check_allowed_dtypes(other, 'all', '__ne__')
+ if other is NotImplemented:
+ return other
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
self, other = self._normalize_two_args(self, other)
@@ -477,6 +545,8 @@ class Array:
"""
Performs the operation __neg__.
"""
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in __neg__')
res = self._array.__neg__()
return self.__class__._new(res)
@@ -484,8 +554,9 @@ class Array:
"""
Performs the operation __or__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__or__(other._array)
return self.__class__._new(res)
@@ -494,6 +565,8 @@ class Array:
"""
Performs the operation __pos__.
"""
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in __pos__')
res = self._array.__pos__()
return self.__class__._new(res)
@@ -505,10 +578,9 @@ class Array:
"""
from ._elementwise_functions import pow
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __pow__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__pow__')
+ if other is NotImplemented:
+ return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
# arrays, so we use pow() here instead.
return pow(self, other)
@@ -517,8 +589,9 @@ class Array:
"""
Performs the operation __rshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__rshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rshift__(other._array)
return self.__class__._new(res)
@@ -537,8 +610,9 @@ class Array:
"""
Performs the operation __sub__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__sub__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__sub__(other._array)
return self.__class__._new(res)
@@ -549,10 +623,9 @@ class Array:
"""
Performs the operation __truediv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __truediv__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__truediv__(other._array)
return self.__class__._new(res)
@@ -561,8 +634,9 @@ class Array:
"""
Performs the operation __xor__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__xor__(other._array)
return self.__class__._new(res)
@@ -571,8 +645,9 @@ class Array:
"""
Performs the operation __iadd__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__iadd__')
+ if other is NotImplemented:
+ return other
self._array.__iadd__(other._array)
return self
@@ -580,8 +655,9 @@ class Array:
"""
Performs the operation __radd__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__radd__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__radd__(other._array)
return self.__class__._new(res)
@@ -590,8 +666,9 @@ class Array:
"""
Performs the operation __iand__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__')
+ if other is NotImplemented:
+ return other
self._array.__iand__(other._array)
return self
@@ -599,8 +676,9 @@ class Array:
"""
Performs the operation __rand__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rand__(other._array)
return self.__class__._new(res)
@@ -609,8 +687,9 @@ class Array:
"""
Performs the operation __ifloordiv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__')
+ if other is NotImplemented:
+ return other
self._array.__ifloordiv__(other._array)
return self
@@ -618,8 +697,9 @@ class Array:
"""
Performs the operation __rfloordiv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rfloordiv__(other._array)
return self.__class__._new(res)
@@ -628,8 +708,9 @@ class Array:
"""
Performs the operation __ilshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__ilshift__')
+ if other is NotImplemented:
+ return other
self._array.__ilshift__(other._array)
return self
@@ -637,8 +718,9 @@ class Array:
"""
Performs the operation __rlshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__rlshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rlshift__(other._array)
return self.__class__._new(res)
@@ -649,15 +731,17 @@ class Array:
"""
# Note: NumPy does not implement __imatmul__.
- if isinstance(other, (int, float, bool)):
- # matmul is not defined for scalars, but without this, we may get
- # the wrong error message from asarray.
- other = self._promote_scalar(other)
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__')
+ if other is NotImplemented:
+ return other
+
# __imatmul__ can only be allowed when it would not change the shape
# of self.
other_shape = other.shape
if self.shape == () or other_shape == ():
- raise ValueError("@= requires at least one dimension")
+ raise TypeError("@= requires at least one dimension")
if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
raise ValueError("@= cannot change the shape of the input array")
self._array[:] = self._array.__matmul__(other._array)
@@ -667,10 +751,11 @@ class Array:
"""
Performs the operation __rmatmul__.
"""
- if isinstance(other, (int, float, bool)):
- # matmul is not defined for scalars, but without this, we may get
- # the wrong error message from asarray.
- other = self._promote_scalar(other)
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__')
+ if other is NotImplemented:
+ return other
res = self._array.__rmatmul__(other._array)
return self.__class__._new(res)
@@ -678,8 +763,9 @@ class Array:
"""
Performs the operation __imod__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__imod__')
+ if other is NotImplemented:
+ return other
self._array.__imod__(other._array)
return self
@@ -687,8 +773,9 @@ class Array:
"""
Performs the operation __rmod__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rmod__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rmod__(other._array)
return self.__class__._new(res)
@@ -697,8 +784,9 @@ class Array:
"""
Performs the operation __imul__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__imul__')
+ if other is NotImplemented:
+ return other
self._array.__imul__(other._array)
return self
@@ -706,8 +794,9 @@ class Array:
"""
Performs the operation __rmul__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rmul__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rmul__(other._array)
return self.__class__._new(res)
@@ -716,8 +805,9 @@ class Array:
"""
Performs the operation __ior__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__')
+ if other is NotImplemented:
+ return other
self._array.__ior__(other._array)
return self
@@ -725,8 +815,9 @@ class Array:
"""
Performs the operation __ror__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ror__(other._array)
return self.__class__._new(res)
@@ -735,10 +826,9 @@ class Array:
"""
Performs the operation __ipow__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __pow__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__')
+ if other is NotImplemented:
+ return other
self._array.__ipow__(other._array)
return self
@@ -748,10 +838,9 @@ class Array:
"""
from ._elementwise_functions import pow
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __pow__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__')
+ if other is NotImplemented:
+ return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules
# for 0-d arrays, so we use pow() here instead.
return pow(other, self)
@@ -760,8 +849,9 @@ class Array:
"""
Performs the operation __irshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__irshift__')
+ if other is NotImplemented:
+ return other
self._array.__irshift__(other._array)
return self
@@ -769,8 +859,9 @@ class Array:
"""
Performs the operation __rrshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__rrshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rrshift__(other._array)
return self.__class__._new(res)
@@ -779,8 +870,9 @@ class Array:
"""
Performs the operation __isub__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__isub__')
+ if other is NotImplemented:
+ return other
self._array.__isub__(other._array)
return self
@@ -788,8 +880,9 @@ class Array:
"""
Performs the operation __rsub__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rsub__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rsub__(other._array)
return self.__class__._new(res)
@@ -798,10 +891,9 @@ class Array:
"""
Performs the operation __itruediv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __truediv__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__')
+ if other is NotImplemented:
+ return other
self._array.__itruediv__(other._array)
return self
@@ -809,10 +901,9 @@ class Array:
"""
Performs the operation __rtruediv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __truediv__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rtruediv__(other._array)
return self.__class__._new(res)
@@ -821,8 +912,9 @@ class Array:
"""
Performs the operation __ixor__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__')
+ if other is NotImplemented:
+ return other
self._array.__ixor__(other._array)
return self
@@ -830,8 +922,9 @@ class Array:
"""
Performs the operation __rxor__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rxor__(other._array)
return self.__class__._new(res)
diff --git a/numpy/_array_api/tests/test_array_object.py b/numpy/_array_api/tests/test_array_object.py
index 49ec3b37b..5aba2b23c 100644
--- a/numpy/_array_api/tests/test_array_object.py
+++ b/numpy/_array_api/tests/test_array_object.py
@@ -1,7 +1,10 @@
from numpy.testing import assert_raises
import numpy as np
-from .. import ones, asarray
+from .. import ones, asarray, result_type
+from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
+ _integer_dtypes, _integer_or_boolean_dtypes,
+ _numeric_dtypes, int8, int16, int32, int64, uint64)
def test_validate_index():
# The indexing tests in the official array API test suite test that the
@@ -57,3 +60,176 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None])
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])
+
+def test_operators():
+ # For every operator, we test that it works for the required type
+ # combinations and raises TypeError otherwise
+ binary_op_dtypes ={
+ '__add__': 'numeric',
+ '__and__': 'integer_or_boolean',
+ '__eq__': 'all',
+ '__floordiv__': 'numeric',
+ '__ge__': 'numeric',
+ '__gt__': 'numeric',
+ '__le__': 'numeric',
+ '__lshift__': 'integer',
+ '__lt__': 'numeric',
+ '__mod__': 'numeric',
+ '__mul__': 'numeric',
+ '__ne__': 'all',
+ '__or__': 'integer_or_boolean',
+ '__pow__': 'floating',
+ '__rshift__': 'integer',
+ '__sub__': 'numeric',
+ '__truediv__': 'floating',
+ '__xor__': 'integer_or_boolean',
+ }
+
+ # Recompute each time because of in-place ops
+ def _array_vals():
+ for d in _integer_dtypes:
+ yield asarray(1, dtype=d)
+ for d in _boolean_dtypes:
+ yield asarray(False, dtype=d)
+ for d in _floating_dtypes:
+ yield asarray(1., dtype=d)
+
+ for op, dtypes in binary_op_dtypes.items():
+ ops = [op]
+ if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']:
+ rop = '__r' + op[2:]
+ iop = '__i' + op[2:]
+ ops += [rop, iop]
+ for s in [1, 1., False]:
+ for _op in ops:
+ for a in _array_vals():
+ # Test array op scalar. From the spec, the following combinations
+ # are supported:
+
+ # - Python bool for a bool array dtype,
+ # - a Python int within the bounds of the given dtype for integer array dtypes,
+ # - a Python int or float for floating-point array dtypes
+
+ # We do not do bounds checking for int scalars, but rather use the default
+ # NumPy behavior for casting in that case.
+
+ if ((dtypes == "all"
+ or dtypes == "numeric" and a.dtype in _numeric_dtypes
+ or dtypes == "integer" and a.dtype in _integer_dtypes
+ or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
+ or dtypes == "boolean" and a.dtype in _boolean_dtypes
+ or dtypes == "floating" and a.dtype in _floating_dtypes
+ )
+ # bool is a subtype of int, which is why we avoid
+ # isinstance here.
+ and (a.dtype in _boolean_dtypes and type(s) == bool
+ or a.dtype in _integer_dtypes and type(s) == int
+ or a.dtype in _floating_dtypes and type(s) in [float, int]
+ )):
+ # Only test for no error
+ getattr(a, _op)(s)
+ else:
+ assert_raises(TypeError, lambda: getattr(a, _op)(s))
+
+ # Test array op array.
+ for _op in ops:
+ for x in _array_vals():
+ for y in _array_vals():
+ # See the promotion table in NEP 47 or the array
+ # API spec page on type promotion. Mixed kind
+ # promotion is not defined.
+ if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
+ or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
+ or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
+ or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
+ or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
+ or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
+ or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
+ or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
+ ):
+ assert_raises(TypeError, lambda: getattr(x, _op)(y))
+ # Ensure in-place operators only promote to the same dtype as the left operand.
+ elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype:
+ assert_raises(TypeError, lambda: getattr(x, _op)(y))
+ # Ensure only those dtypes that are required for every operator are allowed.
+ elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
+ or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
+ or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
+ or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes
+ or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
+ or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes)
+ or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
+ or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
+ ):
+ getattr(x, _op)(y)
+ else:
+ assert_raises(TypeError, lambda: getattr(x, _op)(y))
+
+ unary_op_dtypes ={
+ '__invert__': 'integer_or_boolean',
+ '__neg__': 'numeric',
+ '__pos__': 'numeric',
+ }
+ for op, dtypes in unary_op_dtypes.items():
+ for a in _array_vals():
+ if (dtypes == "numeric" and a.dtype in _numeric_dtypes
+ or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
+ ):
+ # Only test for no error
+ getattr(a, op)()
+ else:
+ assert_raises(TypeError, lambda: getattr(a, op)())
+
+ # Finally, matmul() must be tested separately, because it works a bit
+ # different from the other operations.
+ def _matmul_array_vals():
+ for a in _array_vals():
+ yield a
+ for d in _all_dtypes:
+ yield ones((3, 4), dtype=d)
+ yield ones((4, 2), dtype=d)
+ yield ones((4, 4), dtype=d)
+
+ # Scalars always error
+ for _op in ['__matmul__', '__rmatmul__', '__imatmul__']:
+ for s in [1, 1., False]:
+ for a in _matmul_array_vals():
+ if (type(s) in [float, int] and a.dtype in _floating_dtypes
+ or type(s) == int and a.dtype in _integer_dtypes):
+ # Type promotion is valid, but @ is not allowed on 0-D
+ # inputs, so the error is a ValueError
+ assert_raises(ValueError, lambda: getattr(a, _op)(s))
+ else:
+ assert_raises(TypeError, lambda: getattr(a, _op)(s))
+
+ for x in _matmul_array_vals():
+ for y in _matmul_array_vals():
+ if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
+ or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
+ or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
+ or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
+ or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
+ or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
+ or x.dtype in _boolean_dtypes
+ or y.dtype in _boolean_dtypes
+ ):
+ assert_raises(TypeError, lambda: x.__matmul__(y))
+ assert_raises(TypeError, lambda: y.__rmatmul__(x))
+ assert_raises(TypeError, lambda: x.__imatmul__(y))
+ elif x.shape == () or y.shape == () or x.shape[1] != y.shape[0]:
+ assert_raises(ValueError, lambda: x.__matmul__(y))
+ assert_raises(ValueError, lambda: y.__rmatmul__(x))
+ if result_type(x.dtype, y.dtype) != x.dtype:
+ assert_raises(TypeError, lambda: x.__imatmul__(y))
+ else:
+ assert_raises(ValueError, lambda: x.__imatmul__(y))
+ else:
+ x.__matmul__(y)
+ y.__rmatmul__(x)
+ if result_type(x.dtype, y.dtype) != x.dtype:
+ assert_raises(TypeError, lambda: x.__imatmul__(y))
+ elif y.shape[0] != y.shape[1]:
+ # This one fails because x @ y has a different shape from x
+ assert_raises(ValueError, lambda: x.__imatmul__(y))
+ else:
+ x.__imatmul__(y)