summaryrefslogtreecommitdiff
path: root/numpy/typing/_callable.py
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2020-10-02 23:25:14 +0200
committerGitHub <noreply@github.com>2020-10-02 15:25:14 -0600
commit47a918f644da4ed3947ee2f797c790027041f29b (patch)
treec63ba285cf88307d1ff521ae058a34c35e77836a /numpy/typing/_callable.py
parentf4a4ddd56e1c3107a8b376d2e72cd5334676627a (diff)
downloadnumpy-47a918f644da4ed3947ee2f797c790027041f29b.tar.gz
ENH: Annotate the arithmetic operations of `ndarray` and `generic` (#17273)
* ENH: Added annotations for arithmetic-based magic methods * TST: Added arithmetic tests * TST: Moved a number of tests to `arithmetic.py` * ENH: Ensure that objects annotated as `number` support arithmetic operations * MAINT: Arithmetic operations on 0d arrays return scalars * MAINT: Clarify the type of generics returned by `ufunc.__call__` * TST: Added more arithmetic tests * MAINT: Use `_CharLike` when both `str` and `bytes` are accepted * MAINT: Change the `timedelta64` baseclass to `generic` * MAINT: Add aliases for common scalar unions * MAINT: Update the defition of `_NumberLike` * MAINT: Replace `_NumberLike` with `_ComplexLike` in the `complexfloating` annotations * MAINT: Move the callback protocols to their own module * MAINT: Make `typing._callback` available at runtime * DOC: Provide further clarification about callback protocols * MAINT: Replace `_callback` with `_callable` Addresses https://github.com/numpy/numpy/pull/17273#discussion_r485821346 The use of `__call__`-defining protocols is not limited to callbacks. The module name name & docstring now reflects this. * MAINT: Removed `__add__` from `str_` and `bytes_` Most `np.bytes_` / `np.str_` methods return their builtin `bytes` / `str` counterpart. This includes addition. * MAINT: Fix the return type of boolean division Addresses https://github.com/numpy/numpy/pull/17273#discussion_r486271220 Dividing a `np.bool_` by an integer (or vice versa) always returns `float64` * MAINT: Renamed all `_<x>Arithmetic` protocols to `_<x>Op Addresses https://github.com/numpy/numpy/pull/17273#discussion_r486272745 * TST: Add tests for boolean division * ENH: Make `np.number` generic w.r.t. its precision * ENH,WIP: Add a mypy plugin for casting `np.number` instances to appropiate subclasses * Revert "ENH,WIP: Add a mypy plugin for casting `np.number` instances to appropiate subclasses" This reverts commit c526fb619d20902bfd77709c8983c7a7d5477c95. * Revert "ENH: Make `np.number` generic w.r.t. its precision" This reverts commit dbf20183cf7ff71e379cd1a165d07e1a1d643135. * MAINT: Narow the definition of `_ComplexLike` Addresses https://github.com/numpy/numpy/pull/17273#discussion_r490440238 * MAINT: Refined the return type of `unint + int` ops `unsignedinteger + signedinteger` generally returns a `signedinteger` subclass. The exception to this is `uint64 + signedinteger`, which returns `float64`. Addresses https://github.com/numpy/numpy/pull/17273#discussion_r490442023 * MAINT: Use `_IntLike` and `_FloatLike` in the definition of `_ComplexLike`
Diffstat (limited to 'numpy/typing/_callable.py')
-rw-r--r--numpy/typing/_callable.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py
new file mode 100644
index 000000000..5e14b708f
--- /dev/null
+++ b/numpy/typing/_callable.py
@@ -0,0 +1,136 @@
+"""
+A module with various ``typing.Protocol`` subclasses that implement
+the ``__call__`` magic method.
+
+See the `Mypy documentation`_ on protocols for more details.
+
+.. _`Mypy documentation`: https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols
+
+"""
+
+import sys
+from typing import Union, TypeVar, overload, Any
+
+from numpy import (
+ _BoolLike,
+ _IntLike,
+ _FloatLike,
+ _ComplexLike,
+ _NumberLike,
+ generic,
+ bool_,
+ timedelta64,
+ number,
+ integer,
+ unsignedinteger,
+ signedinteger,
+ int32,
+ int64,
+ floating,
+ float32,
+ float64,
+ complexfloating,
+ complex128,
+)
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+ HAVE_PROTOCOL = True
+else:
+ try:
+ from typing_extensions import Protocol
+ except ImportError:
+ HAVE_PROTOCOL = False
+ else:
+ HAVE_PROTOCOL = True
+
+if HAVE_PROTOCOL:
+ _NumberType = TypeVar("_NumberType", bound=number)
+ _NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number)
+ _GenericType_co = TypeVar("_GenericType_co", covariant=True, bound=generic)
+
+ class _BoolOp(Protocol[_GenericType_co]):
+ @overload
+ def __call__(self, __other: _BoolLike) -> _GenericType_co: ...
+ @overload # platform dependent
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
+ @overload
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(self, __other: _NumberType) -> _NumberType: ...
+
+ class _BoolSub(Protocol):
+ # Note that `__other: bool_` is absent here
+ @overload # platform dependent
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
+ @overload
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(self, __other: _NumberType) -> _NumberType: ...
+
+ class _BoolTrueDiv(Protocol):
+ @overload
+ def __call__(self, __other: Union[float, _IntLike, _BoolLike]) -> float64: ...
+ @overload
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(self, __other: _NumberType) -> _NumberType: ...
+
+ class _TD64Div(Protocol[_NumberType_co]):
+ @overload
+ def __call__(self, __other: timedelta64) -> _NumberType_co: ...
+ @overload
+ def __call__(self, __other: _FloatLike) -> timedelta64: ...
+
+ class _IntTrueDiv(Protocol):
+ @overload
+ def __call__(self, __other: Union[_IntLike, float]) -> floating: ...
+ @overload
+ def __call__(self, __other: complex) -> complexfloating[floating]: ...
+
+ class _UnsignedIntOp(Protocol):
+ # NOTE: `uint64 + signedinteger -> float64`
+ @overload
+ def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ @overload
+ def __call__(self, __other: Union[int, signedinteger]) -> Union[signedinteger, float64]: ...
+ @overload
+ def __call__(self, __other: float) -> floating: ...
+ @overload
+ def __call__(self, __other: complex) -> complexfloating[floating]: ...
+
+ class _SignedIntOp(Protocol):
+ @overload
+ def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ @overload
+ def __call__(self, __other: float) -> floating: ...
+ @overload
+ def __call__(self, __other: complex) -> complexfloating[floating]: ...
+
+ class _FloatOp(Protocol):
+ @overload
+ def __call__(self, __other: _FloatLike) -> floating: ...
+ @overload
+ def __call__(self, __other: complex) -> complexfloating[floating]: ...
+
+ class _ComplexOp(Protocol):
+ def __call__(self, __other: _ComplexLike) -> complexfloating[floating]: ...
+
+ class _NumberOp(Protocol):
+ def __call__(self, __other: _NumberLike) -> number: ...
+
+else:
+ _BoolOp = Any
+ _BoolSub = Any
+ _BoolTrueDiv = Any
+ _TD64Div = Any
+ _IntTrueDiv = Any
+ _UnsignedIntOp = Any
+ _SignedIntOp = Any
+ _FloatOp = Any
+ _ComplexOp = Any
+ _NumberOp = Any