diff options
Diffstat (limited to 'numpy/_array_api')
-rw-r--r-- | numpy/_array_api/_array_object.py | 9 | ||||
-rw-r--r-- | numpy/_array_api/_dtypes.py | 7 |
2 files changed, 15 insertions, 1 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 2d999e2f3..505c27839 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -119,12 +119,19 @@ class Array: else: return NotImplemented + # This will raise TypeError for type combinations that are not allowed + # to promote in the spec (even if the NumPy array operator would + # promote them). res_dtype = _result_type(self.dtype, other.dtype) if op.startswith('__i'): - # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example, + # Note: NumPy will allow in-place operators in some cases where + # the type promoted operator does not match the left-hand side + # operand. For example, # >>> a = np.array(1, dtype=np.int8) # >>> a += np.array(1, dtype=np.int16) + + # The spec explicitly disallows this. if res_dtype != self.dtype: raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") diff --git a/numpy/_array_api/_dtypes.py b/numpy/_array_api/_dtypes.py index 9abe4cc83..fcdb562da 100644 --- a/numpy/_array_api/_dtypes.py +++ b/numpy/_array_api/_dtypes.py @@ -23,6 +23,13 @@ _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) _integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64) _numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) +# Note: the spec defines a restricted type promotion table compared to NumPy. +# In particular, cross-kind promotions like integer + float or boolean + +# integer are not allowed, even for functions that accept both kinds. +# Additionally, NumPy promotes signed integer + uint64 to float64, but this +# promotion is not allowed here. To be clear, Python scalar int objects are +# allowed to promote to floating-point dtypes, but only in array operators +# (see Array._promote_scalar) method in _array_object.py. _promotion_table = { (int8, int8): int8, (int8, int16): int16, |