summaryrefslogtreecommitdiff
path: root/numpy/core/einsumfunc.py
diff options
context:
space:
mode:
authorToshiki Kataoka <kataoka@preferred.jp>2022-06-09 22:34:21 +0900
committerGitHub <noreply@github.com>2022-06-09 07:34:21 -0600
commit59fec4619403762a5d785ad83fcbde5a230416fc (patch)
tree23a63b538e782493634e0acf2c55debdd7cad43c /numpy/core/einsumfunc.py
parent198904597cde9fb2e63afb329f77eb543b2b3ac1 (diff)
downloadnumpy-59fec4619403762a5d785ad83fcbde5a230416fc.tar.gz
BUG: use explicit einsum_path whenever it is given (#21639)
* BUG: use explicit einsum_path whenever it is given For example, allow user to specify unary contractions in binary einsum as described in #20962. The commit also adds a sanity check on user-specified einsum_path. * fix lint * MAINT: refactor logic with explicit_einsum_path
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r--numpy/core/einsumfunc.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index c78d3db23..d6c5885b8 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -821,6 +821,7 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False):
if path_type is None:
path_type = False
+ explicit_einsum_path = False
memory_limit = None
# No optimization or a named path algorithm
@@ -829,7 +830,7 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False):
# Given an explicit path
elif len(path_type) and (path_type[0] == 'einsum_path'):
- pass
+ explicit_einsum_path = True
# Path tuple with memory limit
elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
@@ -898,15 +899,19 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False):
naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
# Compute the path
- if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
+ if explicit_einsum_path:
+ path = path_type[1:]
+ elif (
+ (path_type is False)
+ or (len(input_list) in [1, 2])
+ or (indices == output_set)
+ ):
# Nothing to be optimized, leave it to einsum
path = [tuple(range(len(input_list)))]
elif path_type == "greedy":
path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
elif path_type == "optimal":
path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
- elif path_type[0] == 'einsum_path':
- path = path_type[1:]
else:
raise KeyError("Path name %s not found", path_type)
@@ -955,6 +960,13 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False):
opt_cost = sum(cost_list) + 1
+ if len(input_list) != 1:
+ # Explicit "einsum_path" is usually trusted, but we detect this kind of
+ # mistake in order to prevent from returning an intermediate value.
+ raise RuntimeError(
+ "Invalid einsum_path is specified: {} more operands has to be "
+ "contracted.".format(len(input_list) - 1))
+
if einsum_call_arg:
return (operands, contraction_list)