summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2022-07-14 15:29:42 +0300
committerGitHub <noreply@github.com>2022-07-14 15:29:42 +0300
commitb8192d168315f463d96174625fd3fac92dfcc73b (patch)
treeb8d3c363ec748b119822da4416d859f074fda87c /numpy
parent2d7559243a3474b58950336b5159afa085acf473 (diff)
parent79a895d6be4da326a1d7c5a018ed3ba62adaf8e7 (diff)
downloadnumpy-b8192d168315f463d96174625fd3fac92dfcc73b.tar.gz
Merge pull request #21983 from BvB93/einsum
TYP,MAINT: Allow `einsum` subscripts to be passed via integer array-likes
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/einsumfunc.pyi18
-rw-r--r--numpy/typing/tests/data/reveal/einsumfunc.pyi3
2 files changed, 12 insertions, 9 deletions
diff --git a/numpy/core/einsumfunc.pyi b/numpy/core/einsumfunc.pyi
index e614254ca..c811a5783 100644
--- a/numpy/core/einsumfunc.pyi
+++ b/numpy/core/einsumfunc.pyi
@@ -45,7 +45,7 @@ __all__: list[str]
# Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeBool_co,
out: None = ...,
@@ -56,7 +56,7 @@ def einsum(
) -> Any: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeUInt_co,
out: None = ...,
@@ -67,7 +67,7 @@ def einsum(
) -> Any: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeInt_co,
out: None = ...,
@@ -78,7 +78,7 @@ def einsum(
) -> Any: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeFloat_co,
out: None = ...,
@@ -89,7 +89,7 @@ def einsum(
) -> Any: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeComplex_co,
out: None = ...,
@@ -100,7 +100,7 @@ def einsum(
) -> Any: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: Any,
casting: _CastingUnsafe,
@@ -111,7 +111,7 @@ def einsum(
) -> Any: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeComplex_co,
out: _ArrayType,
@@ -122,7 +122,7 @@ def einsum(
) -> _ArrayType: ...
@overload
def einsum(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: Any,
out: _ArrayType,
@@ -137,7 +137,7 @@ def einsum(
# NOTE: In practice the list consists of a `str` (first element)
# and a variable number of integer tuples.
def einsum_path(
- subscripts: str,
+ subscripts: str | _ArrayLikeInt_co,
/,
*operands: _ArrayLikeComplex_co,
optimize: _OptimizeKind = ...,
diff --git a/numpy/typing/tests/data/reveal/einsumfunc.pyi b/numpy/typing/tests/data/reveal/einsumfunc.pyi
index 3c7146ada..d5f930149 100644
--- a/numpy/typing/tests/data/reveal/einsumfunc.pyi
+++ b/numpy/typing/tests/data/reveal/einsumfunc.pyi
@@ -30,3 +30,6 @@ reveal_type(np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f)) # E: Tuple[builtins
reveal_type(np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c)) # E: Tuple[builtins.list[Any], builtins.str]
reveal_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i)) # E: Tuple[builtins.list[Any], builtins.str]
reveal_type(np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)) # E: Tuple[builtins.list[Any], builtins.str]
+
+reveal_type(np.einsum([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i)) # E: Any
+reveal_type(np.einsum_path([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i)) # E: Tuple[builtins.list[Any], builtins.str]