diff options
author | Toshiki Kataoka <kataoka@preferred.jp> | 2022-06-09 22:34:21 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-09 07:34:21 -0600 |
commit | 59fec4619403762a5d785ad83fcbde5a230416fc (patch) | |
tree | 23a63b538e782493634e0acf2c55debdd7cad43c /numpy/core/einsumfunc.py | |
parent | 198904597cde9fb2e63afb329f77eb543b2b3ac1 (diff) | |
download | numpy-59fec4619403762a5d785ad83fcbde5a230416fc.tar.gz |
BUG: use explicit einsum_path whenever it is given (#21639)
* BUG: use explicit einsum_path whenever it is given
For example, allow user to specify unary contractions in binary einsum
as described in #20962.
The commit also adds a sanity check on user-specified einsum_path.
* fix lint
* MAINT: refactor logic with explicit_einsum_path
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r-- | numpy/core/einsumfunc.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index c78d3db23..d6c5885b8 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -821,6 +821,7 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False): if path_type is None: path_type = False + explicit_einsum_path = False memory_limit = None # No optimization or a named path algorithm @@ -829,7 +830,7 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False): # Given an explicit path elif len(path_type) and (path_type[0] == 'einsum_path'): - pass + explicit_einsum_path = True # Path tuple with memory limit elif ((len(path_type) == 2) and isinstance(path_type[0], str) and @@ -898,15 +899,19 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False): naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict) # Compute the path - if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set): + if explicit_einsum_path: + path = path_type[1:] + elif ( + (path_type is False) + or (len(input_list) in [1, 2]) + or (indices == output_set) + ): # Nothing to be optimized, leave it to einsum path = [tuple(range(len(input_list)))] elif path_type == "greedy": path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) elif path_type == "optimal": path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) - elif path_type[0] == 'einsum_path': - path = path_type[1:] else: raise KeyError("Path name %s not found", path_type) @@ -955,6 +960,13 @@ def einsum_path(*operands, optimize='greedy', einsum_call=False): opt_cost = sum(cost_list) + 1 + if len(input_list) != 1: + # Explicit "einsum_path" is usually trusted, but we detect this kind of + # mistake in order to prevent from returning an intermediate value. + raise RuntimeError( + "Invalid einsum_path is specified: {} more operands has to be " + "contracted.".format(len(input_list) - 1)) + if einsum_call_arg: return (operands, contraction_list) |