summaryrefslogtreecommitdiff
path: root/numpy/typing/_array_like.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/typing/_array_like.py')
-rw-r--r--numpy/typing/_array_like.py133
1 files changed, 110 insertions, 23 deletions
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py
index a1a604239..3bdbed8f8 100644
--- a/numpy/typing/_array_like.py
+++ b/numpy/typing/_array_like.py
@@ -1,29 +1,67 @@
+from __future__ import annotations
+
import sys
-from typing import Any, overload, Sequence, TYPE_CHECKING, Union
+from typing import Any, Sequence, TYPE_CHECKING, Union, TypeVar, Generic
+
+from numpy import (
+ ndarray,
+ dtype,
+ generic,
+ bool_,
+ unsignedinteger,
+ integer,
+ floating,
+ complexfloating,
+ number,
+ timedelta64,
+ datetime64,
+ object_,
+ void,
+ str_,
+ bytes_,
+)
-from numpy import ndarray
-from ._scalars import _ScalarLike
+from . import _HAS_TYPING_EXTENSIONS
from ._dtype_like import DTypeLike
if sys.version_info >= (3, 8):
from typing import Protocol
- HAVE_PROTOCOL = True
-else:
- try:
- from typing_extensions import Protocol
- except ImportError:
- HAVE_PROTOCOL = False
- else:
- HAVE_PROTOCOL = True
-
-if TYPE_CHECKING or HAVE_PROTOCOL:
- class _SupportsArray(Protocol):
- @overload
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
- @overload
- def __array__(self, dtype: DTypeLike = ...) -> ndarray: ...
+elif _HAS_TYPING_EXTENSIONS:
+ from typing_extensions import Protocol
+
+_T = TypeVar("_T")
+_ScalarType = TypeVar("_ScalarType", bound=generic)
+_DType = TypeVar("_DType", bound="dtype[Any]")
+_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")
+
+if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS or sys.version_info >= (3, 8):
+ # The `_SupportsArray` protocol only cares about the default dtype
+ # (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
+ # array.
+ # Concrete implementations of the protocol are responsible for adding
+ # any and all remaining overloads
+ class _SupportsArray(Protocol[_DType_co]):
+ def __array__(self) -> ndarray[Any, _DType_co]: ...
else:
- _SupportsArray = Any
+ class _SupportsArray(Generic[_DType_co]): ...
+
+# TODO: Wait for support for recursive types
+_NestedSequence = Union[
+ _T,
+ Sequence[_T],
+ Sequence[Sequence[_T]],
+ Sequence[Sequence[Sequence[_T]]],
+ Sequence[Sequence[Sequence[Sequence[_T]]]],
+]
+_RecursiveSequence = Sequence[Sequence[Sequence[Sequence[Sequence[Any]]]]]
+
+# A union representing array-like objects; consists of two typevars:
+# One representing types that can be parametrized w.r.t. `np.dtype`
+# and another one for the rest
+_ArrayLike = Union[
+ _NestedSequence[_SupportsArray[_DType]],
+ _NestedSequence[_T],
+]
# TODO: support buffer protocols once
#
@@ -33,8 +71,57 @@ else:
#
# https://github.com/python/typing/issues/593
ArrayLike = Union[
- _ScalarLike,
- Sequence[_ScalarLike],
- Sequence[Sequence[Any]], # TODO: Wait for support for recursive types
- _SupportsArray,
+ _RecursiveSequence,
+ _ArrayLike[
+ "dtype[Any]",
+ Union[bool, int, float, complex, str, bytes]
+ ],
+]
+
+# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
+# given the casting rules `same_kind`
+_ArrayLikeBool_co = _ArrayLike[
+ "dtype[bool_]",
+ bool,
+]
+_ArrayLikeUInt_co = _ArrayLike[
+ "dtype[Union[bool_, unsignedinteger[Any]]]",
+ bool,
+]
+_ArrayLikeInt_co = _ArrayLike[
+ "dtype[Union[bool_, integer[Any]]]",
+ Union[bool, int],
+]
+_ArrayLikeFloat_co = _ArrayLike[
+ "dtype[Union[bool_, integer[Any], floating[Any]]]",
+ Union[bool, int, float],
+]
+_ArrayLikeComplex_co = _ArrayLike[
+ "dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
+ Union[bool, int, float, complex],
+]
+_ArrayLikeNumber_co = _ArrayLike[
+ "dtype[Union[bool_, number[Any]]]",
+ Union[bool, int, float, complex],
+]
+_ArrayLikeTD64_co = _ArrayLike[
+ "dtype[Union[bool_, integer[Any], timedelta64]]",
+ Union[bool, int],
+]
+_ArrayLikeDT64_co = _NestedSequence[_SupportsArray["dtype[datetime64]"]]
+_ArrayLikeObject_co = _NestedSequence[_SupportsArray["dtype[object_]"]]
+
+_ArrayLikeVoid_co = _NestedSequence[_SupportsArray["dtype[void]"]]
+_ArrayLikeStr_co = _ArrayLike[
+ "dtype[str_]",
+ str,
+]
+_ArrayLikeBytes_co = _ArrayLike[
+ "dtype[bytes_]",
+ bytes,
+]
+
+_ArrayLikeInt = _ArrayLike[
+ "dtype[integer[Any]]",
+ int,
]