summaryrefslogtreecommitdiff
path: root/numpy/typing/_callable.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/typing/_callable.py')
-rw-r--r--numpy/typing/_callable.py116
1 files changed, 86 insertions, 30 deletions
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py
index 7c2ee86cb..b16891b0e 100644
--- a/numpy/typing/_callable.py
+++ b/numpy/typing/_callable.py
@@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details.
"""
import sys
-from typing import Union, TypeVar, overload, Any
+from typing import Union, TypeVar, overload, Any, TYPE_CHECKING, NoReturn
from numpy import (
generic,
@@ -25,6 +25,7 @@ from numpy import (
float32,
float64,
complexfloating,
+ complex64,
complex128,
)
from ._scalars import (
@@ -34,6 +35,7 @@ from ._scalars import (
_ComplexLike,
_NumberLike,
)
+from . import NBitBase
if sys.version_info >= (3, 8):
from typing import Protocol
@@ -46,7 +48,9 @@ else:
else:
HAVE_PROTOCOL = True
-if HAVE_PROTOCOL:
+if TYPE_CHECKING or HAVE_PROTOCOL:
+ _NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase)
+ _NBit = TypeVar("_NBit", bound=NBitBase)
_IntType = TypeVar("_IntType", bound=integer)
_NumberType = TypeVar("_NumberType", bound=number)
_NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number)
@@ -56,7 +60,7 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: _BoolLike) -> _GenericType_co: ...
@overload # platform dependent
- def __call__(self, __other: int) -> Union[int32, int64]: ...
+ def __call__(self, __other: int) -> signedinteger[Any]: ...
@overload
def __call__(self, __other: float) -> float64: ...
@overload
@@ -68,14 +72,16 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: _BoolLike) -> _GenericType_co: ...
@overload # platform dependent
- def __call__(self, __other: int) -> Union[int32, int64]: ...
+ def __call__(self, __other: int) -> signedinteger[Any]: ...
@overload
def __call__(self, __other: _IntType) -> _IntType: ...
class _BoolSub(Protocol):
# Note that `__other: bool_` is absent here
+ @overload
+ def __call__(self, __other: bool) -> NoReturn: ...
@overload # platform dependent
- def __call__(self, __other: int) -> Union[int32, int64]: ...
+ def __call__(self, __other: int) -> signedinteger[Any]: ...
@overload
def __call__(self, __other: float) -> float64: ...
@overload
@@ -97,51 +103,101 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: _FloatLike) -> timedelta64: ...
- class _IntTrueDiv(Protocol):
+ class _IntTrueDiv(Protocol[_NBit_co]):
+ @overload
+ def __call__(self, __other: bool) -> floating[_NBit_co]: ...
@overload
- def __call__(self, __other: Union[_IntLike, float]) -> floating: ...
+ def __call__(self, __other: int) -> floating[Any]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(self, __other: integer[_NBit]) -> floating[Union[_NBit_co, _NBit]]: ...
- class _UnsignedIntOp(Protocol):
+ class _UnsignedIntOp(Protocol[_NBit_co]):
# NOTE: `uint64 + signedinteger -> float64`
@overload
- def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ...
+ @overload
+ def __call__(
+ self, __other: Union[int, signedinteger[Any]]
+ ) -> Union[signedinteger[Any], float64]: ...
@overload
- def __call__(self, __other: Union[int, signedinteger]) -> Union[signedinteger, float64]: ...
+ def __call__(self, __other: float) -> float64: ...
@overload
- def __call__(self, __other: float) -> floating: ...
+ def __call__(self, __other: complex) -> complex128: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(
+ self, __other: unsignedinteger[_NBit]
+ ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ...
- class _UnsignedIntBitOp(Protocol):
- # TODO: The likes of `uint64 | np.signedinteger` will fail as there
- # is no signed integer type large enough to hold a `uint64`
- # See https://github.com/numpy/numpy/issues/2524
+ class _UnsignedIntBitOp(Protocol[_NBit_co]):
+ @overload
+ def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> signedinteger[Any]: ...
@overload
- def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ def __call__(self, __other: signedinteger[Any]) -> signedinteger[Any]: ...
@overload
- def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ def __call__(
+ self, __other: unsignedinteger[_NBit]
+ ) -> unsignedinteger[Union[_NBit_co, _NBit]]: ...
- class _SignedIntOp(Protocol):
+ class _SignedIntOp(Protocol[_NBit_co]):
@overload
- def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ...
@overload
- def __call__(self, __other: float) -> floating: ...
+ def __call__(self, __other: int) -> signedinteger[Any]: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(self, __other: float) -> float64: ...
+ @overload
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(
+ self, __other: signedinteger[_NBit]
+ ) -> signedinteger[Union[_NBit_co, _NBit]]: ...
- class _SignedIntBitOp(Protocol):
- def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ class _SignedIntBitOp(Protocol[_NBit_co]):
+ @overload
+ def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> signedinteger[Any]: ...
+ @overload
+ def __call__(
+ self, __other: signedinteger[_NBit]
+ ) -> signedinteger[Union[_NBit_co, _NBit]]: ...
- class _FloatOp(Protocol):
+ class _FloatOp(Protocol[_NBit_co]):
@overload
- def __call__(self, __other: _FloatLike) -> floating: ...
+ def __call__(self, __other: bool) -> floating[_NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> floating[Any]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(
+ self, __other: Union[integer[_NBit], floating[_NBit]]
+ ) -> floating[Union[_NBit_co, _NBit]]: ...
- class _ComplexOp(Protocol):
- def __call__(self, __other: _ComplexLike) -> complexfloating[floating]: ...
+ class _ComplexOp(Protocol[_NBit_co]):
+ @overload
+ def __call__(self, __other: bool) -> complexfloating[_NBit_co, _NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> complexfloating[Any, Any]: ...
+ @overload
+ def __call__(self, __other: Union[float, complex]) -> complex128: ...
+ @overload
+ def __call__(
+ self,
+ __other: Union[
+ integer[_NBit],
+ floating[_NBit],
+ complexfloating[_NBit, _NBit],
+ ]
+ ) -> complexfloating[Union[_NBit_co, _NBit], Union[_NBit_co, _NBit]]: ...
class _NumberOp(Protocol):
def __call__(self, __other: _NumberLike) -> number: ...