diff options
Diffstat (limited to 'numpy/typing/_callable.py')
-rw-r--r-- | numpy/typing/_callable.py | 249 |
1 files changed, 130 insertions, 119 deletions
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py index c703df28a..8f911da3b 100644 --- a/numpy/typing/_callable.py +++ b/numpy/typing/_callable.py @@ -8,6 +8,8 @@ See the `Mypy documentation`_ on protocols for more details. """ +from __future__ import annotations + import sys from typing import ( Union, @@ -21,6 +23,7 @@ from typing import ( from numpy import ( ndarray, + dtype, generic, bool_, timedelta64, @@ -29,38 +32,34 @@ from numpy import ( unsignedinteger, signedinteger, int8, + int_, floating, float64, complexfloating, complex128, ) +from ._nbit import _NBitInt, _NBitDouble from ._scalars import ( - _BoolLike, - _IntLike, - _FloatLike, - _ComplexLike, - _NumberLike, + _BoolLike_co, + _IntLike_co, + _FloatLike_co, + _NumberLike_co, ) -from . import NBitBase -from ._array_like import ArrayLike +from . import NBitBase, _HAS_TYPING_EXTENSIONS +from ._generic_alias import NDArray if sys.version_info >= (3, 8): from typing import Protocol - HAVE_PROTOCOL = True -else: - try: - from typing_extensions import Protocol - except ImportError: - HAVE_PROTOCOL = False - else: - HAVE_PROTOCOL = True +elif _HAS_TYPING_EXTENSIONS: + from typing_extensions import Protocol -if TYPE_CHECKING or HAVE_PROTOCOL: - _T = TypeVar("_T") - _2Tuple = Tuple[_T, _T] +if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS or sys.version_info >= (3, 8): + _T1 = TypeVar("_T1") + _T2 = TypeVar("_T2") + _2Tuple = Tuple[_T1, _T1] - _NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase) - _NBit = TypeVar("_NBit", bound=NBitBase) + _NBit1 = TypeVar("_NBit1", bound=NBitBase) + _NBit2 = TypeVar("_NBit2", bound=NBitBase) _IntType = TypeVar("_IntType", bound=integer) _FloatType = TypeVar("_FloatType", bound=floating) @@ -70,9 +69,9 @@ if TYPE_CHECKING or HAVE_PROTOCOL: class _BoolOp(Protocol[_GenericType_co]): @overload - def __call__(self, __other: _BoolLike) -> _GenericType_co: ... + def __call__(self, __other: _BoolLike_co) -> _GenericType_co: ... @overload # platform dependent - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> int_: ... @overload def __call__(self, __other: float) -> float64: ... @overload @@ -82,9 +81,9 @@ if TYPE_CHECKING or HAVE_PROTOCOL: class _BoolBitOp(Protocol[_GenericType_co]): @overload - def __call__(self, __other: _BoolLike) -> _GenericType_co: ... + def __call__(self, __other: _BoolLike_co) -> _GenericType_co: ... @overload # platform dependent - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> int_: ... @overload def __call__(self, __other: _IntType) -> _IntType: ... @@ -93,7 +92,7 @@ if TYPE_CHECKING or HAVE_PROTOCOL: @overload def __call__(self, __other: bool) -> NoReturn: ... @overload # platform dependent - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> int_: ... @overload def __call__(self, __other: float) -> float64: ... @overload @@ -103,7 +102,7 @@ if TYPE_CHECKING or HAVE_PROTOCOL: class _BoolTrueDiv(Protocol): @overload - def __call__(self, __other: Union[float, _IntLike, _BoolLike]) -> float64: ... + def __call__(self, __other: float | _IntLike_co) -> float64: ... @overload def __call__(self, __other: complex) -> complex128: ... @overload @@ -111,9 +110,9 @@ if TYPE_CHECKING or HAVE_PROTOCOL: class _BoolMod(Protocol): @overload - def __call__(self, __other: _BoolLike) -> int8: ... + def __call__(self, __other: _BoolLike_co) -> int8: ... @overload # platform dependent - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> int_: ... @overload def __call__(self, __other: float) -> float64: ... @overload @@ -123,11 +122,11 @@ if TYPE_CHECKING or HAVE_PROTOCOL: class _BoolDivMod(Protocol): @overload - def __call__(self, __other: _BoolLike) -> _2Tuple[int8]: ... + def __call__(self, __other: _BoolLike_co) -> _2Tuple[int8]: ... @overload # platform dependent - def __call__(self, __other: int) -> _2Tuple[signedinteger[Any]]: ... + def __call__(self, __other: int) -> _2Tuple[int_]: ... @overload - def __call__(self, __other: float) -> _2Tuple[float64]: ... + def __call__(self, __other: float) -> _2Tuple[floating[_NBit1 | _NBitDouble]]: ... @overload def __call__(self, __other: _IntType) -> _2Tuple[_IntType]: ... @overload @@ -137,188 +136,200 @@ if TYPE_CHECKING or HAVE_PROTOCOL: @overload def __call__(self, __other: timedelta64) -> _NumberType_co: ... @overload - def __call__(self, __other: _FloatLike) -> timedelta64: ... + def __call__(self, __other: _BoolLike_co) -> NoReturn: ... + @overload + def __call__(self, __other: _FloatLike_co) -> timedelta64: ... - class _IntTrueDiv(Protocol[_NBit_co]): + class _IntTrueDiv(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> floating[_NBit_co]: ... + def __call__(self, __other: bool) -> floating[_NBit1]: ... @overload - def __call__(self, __other: int) -> floating[Any]: ... + def __call__(self, __other: int) -> floating[_NBit1 | _NBitInt]: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload - def __call__(self, __other: complex) -> complex128: ... + def __call__( + self, __other: complex + ) -> complexfloating[_NBit1 | _NBitDouble, _NBit1 | _NBitDouble]: ... @overload - def __call__(self, __other: integer[_NBit]) -> floating[Union[_NBit_co, _NBit]]: ... + def __call__(self, __other: integer[_NBit2]) -> floating[_NBit1 | _NBit2]: ... - class _UnsignedIntOp(Protocol[_NBit_co]): + class _UnsignedIntOp(Protocol[_NBit1]): # NOTE: `uint64 + signedinteger -> float64` @overload - def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... + def __call__(self, __other: bool) -> unsignedinteger[_NBit1]: ... @overload def __call__( - self, __other: Union[int, signedinteger[Any]] - ) -> Union[signedinteger[Any], float64]: ... + self, __other: int | signedinteger[Any] + ) -> Any: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload - def __call__(self, __other: complex) -> complex128: ... + def __call__( + self, __other: complex + ) -> complexfloating[_NBit1 | _NBitDouble, _NBit1 | _NBitDouble]: ... @overload def __call__( - self, __other: unsignedinteger[_NBit] - ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ... + self, __other: unsignedinteger[_NBit2] + ) -> unsignedinteger[_NBit1 | _NBit2]: ... - class _UnsignedIntBitOp(Protocol[_NBit_co]): + class _UnsignedIntBitOp(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... + def __call__(self, __other: bool) -> unsignedinteger[_NBit1]: ... @overload def __call__(self, __other: int) -> signedinteger[Any]: ... @overload def __call__(self, __other: signedinteger[Any]) -> signedinteger[Any]: ... @overload def __call__( - self, __other: unsignedinteger[_NBit] - ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ... + self, __other: unsignedinteger[_NBit2] + ) -> unsignedinteger[_NBit1 | _NBit2]: ... - class _UnsignedIntMod(Protocol[_NBit_co]): + class _UnsignedIntMod(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... + def __call__(self, __other: bool) -> unsignedinteger[_NBit1]: ... @overload def __call__( - self, __other: Union[int, signedinteger[Any]] - ) -> Union[signedinteger[Any], float64]: ... + self, __other: int | signedinteger[Any] + ) -> Any: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload def __call__( - self, __other: unsignedinteger[_NBit] - ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ... + self, __other: unsignedinteger[_NBit2] + ) -> unsignedinteger[_NBit1 | _NBit2]: ... - class _UnsignedIntDivMod(Protocol[_NBit_co]): + class _UnsignedIntDivMod(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> _2Tuple[signedinteger[_NBit_co]]: ... + def __call__(self, __other: bool) -> _2Tuple[signedinteger[_NBit1]]: ... @overload def __call__( - self, __other: Union[int, signedinteger[Any]] - ) -> Union[_2Tuple[signedinteger[Any]], _2Tuple[float64]]: ... + self, __other: int | signedinteger[Any] + ) -> _2Tuple[Any]: ... @overload - def __call__(self, __other: float) -> _2Tuple[float64]: ... + def __call__(self, __other: float) -> _2Tuple[floating[_NBit1 | _NBitDouble]]: ... @overload def __call__( - self, __other: unsignedinteger[_NBit] - ) -> _2Tuple[unsignedinteger[Union[_NBit_co, _NBit]]]: ... + self, __other: unsignedinteger[_NBit2] + ) -> _2Tuple[unsignedinteger[_NBit1 | _NBit2]]: ... - class _SignedIntOp(Protocol[_NBit_co]): + class _SignedIntOp(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... + def __call__(self, __other: bool) -> signedinteger[_NBit1]: ... @overload - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> signedinteger[_NBit1 | _NBitInt]: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload - def __call__(self, __other: complex) -> complex128: ... + def __call__( + self, __other: complex + ) -> complexfloating[_NBit1 | _NBitDouble, _NBit1 | _NBitDouble]: ... @overload def __call__( - self, __other: signedinteger[_NBit] - ) -> signedinteger[Union[_NBit_co, _NBit]]: ... + self, __other: signedinteger[_NBit2] + ) -> signedinteger[_NBit1 | _NBit2]: ... - class _SignedIntBitOp(Protocol[_NBit_co]): + class _SignedIntBitOp(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... + def __call__(self, __other: bool) -> signedinteger[_NBit1]: ... @overload - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> signedinteger[_NBit1 | _NBitInt]: ... @overload def __call__( - self, __other: signedinteger[_NBit] - ) -> signedinteger[Union[_NBit_co, _NBit]]: ... + self, __other: signedinteger[_NBit2] + ) -> signedinteger[_NBit1 | _NBit2]: ... - class _SignedIntMod(Protocol[_NBit_co]): + class _SignedIntMod(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... + def __call__(self, __other: bool) -> signedinteger[_NBit1]: ... @overload - def __call__(self, __other: int) -> signedinteger[Any]: ... + def __call__(self, __other: int) -> signedinteger[_NBit1 | _NBitInt]: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload def __call__( - self, __other: signedinteger[_NBit] - ) -> signedinteger[Union[_NBit_co, _NBit]]: ... + self, __other: signedinteger[_NBit2] + ) -> signedinteger[_NBit1 | _NBit2]: ... - class _SignedIntDivMod(Protocol[_NBit_co]): + class _SignedIntDivMod(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> _2Tuple[signedinteger[_NBit_co]]: ... + def __call__(self, __other: bool) -> _2Tuple[signedinteger[_NBit1]]: ... @overload - def __call__(self, __other: int) -> _2Tuple[signedinteger[Any]]: ... + def __call__(self, __other: int) -> _2Tuple[signedinteger[_NBit1 | _NBitInt]]: ... @overload - def __call__(self, __other: float) -> _2Tuple[float64]: ... + def __call__(self, __other: float) -> _2Tuple[floating[_NBit1 | _NBitDouble]]: ... @overload def __call__( - self, __other: signedinteger[_NBit] - ) -> _2Tuple[signedinteger[Union[_NBit_co, _NBit]]]: ... + self, __other: signedinteger[_NBit2] + ) -> _2Tuple[signedinteger[_NBit1 | _NBit2]]: ... - class _FloatOp(Protocol[_NBit_co]): + class _FloatOp(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> floating[_NBit_co]: ... + def __call__(self, __other: bool) -> floating[_NBit1]: ... @overload - def __call__(self, __other: int) -> floating[Any]: ... + def __call__(self, __other: int) -> floating[_NBit1 | _NBitInt]: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload - def __call__(self, __other: complex) -> complex128: ... + def __call__( + self, __other: complex + ) -> complexfloating[_NBit1 | _NBitDouble, _NBit1 | _NBitDouble]: ... @overload def __call__( - self, __other: Union[integer[_NBit], floating[_NBit]] - ) -> floating[Union[_NBit_co, _NBit]]: ... + self, __other: integer[_NBit2] | floating[_NBit2] + ) -> floating[_NBit1 | _NBit2]: ... - class _FloatMod(Protocol[_NBit_co]): + class _FloatMod(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> floating[_NBit_co]: ... + def __call__(self, __other: bool) -> floating[_NBit1]: ... @overload - def __call__(self, __other: int) -> floating[Any]: ... + def __call__(self, __other: int) -> floating[_NBit1 | _NBitInt]: ... @overload - def __call__(self, __other: float) -> float64: ... + def __call__(self, __other: float) -> floating[_NBit1 | _NBitDouble]: ... @overload def __call__( - self, __other: Union[integer[_NBit], floating[_NBit]] - ) -> floating[Union[_NBit_co, _NBit]]: ... + self, __other: integer[_NBit2] | floating[_NBit2] + ) -> floating[_NBit1 | _NBit2]: ... - class _FloatDivMod(Protocol[_NBit_co]): + class _FloatDivMod(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> _2Tuple[floating[_NBit_co]]: ... + def __call__(self, __other: bool) -> _2Tuple[floating[_NBit1]]: ... @overload - def __call__(self, __other: int) -> _2Tuple[floating[Any]]: ... + def __call__(self, __other: int) -> _2Tuple[floating[_NBit1 | _NBitInt]]: ... @overload - def __call__(self, __other: float) -> _2Tuple[float64]: ... + def __call__(self, __other: float) -> _2Tuple[floating[_NBit1 | _NBitDouble]]: ... @overload def __call__( - self, __other: Union[integer[_NBit], floating[_NBit]] - ) -> _2Tuple[floating[Union[_NBit_co, _NBit]]]: ... + self, __other: integer[_NBit2] | floating[_NBit2] + ) -> _2Tuple[floating[_NBit1 | _NBit2]]: ... - class _ComplexOp(Protocol[_NBit_co]): + class _ComplexOp(Protocol[_NBit1]): @overload - def __call__(self, __other: bool) -> complexfloating[_NBit_co, _NBit_co]: ... + def __call__(self, __other: bool) -> complexfloating[_NBit1, _NBit1]: ... @overload - def __call__(self, __other: int) -> complexfloating[Any, Any]: ... + def __call__(self, __other: int) -> complexfloating[_NBit1 | _NBitInt, _NBit1 | _NBitInt]: ... @overload - def __call__(self, __other: Union[float, complex]) -> complex128: ... + def __call__( + self, __other: complex + ) -> complexfloating[_NBit1 | _NBitDouble, _NBit1 | _NBitDouble]: ... @overload def __call__( self, __other: Union[ - integer[_NBit], - floating[_NBit], - complexfloating[_NBit, _NBit], + integer[_NBit2], + floating[_NBit2], + complexfloating[_NBit2, _NBit2], ] - ) -> complexfloating[Union[_NBit_co, _NBit], Union[_NBit_co, _NBit]]: ... + ) -> complexfloating[_NBit1 | _NBit2, _NBit1 | _NBit2]: ... class _NumberOp(Protocol): - def __call__(self, __other: _NumberLike) -> number: ... + def __call__(self, __other: _NumberLike_co) -> Any: ... - class _ComparisonOp(Protocol[_T]): + class _ComparisonOp(Protocol[_T1, _T2]): @overload - def __call__(self, __other: _T) -> bool_: ... + def __call__(self, __other: _T1) -> bool_: ... @overload - def __call__(self, __other: ArrayLike) -> Union[ndarray, bool_]: ... + def __call__(self, __other: _T2) -> NDArray[bool_]: ... else: _BoolOp = Any |