diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-02-10 20:06:35 +0100 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-02-10 23:12:22 +0100 |
commit | 4849b002b7d1a862825f7ddcda1841031c60b665 (patch) | |
tree | 78c88bfb230d1cd6b88be9a4be2c5b4b9e920035 /numpy | |
parent | 5f14d6b77f04bd60780f6040c3382e7ad1379bb3 (diff) | |
download | numpy-4849b002b7d1a862825f7ddcda1841031c60b665.tar.gz |
ENH: Add annotations for `np.core.einsumfunc`
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/__init__.pyi | 7 | ||||
-rw-r--r-- | numpy/core/einsumfunc.pyi | 138 |
2 files changed, 143 insertions, 2 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 0e9deef61..1c52c7285 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -281,6 +281,11 @@ from numpy.core.arrayprint import ( printoptions as printoptions, ) +from numpy.core.einsumfunc import ( + einsum as einsum, + einsum_path as einsum_path, +) + from numpy.core.numeric import ( zeros_like as zeros_like, ones as ones, @@ -401,8 +406,6 @@ dot: Any dsplit: Any dstack: Any ediff1d: Any -einsum: Any -einsum_path: Any expand_dims: Any extract: Any eye: Any diff --git a/numpy/core/einsumfunc.pyi b/numpy/core/einsumfunc.pyi new file mode 100644 index 000000000..b33aff29f --- /dev/null +++ b/numpy/core/einsumfunc.pyi @@ -0,0 +1,138 @@ +import sys +from typing import List, TypeVar, Optional, Any, overload, Union, Tuple, Sequence + +from numpy import ( + ndarray, + dtype, + bool_, + unsignedinteger, + signedinteger, + floating, + complexfloating, + number, + _OrderKACF, +) +from numpy.typing import ( + _ArrayOrScalar, + _ArrayLikeBool_co, + _ArrayLikeUInt_co, + _ArrayLikeInt_co, + _ArrayLikeFloat_co, + _ArrayLikeComplex_co, + _DTypeLikeBool, + _DTypeLikeUInt, + _DTypeLikeInt, + _DTypeLikeFloat, + _DTypeLikeComplex, + _DTypeLikeComplex_co, +) + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +_ArrayType = TypeVar( + "_ArrayType", + bound=ndarray[Any, dtype[Union[bool_, number[Any]]]], +) + +_OptimizeKind = Union[ + None, bool, Literal["greedy", "optimal"], Sequence[Any] +] +_CastingSafe = Literal["no", "equiv", "safe", "same_kind"] +_CastingUnsafe = Literal["unsafe"] + +__all__: List[str] + +# TODO: Properly handle the `casting`-based combinatorics +@overload +def einsum( + __subscripts: str, + *operands: _ArrayLikeBool_co, + out: None = ..., + dtype: Optional[_DTypeLikeBool] = ..., + order: _OrderKACF = ..., + casting: _CastingSafe = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayOrScalar[bool_]: ... +@overload +def einsum( + __subscripts: str, + *operands: _ArrayLikeUInt_co, + out: None = ..., + dtype: Optional[_DTypeLikeUInt] = ..., + order: _OrderKACF = ..., + casting: _CastingSafe = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayOrScalar[unsignedinteger[Any]]: ... +@overload +def einsum( + __subscripts: str, + *operands: _ArrayLikeInt_co, + out: None = ..., + dtype: Optional[_DTypeLikeInt] = ..., + order: _OrderKACF = ..., + casting: _CastingSafe = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayOrScalar[signedinteger[Any]]: ... +@overload +def einsum( + __subscripts: str, + *operands: _ArrayLikeFloat_co, + out: None = ..., + dtype: Optional[_DTypeLikeFloat] = ..., + order: _OrderKACF = ..., + casting: _CastingSafe = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayOrScalar[floating[Any]]: ... +@overload +def einsum( + __subscripts: str, + *operands: _ArrayLikeComplex_co, + out: None = ..., + dtype: Optional[_DTypeLikeComplex] = ..., + order: _OrderKACF = ..., + casting: _CastingSafe = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayOrScalar[complexfloating[Any, Any]]: ... +@overload +def einsum( + __subscripts: str, + *operands: Any, + casting: _CastingUnsafe, + dtype: Optional[_DTypeLikeComplex_co] = ..., + out: None = ..., + order: _OrderKACF = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayOrScalar[Any]: ... +@overload +def einsum( + __subscripts: str, + *operands: _ArrayLikeComplex_co, + out: _ArrayType, + dtype: Optional[_DTypeLikeComplex_co] = ..., + order: _OrderKACF = ..., + casting: _CastingSafe = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayType: ... +@overload +def einsum( + __subscripts: str, + *operands: Any, + out: _ArrayType, + casting: _CastingUnsafe, + dtype: Optional[_DTypeLikeComplex_co] = ..., + order: _OrderKACF = ..., + optimize: _OptimizeKind = ..., +) -> _ArrayType: ... + +# NOTE: `einsum_call` is a hidden kwarg unavailable for public use. +# It is therefore excluded from the signatures below. +# NOTE: In practice the list consists of a `str` (first element) +# and a variable number of integer tuples. +def einsum_path( + __subscripts: str, + *operands: _ArrayLikeComplex_co, + optimize: _OptimizeKind = ..., +) -> Tuple[List[Any], str]: ... |