diff options
Diffstat (limited to 'numpy/typing/_callable.py')
-rw-r--r-- | numpy/typing/_callable.py | 116 |
1 files changed, 86 insertions, 30 deletions
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py index 7c2ee86cb..b16891b0e 100644 --- a/numpy/typing/_callable.py +++ b/numpy/typing/_callable.py @@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details. """ import sys -from typing import Union, TypeVar, overload, Any +from typing import Union, TypeVar, overload, Any, TYPE_CHECKING, NoReturn from numpy import ( generic, @@ -25,6 +25,7 @@ from numpy import ( float32, float64, complexfloating, + complex64, complex128, ) from ._scalars import ( @@ -34,6 +35,7 @@ from ._scalars import ( _ComplexLike, _NumberLike, ) +from . import NBitBase if sys.version_info >= (3, 8): from typing import Protocol @@ -46,7 +48,9 @@ else: else: HAVE_PROTOCOL = True -if HAVE_PROTOCOL: +if TYPE_CHECKING or HAVE_PROTOCOL: + _NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase) + _NBit = TypeVar("_NBit", bound=NBitBase) _IntType = TypeVar("_IntType", bound=integer) _NumberType = TypeVar("_NumberType", bound=number) _NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number) @@ -56,7 +60,7 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: _BoolLike) -> _GenericType_co: ... @overload # platform dependent - def __call__(self, __other: int) -> Union[int32, int64]: ... + def __call__(self, __other: int) -> signedinteger[Any]: ... @overload def __call__(self, __other: float) -> float64: ... @overload @@ -68,14 +72,16 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: _BoolLike) -> _GenericType_co: ... @overload # platform dependent - def __call__(self, __other: int) -> Union[int32, int64]: ... + def __call__(self, __other: int) -> signedinteger[Any]: ... @overload def __call__(self, __other: _IntType) -> _IntType: ... class _BoolSub(Protocol): # Note that `__other: bool_` is absent here + @overload + def __call__(self, __other: bool) -> NoReturn: ... @overload # platform dependent - def __call__(self, __other: int) -> Union[int32, int64]: ... + def __call__(self, __other: int) -> signedinteger[Any]: ... @overload def __call__(self, __other: float) -> float64: ... @overload @@ -97,51 +103,101 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: _FloatLike) -> timedelta64: ... - class _IntTrueDiv(Protocol): + class _IntTrueDiv(Protocol[_NBit_co]): + @overload + def __call__(self, __other: bool) -> floating[_NBit_co]: ... @overload - def __call__(self, __other: Union[_IntLike, float]) -> floating: ... + def __call__(self, __other: int) -> floating[Any]: ... + @overload + def __call__(self, __other: float) -> float64: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__(self, __other: complex) -> complex128: ... + @overload + def __call__(self, __other: integer[_NBit]) -> floating[Union[_NBit_co, _NBit]]: ... - class _UnsignedIntOp(Protocol): + class _UnsignedIntOp(Protocol[_NBit_co]): # NOTE: `uint64 + signedinteger -> float64` @overload - def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ... + def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... + @overload + def __call__( + self, __other: Union[int, signedinteger[Any]] + ) -> Union[signedinteger[Any], float64]: ... @overload - def __call__(self, __other: Union[int, signedinteger]) -> Union[signedinteger, float64]: ... + def __call__(self, __other: float) -> float64: ... @overload - def __call__(self, __other: float) -> floating: ... + def __call__(self, __other: complex) -> complex128: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__( + self, __other: unsignedinteger[_NBit] + ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ... - class _UnsignedIntBitOp(Protocol): - # TODO: The likes of `uint64 | np.signedinteger` will fail as there - # is no signed integer type large enough to hold a `uint64` - # See https://github.com/numpy/numpy/issues/2524 + class _UnsignedIntBitOp(Protocol[_NBit_co]): + @overload + def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... + @overload + def __call__(self, __other: int) -> signedinteger[Any]: ... @overload - def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ... + def __call__(self, __other: signedinteger[Any]) -> signedinteger[Any]: ... @overload - def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + def __call__( + self, __other: unsignedinteger[_NBit] + ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ... - class _SignedIntOp(Protocol): + class _SignedIntOp(Protocol[_NBit_co]): @overload - def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... @overload - def __call__(self, __other: float) -> floating: ... + def __call__(self, __other: int) -> signedinteger[Any]: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__(self, __other: float) -> float64: ... + @overload + def __call__(self, __other: complex) -> complex128: ... + @overload + def __call__( + self, __other: signedinteger[_NBit] + ) -> signedinteger[Union[_NBit_co, _NBit]]: ... - class _SignedIntBitOp(Protocol): - def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + class _SignedIntBitOp(Protocol[_NBit_co]): + @overload + def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... + @overload + def __call__(self, __other: int) -> signedinteger[Any]: ... + @overload + def __call__( + self, __other: signedinteger[_NBit] + ) -> signedinteger[Union[_NBit_co, _NBit]]: ... - class _FloatOp(Protocol): + class _FloatOp(Protocol[_NBit_co]): @overload - def __call__(self, __other: _FloatLike) -> floating: ... + def __call__(self, __other: bool) -> floating[_NBit_co]: ... + @overload + def __call__(self, __other: int) -> floating[Any]: ... + @overload + def __call__(self, __other: float) -> float64: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__(self, __other: complex) -> complex128: ... + @overload + def __call__( + self, __other: Union[integer[_NBit], floating[_NBit]] + ) -> floating[Union[_NBit_co, _NBit]]: ... - class _ComplexOp(Protocol): - def __call__(self, __other: _ComplexLike) -> complexfloating[floating]: ... + class _ComplexOp(Protocol[_NBit_co]): + @overload + def __call__(self, __other: bool) -> complexfloating[_NBit_co, _NBit_co]: ... + @overload + def __call__(self, __other: int) -> complexfloating[Any, Any]: ... + @overload + def __call__(self, __other: Union[float, complex]) -> complex128: ... + @overload + def __call__( + self, + __other: Union[ + integer[_NBit], + floating[_NBit], + complexfloating[_NBit, _NBit], + ] + ) -> complexfloating[Union[_NBit_co, _NBit], Union[_NBit_co, _NBit]]: ... class _NumberOp(Protocol): def __call__(self, __other: _NumberLike) -> number: ... |