summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_array_object.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r--numpy/_array_api/_array_object.py305
1 files changed, 199 insertions, 106 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index f8bad0b59..f6371fbf4 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -18,7 +18,8 @@ from __future__ import annotations
import operator
from enum import IntEnum
from ._creation_functions import asarray
-from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes
+from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes,
+ _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes)
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
if TYPE_CHECKING:
@@ -83,6 +84,52 @@ class Array:
"""
return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})"
+ # These are various helper functions to make the array behavior match the
+ # spec in places where it either deviates from or is more strict than
+ # NumPy behavior
+
+ def _check_allowed_dtypes(self, other, dtype_category, op):
+ """
+ Helper function for operators to only allow specific input dtypes
+
+ Use like
+
+ other = self._check_allowed_dtypes(other, 'numeric', '__add__')
+ if other is NotImplemented:
+ return other
+ """
+ from ._dtypes import _result_type
+
+ _dtypes = {
+ 'all': _all_dtypes,
+ 'numeric': _numeric_dtypes,
+ 'integer': _integer_dtypes,
+ 'integer or boolean': _integer_or_boolean_dtypes,
+ 'boolean': _boolean_dtypes,
+ 'floating-point': _floating_dtypes,
+ }
+
+ if self.dtype not in _dtypes[dtype_category]:
+ raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
+ elif isinstance(other, Array):
+ if other.dtype not in _dtypes[dtype_category]:
+ raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
+ else:
+ return NotImplemented
+
+ res_dtype = _result_type(self.dtype, other.dtype)
+ if op.startswith('__i'):
+ # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example,
+
+ # >>> a = np.array(1, dtype=np.int8)
+ # >>> a += np.array(1, dtype=np.int16)
+ if res_dtype != self.dtype:
+ raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}")
+
+ return other
+
# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar):
"""
@@ -270,8 +317,9 @@ class Array:
"""
Performs the operation __add__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__add__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__add__(other._array)
return self.__class__._new(res)
@@ -280,8 +328,9 @@ class Array:
"""
Performs the operation __and__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__and__(other._array)
return self.__class__._new(res)
@@ -321,6 +370,11 @@ class Array:
"""
Performs the operation __eq__.
"""
+ # Even though "all" dtypes are allowed, we still require them to be
+ # promotable with each other.
+ other = self._check_allowed_dtypes(other, 'all', '__eq__')
+ if other is NotImplemented:
+ return other
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
self, other = self._normalize_two_args(self, other)
@@ -341,8 +395,9 @@ class Array:
"""
Performs the operation __floordiv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__floordiv__(other._array)
return self.__class__._new(res)
@@ -351,8 +406,9 @@ class Array:
"""
Performs the operation __ge__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__ge__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ge__(other._array)
return self.__class__._new(res)
@@ -371,8 +427,9 @@ class Array:
"""
Performs the operation __gt__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__gt__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__gt__(other._array)
return self.__class__._new(res)
@@ -391,6 +448,8 @@ class Array:
"""
Performs the operation __invert__.
"""
+ if self.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError('Only integer or boolean dtypes are allowed in __invert__')
res = self._array.__invert__()
return self.__class__._new(res)
@@ -398,8 +457,9 @@ class Array:
"""
Performs the operation __le__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__le__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__le__(other._array)
return self.__class__._new(res)
@@ -416,8 +476,9 @@ class Array:
"""
Performs the operation __lshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__lshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lshift__(other._array)
return self.__class__._new(res)
@@ -426,8 +487,9 @@ class Array:
"""
Performs the operation __lt__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__lt__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lt__(other._array)
return self.__class__._new(res)
@@ -436,10 +498,11 @@ class Array:
"""
Performs the operation __matmul__.
"""
- if isinstance(other, (int, float, bool)):
- # matmul is not defined for scalars, but without this, we may get
- # the wrong error message from asarray.
- other = self._promote_scalar(other)
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, 'numeric', '__matmul__')
+ if other is NotImplemented:
+ return other
res = self._array.__matmul__(other._array)
return self.__class__._new(res)
@@ -447,8 +510,9 @@ class Array:
"""
Performs the operation __mod__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__mod__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mod__(other._array)
return self.__class__._new(res)
@@ -457,8 +521,9 @@ class Array:
"""
Performs the operation __mul__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__mul__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mul__(other._array)
return self.__class__._new(res)
@@ -467,6 +532,9 @@ class Array:
"""
Performs the operation __ne__.
"""
+ other = self._check_allowed_dtypes(other, 'all', '__ne__')
+ if other is NotImplemented:
+ return other
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
self, other = self._normalize_two_args(self, other)
@@ -477,6 +545,8 @@ class Array:
"""
Performs the operation __neg__.
"""
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in __neg__')
res = self._array.__neg__()
return self.__class__._new(res)
@@ -484,8 +554,9 @@ class Array:
"""
Performs the operation __or__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__or__(other._array)
return self.__class__._new(res)
@@ -494,6 +565,8 @@ class Array:
"""
Performs the operation __pos__.
"""
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in __pos__')
res = self._array.__pos__()
return self.__class__._new(res)
@@ -505,10 +578,9 @@ class Array:
"""
from ._elementwise_functions import pow
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __pow__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__pow__')
+ if other is NotImplemented:
+ return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
# arrays, so we use pow() here instead.
return pow(self, other)
@@ -517,8 +589,9 @@ class Array:
"""
Performs the operation __rshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__rshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rshift__(other._array)
return self.__class__._new(res)
@@ -537,8 +610,9 @@ class Array:
"""
Performs the operation __sub__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__sub__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__sub__(other._array)
return self.__class__._new(res)
@@ -549,10 +623,9 @@ class Array:
"""
Performs the operation __truediv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __truediv__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__truediv__(other._array)
return self.__class__._new(res)
@@ -561,8 +634,9 @@ class Array:
"""
Performs the operation __xor__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__xor__(other._array)
return self.__class__._new(res)
@@ -571,8 +645,9 @@ class Array:
"""
Performs the operation __iadd__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__iadd__')
+ if other is NotImplemented:
+ return other
self._array.__iadd__(other._array)
return self
@@ -580,8 +655,9 @@ class Array:
"""
Performs the operation __radd__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__radd__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__radd__(other._array)
return self.__class__._new(res)
@@ -590,8 +666,9 @@ class Array:
"""
Performs the operation __iand__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__')
+ if other is NotImplemented:
+ return other
self._array.__iand__(other._array)
return self
@@ -599,8 +676,9 @@ class Array:
"""
Performs the operation __rand__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rand__(other._array)
return self.__class__._new(res)
@@ -609,8 +687,9 @@ class Array:
"""
Performs the operation __ifloordiv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__')
+ if other is NotImplemented:
+ return other
self._array.__ifloordiv__(other._array)
return self
@@ -618,8 +697,9 @@ class Array:
"""
Performs the operation __rfloordiv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rfloordiv__(other._array)
return self.__class__._new(res)
@@ -628,8 +708,9 @@ class Array:
"""
Performs the operation __ilshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__ilshift__')
+ if other is NotImplemented:
+ return other
self._array.__ilshift__(other._array)
return self
@@ -637,8 +718,9 @@ class Array:
"""
Performs the operation __rlshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__rlshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rlshift__(other._array)
return self.__class__._new(res)
@@ -649,15 +731,17 @@ class Array:
"""
# Note: NumPy does not implement __imatmul__.
- if isinstance(other, (int, float, bool)):
- # matmul is not defined for scalars, but without this, we may get
- # the wrong error message from asarray.
- other = self._promote_scalar(other)
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__')
+ if other is NotImplemented:
+ return other
+
# __imatmul__ can only be allowed when it would not change the shape
# of self.
other_shape = other.shape
if self.shape == () or other_shape == ():
- raise ValueError("@= requires at least one dimension")
+ raise TypeError("@= requires at least one dimension")
if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
raise ValueError("@= cannot change the shape of the input array")
self._array[:] = self._array.__matmul__(other._array)
@@ -667,10 +751,11 @@ class Array:
"""
Performs the operation __rmatmul__.
"""
- if isinstance(other, (int, float, bool)):
- # matmul is not defined for scalars, but without this, we may get
- # the wrong error message from asarray.
- other = self._promote_scalar(other)
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__')
+ if other is NotImplemented:
+ return other
res = self._array.__rmatmul__(other._array)
return self.__class__._new(res)
@@ -678,8 +763,9 @@ class Array:
"""
Performs the operation __imod__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__imod__')
+ if other is NotImplemented:
+ return other
self._array.__imod__(other._array)
return self
@@ -687,8 +773,9 @@ class Array:
"""
Performs the operation __rmod__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rmod__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rmod__(other._array)
return self.__class__._new(res)
@@ -697,8 +784,9 @@ class Array:
"""
Performs the operation __imul__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__imul__')
+ if other is NotImplemented:
+ return other
self._array.__imul__(other._array)
return self
@@ -706,8 +794,9 @@ class Array:
"""
Performs the operation __rmul__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rmul__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rmul__(other._array)
return self.__class__._new(res)
@@ -716,8 +805,9 @@ class Array:
"""
Performs the operation __ior__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__')
+ if other is NotImplemented:
+ return other
self._array.__ior__(other._array)
return self
@@ -725,8 +815,9 @@ class Array:
"""
Performs the operation __ror__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ror__(other._array)
return self.__class__._new(res)
@@ -735,10 +826,9 @@ class Array:
"""
Performs the operation __ipow__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __pow__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__')
+ if other is NotImplemented:
+ return other
self._array.__ipow__(other._array)
return self
@@ -748,10 +838,9 @@ class Array:
"""
from ._elementwise_functions import pow
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __pow__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__')
+ if other is NotImplemented:
+ return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules
# for 0-d arrays, so we use pow() here instead.
return pow(other, self)
@@ -760,8 +849,9 @@ class Array:
"""
Performs the operation __irshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__irshift__')
+ if other is NotImplemented:
+ return other
self._array.__irshift__(other._array)
return self
@@ -769,8 +859,9 @@ class Array:
"""
Performs the operation __rrshift__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer', '__rrshift__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rrshift__(other._array)
return self.__class__._new(res)
@@ -779,8 +870,9 @@ class Array:
"""
Performs the operation __isub__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__isub__')
+ if other is NotImplemented:
+ return other
self._array.__isub__(other._array)
return self
@@ -788,8 +880,9 @@ class Array:
"""
Performs the operation __rsub__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'numeric', '__rsub__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rsub__(other._array)
return self.__class__._new(res)
@@ -798,10 +891,9 @@ class Array:
"""
Performs the operation __itruediv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __truediv__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__')
+ if other is NotImplemented:
+ return other
self._array.__itruediv__(other._array)
return self
@@ -809,10 +901,9 @@ class Array:
"""
Performs the operation __rtruediv__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
- if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in __truediv__')
+ other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rtruediv__(other._array)
return self.__class__._new(res)
@@ -821,8 +912,9 @@ class Array:
"""
Performs the operation __ixor__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__')
+ if other is NotImplemented:
+ return other
self._array.__ixor__(other._array)
return self
@@ -830,8 +922,9 @@ class Array:
"""
Performs the operation __rxor__.
"""
- if isinstance(other, (int, float, bool)):
- other = self._promote_scalar(other)
+ other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__')
+ if other is NotImplemented:
+ return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rxor__(other._array)
return self.__class__._new(res)