summaryrefslogtreecommitdiff
path: root/numpy/typing/_callable.py
blob: 5e14b708f1a6667ae7fcf51896c53e8df5432396 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
A module with various ``typing.Protocol`` subclasses that implement
the ``__call__`` magic method.

See the `Mypy documentation`_ on protocols for more details.

.. _`Mypy documentation`: https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols

"""

import sys
from typing import Union, TypeVar, overload, Any

from numpy import (
    _BoolLike,
    _IntLike,
    _FloatLike,
    _ComplexLike,
    _NumberLike,
    generic,
    bool_,
    timedelta64,
    number,
    integer,
    unsignedinteger,
    signedinteger,
    int32,
    int64,
    floating,
    float32,
    float64,
    complexfloating,
    complex128,
)

if sys.version_info >= (3, 8):
    from typing import Protocol
    HAVE_PROTOCOL = True
else:
    try:
        from typing_extensions import Protocol
    except ImportError:
        HAVE_PROTOCOL = False
    else:
        HAVE_PROTOCOL = True

if HAVE_PROTOCOL:
    _NumberType = TypeVar("_NumberType", bound=number)
    _NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number)
    _GenericType_co = TypeVar("_GenericType_co", covariant=True, bound=generic)

    class _BoolOp(Protocol[_GenericType_co]):
        @overload
        def __call__(self, __other: _BoolLike) -> _GenericType_co: ...
        @overload  # platform dependent
        def __call__(self, __other: int) -> Union[int32, int64]: ...
        @overload
        def __call__(self, __other: float) -> float64: ...
        @overload
        def __call__(self, __other: complex) -> complex128: ...
        @overload
        def __call__(self, __other: _NumberType) -> _NumberType: ...

    class _BoolSub(Protocol):
        # Note that `__other: bool_` is absent here
        @overload  # platform dependent
        def __call__(self, __other: int) -> Union[int32, int64]: ...
        @overload
        def __call__(self, __other: float) -> float64: ...
        @overload
        def __call__(self, __other: complex) -> complex128: ...
        @overload
        def __call__(self, __other: _NumberType) -> _NumberType: ...

    class _BoolTrueDiv(Protocol):
        @overload
        def __call__(self, __other: Union[float, _IntLike, _BoolLike]) -> float64: ...
        @overload
        def __call__(self, __other: complex) -> complex128: ...
        @overload
        def __call__(self, __other: _NumberType) -> _NumberType: ...

    class _TD64Div(Protocol[_NumberType_co]):
        @overload
        def __call__(self, __other: timedelta64) -> _NumberType_co: ...
        @overload
        def __call__(self, __other: _FloatLike) -> timedelta64: ...

    class _IntTrueDiv(Protocol):
        @overload
        def __call__(self, __other: Union[_IntLike, float]) -> floating: ...
        @overload
        def __call__(self, __other: complex) -> complexfloating[floating]: ...

    class _UnsignedIntOp(Protocol):
        # NOTE: `uint64 + signedinteger -> float64`
        @overload
        def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
        @overload
        def __call__(self, __other: Union[int, signedinteger]) -> Union[signedinteger, float64]: ...
        @overload
        def __call__(self, __other: float) -> floating: ...
        @overload
        def __call__(self, __other: complex) -> complexfloating[floating]: ...

    class _SignedIntOp(Protocol):
        @overload
        def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
        @overload
        def __call__(self, __other: float) -> floating: ...
        @overload
        def __call__(self, __other: complex) -> complexfloating[floating]: ...

    class _FloatOp(Protocol):
        @overload
        def __call__(self, __other: _FloatLike) -> floating: ...
        @overload
        def __call__(self, __other: complex) -> complexfloating[floating]: ...

    class _ComplexOp(Protocol):
        def __call__(self, __other: _ComplexLike) -> complexfloating[floating]: ...

    class _NumberOp(Protocol):
        def __call__(self, __other: _NumberLike) -> number: ...

else:
    _BoolOp = Any
    _BoolSub = Any
    _BoolTrueDiv = Any
    _TD64Div = Any
    _IntTrueDiv = Any
    _UnsignedIntOp = Any
    _SignedIntOp = Any
    _FloatOp = Any
    _ComplexOp = Any
    _NumberOp = Any