summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-02-10 20:06:35 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2021-02-10 23:12:22 +0100
commit4849b002b7d1a862825f7ddcda1841031c60b665 (patch)
tree78c88bfb230d1cd6b88be9a4be2c5b4b9e920035 /numpy
parent5f14d6b77f04bd60780f6040c3382e7ad1379bb3 (diff)
downloadnumpy-4849b002b7d1a862825f7ddcda1841031c60b665.tar.gz
ENH: Add annotations for `np.core.einsumfunc`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi7
-rw-r--r--numpy/core/einsumfunc.pyi138
2 files changed, 143 insertions, 2 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 0e9deef61..1c52c7285 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -281,6 +281,11 @@ from numpy.core.arrayprint import (
printoptions as printoptions,
)
+from numpy.core.einsumfunc import (
+ einsum as einsum,
+ einsum_path as einsum_path,
+)
+
from numpy.core.numeric import (
zeros_like as zeros_like,
ones as ones,
@@ -401,8 +406,6 @@ dot: Any
dsplit: Any
dstack: Any
ediff1d: Any
-einsum: Any
-einsum_path: Any
expand_dims: Any
extract: Any
eye: Any
diff --git a/numpy/core/einsumfunc.pyi b/numpy/core/einsumfunc.pyi
new file mode 100644
index 000000000..b33aff29f
--- /dev/null
+++ b/numpy/core/einsumfunc.pyi
@@ -0,0 +1,138 @@
+import sys
+from typing import List, TypeVar, Optional, Any, overload, Union, Tuple, Sequence
+
+from numpy import (
+ ndarray,
+ dtype,
+ bool_,
+ unsignedinteger,
+ signedinteger,
+ floating,
+ complexfloating,
+ number,
+ _OrderKACF,
+)
+from numpy.typing import (
+ _ArrayOrScalar,
+ _ArrayLikeBool_co,
+ _ArrayLikeUInt_co,
+ _ArrayLikeInt_co,
+ _ArrayLikeFloat_co,
+ _ArrayLikeComplex_co,
+ _DTypeLikeBool,
+ _DTypeLikeUInt,
+ _DTypeLikeInt,
+ _DTypeLikeFloat,
+ _DTypeLikeComplex,
+ _DTypeLikeComplex_co,
+)
+
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
+_ArrayType = TypeVar(
+ "_ArrayType",
+ bound=ndarray[Any, dtype[Union[bool_, number[Any]]]],
+)
+
+_OptimizeKind = Union[
+ None, bool, Literal["greedy", "optimal"], Sequence[Any]
+]
+_CastingSafe = Literal["no", "equiv", "safe", "same_kind"]
+_CastingUnsafe = Literal["unsafe"]
+
+__all__: List[str]
+
+# TODO: Properly handle the `casting`-based combinatorics
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeBool_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeBool] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[bool_]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeUInt_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeUInt] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[unsignedinteger[Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeInt_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeInt] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[signedinteger[Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeFloat_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeFloat] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[floating[Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeComplex_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeComplex] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[complexfloating[Any, Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: Any,
+ casting: _CastingUnsafe,
+ dtype: Optional[_DTypeLikeComplex_co] = ...,
+ out: None = ...,
+ order: _OrderKACF = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[Any]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeComplex_co,
+ out: _ArrayType,
+ dtype: Optional[_DTypeLikeComplex_co] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayType: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: Any,
+ out: _ArrayType,
+ casting: _CastingUnsafe,
+ dtype: Optional[_DTypeLikeComplex_co] = ...,
+ order: _OrderKACF = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayType: ...
+
+# NOTE: `einsum_call` is a hidden kwarg unavailable for public use.
+# It is therefore excluded from the signatures below.
+# NOTE: In practice the list consists of a `str` (first element)
+# and a variable number of integer tuples.
+def einsum_path(
+ __subscripts: str,
+ *operands: _ArrayLikeComplex_co,
+ optimize: _OptimizeKind = ...,
+) -> Tuple[List[Any], str]: ...