diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/add_newdocs.py | 14 | ||||
-rw-r--r-- | numpy/core/__init__.py | 3 | ||||
-rw-r--r-- | numpy/core/einsumfunc.py | 990 | ||||
-rw-r--r-- | numpy/core/numeric.py | 4 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 866 |
6 files changed, 1574 insertions, 305 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 755cdc303..da87bc29e 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -2108,9 +2108,9 @@ add_newdoc('numpy.core', 'matmul', """) -add_newdoc('numpy.core', 'einsum', +add_newdoc('numpy.core', 'c_einsum', """ - einsum(subscripts, *operands, out=None, dtype=None, order='K', casting='safe') + c_einsum(subscripts, *operands, out=None, dtype=None, order='K', casting='safe') Evaluates the Einstein summation convention on the operands. @@ -2120,6 +2120,8 @@ add_newdoc('numpy.core', 'einsum', function is to try the examples below, which show how many common NumPy functions can be implemented as calls to `einsum`. + This is the core C function. + Parameters ---------- subscripts : str @@ -2128,10 +2130,10 @@ add_newdoc('numpy.core', 'einsum', These are the arrays for the operation. out : ndarray, optional If provided, the calculation is done into this array. - dtype : data-type, optional + dtype : {data-type, None}, optional If provided, forces the calculation to use the data type specified. Note that you may have to also give a more liberal `casting` - parameter to allow the conversions. + parameter to allow the conversions. Default is None. order : {'C', 'F', 'A', 'K'}, optional Controls the memory layout of the output. 'C' means it should be C contiguous. 'F' means it should be Fortran contiguous, @@ -2150,6 +2152,8 @@ add_newdoc('numpy.core', 'einsum', like float64 to float32, are allowed. * 'unsafe' means any data conversions may be done. + Default is 'safe'. + Returns ------- output : ndarray @@ -2157,7 +2161,7 @@ add_newdoc('numpy.core', 'einsum', See Also -------- - dot, inner, outer, tensordot + einsum, dot, inner, outer, tensordot Notes ----- diff --git a/numpy/core/__init__.py b/numpy/core/__init__.py index 1ac850002..ca2f45ece 100644 --- a/numpy/core/__init__.py +++ b/numpy/core/__init__.py @@ -50,6 +50,8 @@ from . import getlimits from .getlimits import * from . import shape_base from .shape_base import * +from . import einsumfunc +from .einsumfunc import * del nt from .fromnumeric import amax as max, amin as min, round_ as round @@ -64,6 +66,7 @@ __all__ += function_base.__all__ __all__ += machar.__all__ __all__ += getlimits.__all__ __all__ += shape_base.__all__ +__all__ += einsumfunc.__all__ from numpy.testing.nosetester import _numpy_tester diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py new file mode 100644 index 000000000..97eb7924f --- /dev/null +++ b/numpy/core/einsumfunc.py @@ -0,0 +1,990 @@ +""" +Implementation of optimized einsum. + +""" +from __future__ import division, absolute_import, print_function + +from numpy.core.multiarray import c_einsum +from numpy.core.numeric import asarray, asanyarray, result_type + +__all__ = ['einsum', 'einsum_path'] + +einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' +einsum_symbols_set = set(einsum_symbols) + + +def _compute_size_by_dict(indices, idx_dict): + """ + Computes the product of the elements in indices based on the dictionary + idx_dict. + + Parameters + ---------- + indices : iterable + Indices to base the product on. + idx_dict : dictionary + Dictionary of index sizes + + Returns + ------- + ret : int + The resulting product. + + Examples + -------- + >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) + 90 + + """ + ret = 1 + for i in indices: + ret *= idx_dict[i] + return ret + + +def _find_contraction(positions, input_sets, output_set): + """ + Finds the contraction for a given set of input and output sets. + + Paramaters + ---------- + positions : iterable + Integer positions of terms used in the contraction. + input_sets : list + List of sets that represent the lhs side of the einsum subscript + output_set : set + Set that represents the rhs side of the overall einsum subscript + + Returns + ------- + new_result : set + The indices of the resulting contraction + remaining : list + List of sets that have not been contracted, the new set is appended to + the end of this list + idx_removed : set + Indices removed from the entire contraction + idx_contraction : set + The indices used in the current contraction + + Examples + -------- + + # A simple dot product test case + >>> pos = (0, 1) + >>> isets = [set('ab'), set('bc')] + >>> oset = set('ac') + >>> _find_contraction(pos, isets, oset) + ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) + + # A more complex case with additional terms in the contraction + >>> pos = (0, 2) + >>> isets = [set('abd'), set('ac'), set('bdc')] + >>> oset = set('ac') + >>> _find_contraction(pos, isets, oset) + ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) + """ + + idx_contract = set() + idx_remain = output_set.copy() + remaining = [] + for ind, value in enumerate(input_sets): + if ind in positions: + idx_contract |= value + else: + remaining.append(value) + idx_remain |= value + + new_result = idx_remain & idx_contract + idx_removed = (idx_contract - new_result) + remaining.append(new_result) + + return (new_result, remaining, idx_removed, idx_contract) + + +def _optimal_path(input_sets, output_set, idx_dict, memory_limit): + """ + Computes all possible pair contractions, sieves the results based + on ``memory_limit`` and returns the lowest cost path. This algorithm + scales factorial with respect to the elements in the list ``input_sets``. + + Paramaters + ---------- + input_sets : list + List of sets that represent the lhs side of the einsum subscript + output_set : set + Set that represents the rhs side of the overall einsum subscript + idx_dict : dictionary + Dictionary of index sizes + memory_limit : int + The maximum number of elements in a temporary array + + Returns + ------- + path : list + The optimal contraction order within the memory limit constraint. + + Examples + -------- + >>> isets = [set('abd'), set('ac'), set('bdc')] + >>> oset = set('') + >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} + >>> _path__optimal_path(isets, oset, idx_sizes, 5000) + [(0, 2), (0, 1)] + """ + + full_results = [(0, [], input_sets)] + for iteration in range(len(input_sets) - 1): + iter_results = [] + + # Compute all unique pairs + comb_iter = [] + for x in range(len(input_sets) - iteration): + for y in range(x + 1, len(input_sets) - iteration): + comb_iter.append((x, y)) + + for curr in full_results: + cost, positions, remaining = curr + for con in comb_iter: + + # Find the contraction + cont = _find_contraction(con, remaining, output_set) + new_result, new_input_sets, idx_removed, idx_contract = cont + + # Sieve the results based on memory_limit + new_size = _compute_size_by_dict(new_result, idx_dict) + if new_size > memory_limit: + continue + + # Find cost + new_cost = _compute_size_by_dict(idx_contract, idx_dict) + if idx_removed: + new_cost *= 2 + + # Build (total_cost, positions, indices_remaining) + new_cost += cost + new_pos = positions + [con] + iter_results.append((new_cost, new_pos, new_input_sets)) + + # Update list to iterate over + full_results = iter_results + + # If we have not found anything return single einsum contraction + if len(full_results) == 0: + return [tuple(range(len(input_sets)))] + + path = min(full_results, key=lambda x: x[0])[1] + return path + + +def _greedy_path(input_sets, output_set, idx_dict, memory_limit): + """ + Finds the path by contracting the best pair until the input list is + exhausted. The best pair is found by minimizing the tuple + ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing + matrix multiplication or inner product operations, then Hadamard like + operations, and finally outer operations. Outer products are limited by + ``memory_limit``. This algorithm scales cubically with respect to the + number of elements in the list ``input_sets``. + + Paramaters + ---------- + input_sets : list + List of sets that represent the lhs side of the einsum subscript + output_set : set + Set that represents the rhs side of the overall einsum subscript + idx_dict : dictionary + Dictionary of index sizes + memory_limit_limit : int + The maximum number of elements in a temporary array + + Returns + ------- + path : list + The greedy contraction order within the memory limit constraint. + + Examples + -------- + >>> isets = [set('abd'), set('ac'), set('bdc')] + >>> oset = set('') + >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} + >>> _path__greedy_path(isets, oset, idx_sizes, 5000) + [(0, 2), (0, 1)] + """ + + if len(input_sets) == 1: + return [(0,)] + + path = [] + for iteration in range(len(input_sets) - 1): + iteration_results = [] + comb_iter = [] + + # Compute all unique pairs + for x in range(len(input_sets)): + for y in range(x + 1, len(input_sets)): + comb_iter.append((x, y)) + + for positions in comb_iter: + + # Find the contraction + contract = _find_contraction(positions, input_sets, output_set) + idx_result, new_input_sets, idx_removed, idx_contract = contract + + # Sieve the results based on memory_limit + if _compute_size_by_dict(idx_result, idx_dict) > memory_limit: + continue + + # Build sort tuple + removed_size = _compute_size_by_dict(idx_removed, idx_dict) + cost = _compute_size_by_dict(idx_contract, idx_dict) + sort = (-removed_size, cost) + + # Add contraction to possible choices + iteration_results.append([sort, positions, new_input_sets]) + + # If we did not find a new contraction contract remaining + if len(iteration_results) == 0: + path.append(tuple(range(len(input_sets)))) + break + + # Sort based on first index + best = min(iteration_results, key=lambda x: x[0]) + path.append(best[1]) + input_sets = best[2] + + return path + + +def _parse_einsum_input(operands): + """ + A reproduction of einsum c side einsum parsing in python. + + Returns + ------- + input_strings : str + Parsed input strings + output_string : str + Parsed output string + operands : list of array_like + The operands to use in the numpy contraction + + Examples + -------- + The operand list is simplified to reduce printing: + + >>> a = np.random.rand(4, 4) + >>> b = np.random.rand(4, 4, 4) + >>> __parse_einsum_input(('...a,...a->...', a, b)) + ('za,xza', 'xz', [a, b]) + + >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) + ('za,xza', 'xz', [a, b]) + """ + + if len(operands) == 0: + raise ValueError("No input operands") + + if isinstance(operands[0], str): + subscripts = operands[0].replace(" ", "") + operands = [asanyarray(v) for v in operands[1:]] + + # Ensure all characters are valid + for s in subscripts: + if s in '.,->': + continue + if s not in einsum_symbols: + raise ValueError("Character %s is not a valid symbol." % s) + + else: + tmp_operands = list(operands) + operand_list = [] + subscript_list = [] + for p in range(len(operands) // 2): + operand_list.append(tmp_operands.pop(0)) + subscript_list.append(tmp_operands.pop(0)) + + output_list = tmp_operands[-1] if len(tmp_operands) else None + operands = [asanyarray(v) for v in operand_list] + subscripts = "" + last = len(subscript_list) - 1 + for num, sub in enumerate(subscript_list): + for s in sub: + if s is Ellipsis: + subscripts += "..." + elif isinstance(s, int): + subscripts += einsum_symbols[s] + else: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") + if num != last: + subscripts += "," + + if output_list is not None: + subscripts += "->" + for s in output_list: + if s is Ellipsis: + subscripts += "..." + elif isinstance(s, int): + subscripts += einsum_symbols[s] + else: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") + # Check for proper "->" + if ("-" in subscripts) or (">" in subscripts): + invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) + if invalid or (subscripts.count("->") != 1): + raise ValueError("Subscripts can only contain one '->'.") + + # Parse ellipses + if "." in subscripts: + used = subscripts.replace(".", "").replace(",", "").replace("->", "") + unused = list(einsum_symbols_set - set(used)) + ellipse_inds = "".join(unused) + longest = 0 + + if "->" in subscripts: + input_tmp, output_sub = subscripts.split("->") + split_subscripts = input_tmp.split(",") + out_sub = True + else: + split_subscripts = subscripts.split(',') + out_sub = False + + for num, sub in enumerate(split_subscripts): + if "." in sub: + if (sub.count(".") != 3) or (sub.count("...") != 1): + raise ValueError("Invalid Ellipses.") + + # Take into account numerical values + if operands[num].shape == (): + ellipse_count = 0 + else: + ellipse_count = max(len(operands[num].shape), 1) + ellipse_count -= (len(sub) - 3) + + if ellipse_count > longest: + longest = ellipse_count + + if ellipse_count < 0: + raise ValueError("Ellipses lengths do not match.") + elif ellipse_count == 0: + split_subscripts[num] = sub.replace('...', '') + else: + rep_inds = ellipse_inds[-ellipse_count:] + split_subscripts[num] = sub.replace('...', rep_inds) + + subscripts = ",".join(split_subscripts) + if longest == 0: + out_ellipse = "" + else: + out_ellipse = ellipse_inds[-longest:] + + if out_sub: + subscripts += "->" + output_sub.replace("...", out_ellipse) + else: + # Special care for outputless ellipses + output_subscript = "" + tmp_subscripts = subscripts.replace(",", "") + for s in sorted(set(tmp_subscripts)): + if s not in (einsum_symbols): + raise ValueError("Character %s is not a valid symbol." % s) + if tmp_subscripts.count(s) == 1: + output_subscript += s + normal_inds = ''.join(sorted(set(output_subscript) - + set(out_ellipse))) + + subscripts += "->" + out_ellipse + normal_inds + + # Build output string if does not exist + if "->" in subscripts: + input_subscripts, output_subscript = subscripts.split("->") + else: + input_subscripts = subscripts + # Build output subscripts + tmp_subscripts = subscripts.replace(",", "") + output_subscript = "" + for s in sorted(set(tmp_subscripts)): + if s not in einsum_symbols: + raise ValueError("Character %s is not a valid symbol." % s) + if tmp_subscripts.count(s) == 1: + output_subscript += s + + # Make sure output subscripts are in the input + for char in output_subscript: + if char not in input_subscripts: + raise ValueError("Output character %s did not appear in the input" + % char) + + # Make sure number operands is equivalent to the number of terms + if len(input_subscripts.split(',')) != len(operands): + raise ValueError("Number of einsum subscripts must be equal to the " + "number of operands.") + + return (input_subscripts, output_subscript, operands) + + +def einsum_path(*operands, **kwargs): + """ + einsum_path(subscripts, *operands, optimize='greedy') + + Evaluates the lowest cost contraction order for an einsum expression by + considering the creation of intermediate arrays. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation. + *operands : list of array_like + These are the arrays for the operation. + optimize : {bool, list, tuple, 'greedy', 'optimal'} + Choose the type of path. If a tuple is provided, the second argument is + assumed to be the maximum intermediate size created. If only a single + argument is provided the largest input or output array size is used + as a maximum intermediate size. + + * if a list is given that starts with ``einsum_path``, uses this as the + contraction path + * if False no optimization is taken + * if True defaults to the 'greedy' algorithm + * 'optimal' An algorithm that combinatorially explores all possible + ways of contracting the listed tensors and choosest the least costly + path. Scales exponentially with the number of terms in the + contraction. + * 'greedy' An algorithm that chooses the best pair contraction + at each step. Effectively, this algorithm searches the largest inner, + Hadamard, and then outer products at each step. Scales cubically with + the number of terms in the contraction. Equivalent to the 'optimal' + path for most contractions. + + Default is 'greedy'. + + Returns + ------- + path : list of tuples + A list representation of the einsum path. + string_repr : str + A printable representation of the einsum path. + + Notes + ----- + The resulting path indicates which terms of the input contraction should be + contracted first, the result of this contraction is then appended to the + end of the contraction list. This list can then be iterated over until all + intermediate contractions are complete. + + See Also + -------- + einsum, linalg.multi_dot + + Examples + -------- + + We can begin with a chain dot example. In this case, it is optimal to + contract the ``b`` and ``c`` tensors first as reprsented by the first + element of the path ``(1, 2)``. The resulting tensor is added to the end + of the contraction and the remaining contraction ``(0, 1)`` is then + completed. + + >>> a = np.random.rand(2, 2) + >>> b = np.random.rand(2, 5) + >>> c = np.random.rand(5, 2) + >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') + >>> print(path_info[0]) + ['einsum_path', (1, 2), (0, 1)] + >>> print(path_info[1]) + Complete contraction: ij,jk,kl->il + Naive scaling: 4 + Optimized scaling: 3 + Naive FLOP count: 1.600e+02 + Optimized FLOP count: 5.600e+01 + Theoretical speedup: 2.857 + Largest intermediate: 4.000e+00 elements + ------------------------------------------------------------------------- + scaling current remaining + ------------------------------------------------------------------------- + 3 kl,jk->jl ij,jl->il + 3 jl,ij->il il->il + + + A more complex index transformation example. + + >>> I = np.random.rand(10, 10, 10, 10) + >>> C = np.random.rand(10, 10) + >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, + optimize='greedy') + + >>> print(path_info[0]) + ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] + >>> print(path_info[1]) + Complete contraction: ea,fb,abcd,gc,hd->efgh + Naive scaling: 8 + Optimized scaling: 5 + Naive FLOP count: 8.000e+08 + Optimized FLOP count: 8.000e+05 + Theoretical speedup: 1000.000 + Largest intermediate: 1.000e+04 elements + -------------------------------------------------------------------------- + scaling current remaining + -------------------------------------------------------------------------- + 5 abcd,ea->bcde fb,gc,hd,bcde->efgh + 5 bcde,fb->cdef gc,hd,cdef->efgh + 5 cdef,gc->defg hd,defg->efgh + 5 defg,hd->efgh efgh->efgh + """ + + # Make sure all keywords are valid + valid_contract_kwargs = ['optimize', 'einsum_call'] + unknown_kwargs = [k for (k, v) in kwargs.items() if k + not in valid_contract_kwargs] + if len(unknown_kwargs): + raise TypeError("Did not understand the following kwargs:" + " %s" % unknown_kwargs) + + # Figure out what the path really is + path_type = kwargs.pop('optimize', False) + if path_type is True: + path_type = 'greedy' + if path_type is None: + path_type = False + + memory_limit = None + + # No optimization or a named path algorithm + if (path_type is False) or isinstance(path_type, str): + pass + + # Given an explicit path + elif len(path_type) and (path_type[0] == 'einsum_path'): + pass + + # Path tuple with memory limit + elif ((len(path_type) == 2) and isinstance(path_type[0], str) and + isinstance(path_type[1], (int, float))): + memory_limit = int(path_type[1]) + path_type = path_type[0] + + else: + raise TypeError("Did not understand the path: %s" % str(path_type)) + + # Hidden option, only einsum should call this + einsum_call_arg = kwargs.pop("einsum_call", False) + + # Python side parsing + input_subscripts, output_subscript, operands = _parse_einsum_input(operands) + subscripts = input_subscripts + '->' + output_subscript + + # Build a few useful list and sets + input_list = input_subscripts.split(',') + input_sets = [set(x) for x in input_list] + output_set = set(output_subscript) + indices = set(input_subscripts.replace(',', '')) + + # Get length of each unique dimension and ensure all dimensions are correct + dimension_dict = {} + for tnum, term in enumerate(input_list): + sh = operands[tnum].shape + if len(sh) != len(term): + raise ValueError("Einstein sum subscript %s does not contain the " + "correct number of indices for operand %d.", + input_subscripts[tnum], tnum) + for cnum, char in enumerate(term): + dim = sh[cnum] + if char in dimension_dict.keys(): + if dimension_dict[char] != dim: + raise ValueError("Size of label '%s' for operand %d does " + "not match previous terms.", char, tnum) + else: + dimension_dict[char] = dim + + # Compute size of each input array plus the output array + size_list = [] + for term in input_list + [output_subscript]: + size_list.append(_compute_size_by_dict(term, dimension_dict)) + max_size = max(size_list) + + if memory_limit is None: + memory_arg = max_size + else: + memory_arg = memory_limit + + # Compute naive cost + # This isnt quite right, need to look into exactly how einsum does this + naive_cost = _compute_size_by_dict(indices, dimension_dict) + indices_in_input = input_subscripts.replace(',', '') + mult = max(len(input_list) - 1, 1) + if (len(indices_in_input) - len(set(indices_in_input))): + mult *= 2 + naive_cost *= mult + + # Compute the path + if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set): + # Nothing to be optimized, leave it to einsum + path = [tuple(range(len(input_list)))] + elif path_type == "greedy": + # Maximum memory should be at most out_size for this algorithm + memory_arg = min(memory_arg, max_size) + path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) + elif path_type == "optimal": + path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) + elif path_type[0] == 'einsum_path': + path = path_type[1:] + else: + raise KeyError("Path name %s not found", path_type) + + cost_list, scale_list, size_list, contraction_list = [], [], [], [] + + # Build contraction tuple (positions, gemm, einsum_str, remaining) + for cnum, contract_inds in enumerate(path): + # Make sure we remove inds from right to left + contract_inds = tuple(sorted(list(contract_inds), reverse=True)) + + contract = _find_contraction(contract_inds, input_sets, output_set) + out_inds, input_sets, idx_removed, idx_contract = contract + + cost = _compute_size_by_dict(idx_contract, dimension_dict) + if idx_removed: + cost *= 2 + cost_list.append(cost) + scale_list.append(len(idx_contract)) + size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) + + tmp_inputs = [] + for x in contract_inds: + tmp_inputs.append(input_list.pop(x)) + + # Last contraction + if (cnum - len(path)) == -1: + idx_result = output_subscript + else: + sort_result = [(dimension_dict[ind], ind) for ind in out_inds] + idx_result = "".join([x[1] for x in sorted(sort_result)]) + + input_list.append(idx_result) + einsum_str = ",".join(tmp_inputs) + "->" + idx_result + + contraction = (contract_inds, idx_removed, einsum_str, input_list[:]) + contraction_list.append(contraction) + + opt_cost = sum(cost_list) + 1 + + if einsum_call_arg: + return (operands, contraction_list) + + # Return the path along with a nice string representation + overall_contraction = input_subscripts + "->" + output_subscript + header = ("scaling", "current", "remaining") + + speedup = naive_cost / opt_cost + max_i = max(size_list) + + path_print = " Complete contraction: %s\n" % overall_contraction + path_print += " Naive scaling: %d\n" % len(indices) + path_print += " Optimized scaling: %d\n" % max(scale_list) + path_print += " Naive FLOP count: %.3e\n" % naive_cost + path_print += " Optimized FLOP count: %.3e\n" % opt_cost + path_print += " Theoretical speedup: %3.3f\n" % speedup + path_print += " Largest intermediate: %.3e elements\n" % max_i + path_print += "-" * 74 + "\n" + path_print += "%6s %24s %40s\n" % header + path_print += "-" * 74 + + for n, contraction in enumerate(contraction_list): + inds, idx_rm, einsum_str, remaining = contraction + remaining_str = ",".join(remaining) + "->" + output_subscript + path_run = (scale_list[n], einsum_str, remaining_str) + path_print += "\n%4d %24s %40s" % path_run + + path = ['einsum_path'] + path + return (path, path_print) + + +# Rewrite einsum to handle different cases +def einsum(*operands, **kwargs): + """ + einsum(subscripts, *operands, out=None, dtype=None, order='K', + casting='safe', optimize=False) + + Evaluates the Einstein summation convention on the operands. + + Using the Einstein summation convention, many common multi-dimensional + array operations can be represented in a simple fashion. This function + provides a way to compute such summations. The best way to understand this + function is to try the examples below, which show how many common NumPy + functions can be implemented as calls to `einsum`. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation. + operands : list of array_like + These are the arrays for the operation. + out : {ndarray, None}, optional + If provided, the calculation is done into this array. + dtype : {data-type, None}, optional + If provided, forces the calculation to use the data type specified. + Note that you may have to also give a more liberal `casting` + parameter to allow the conversions. Default is None. + order : {'C', 'F', 'A', 'K'}, optional + Controls the memory layout of the output. 'C' means it should + be C contiguous. 'F' means it should be Fortran contiguous, + 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. + 'K' means it should be as close to the layout as the inputs as + is possible, including arbitrarily permuted axes. + Default is 'K'. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Setting this to + 'unsafe' is not recommended, as it can adversely affect accumulations. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + + Default is 'safe'. + optimize : {False, True, 'greedy', 'optimal'}, optional + Controls if intermediate optimization should occur. No optimization + will occur if False and True will default to the 'greedy' algorithm. + Also accepts an explicit contraction list from the ``np.einsum_path`` + function. See ``np.einsum_path`` for more details. Default is False. + + Returns + ------- + output : ndarray + The calculation based on the Einstein summation convention. + + See Also + -------- + einsum_path, dot, inner, outer, tensordot, linalg.multi_dot + + Notes + ----- + .. versionadded:: 1.6.0 + + The subscripts string is a comma-separated list of subscript labels, + where each label refers to a dimension of the corresponding operand. + Repeated subscripts labels in one operand take the diagonal. For example, + ``np.einsum('ii', a)`` is equivalent to ``np.trace(a)``. + + Whenever a label is repeated, it is summed, so ``np.einsum('i,i', a, b)`` + is equivalent to ``np.inner(a,b)``. If a label appears only once, + it is not summed, so ``np.einsum('i', a)`` produces a view of ``a`` + with no changes. + + The order of labels in the output is by default alphabetical. This + means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while + ``np.einsum('ji', a)`` takes its transpose. + + The output can be controlled by specifying output subscript labels + as well. This specifies the label order, and allows summing to + be disallowed or forced when desired. The call ``np.einsum('i->', a)`` + is like ``np.sum(a, axis=-1)``, and ``np.einsum('ii->i', a)`` + is like ``np.diag(a)``. The difference is that `einsum` does not + allow broadcasting by default. + + To enable and control broadcasting, use an ellipsis. Default + NumPy-style broadcasting is done by adding an ellipsis + to the left of each term, like ``np.einsum('...ii->...i', a)``. + To take the trace along the first and last axes, + you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix + product with the left-most indices instead of rightmost, you can do + ``np.einsum('ij...,jk...->ik...', a, b)``. + + When there is only one operand, no axes are summed, and no output + parameter is provided, a view into the operand is returned instead + of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` + produces a view. + + An alternative way to provide the subscripts and operands is as + ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. The examples + below have corresponding `einsum` calls with the two parameter methods. + + .. versionadded:: 1.10.0 + + Views returned from einsum are now writeable whenever the input array + is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now + have the same effect as ``np.swapaxes(a, 0, 2)`` and + ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal + of a 2D array. + + .. versionadded:: 1.12.0 + + Added the ``optimize`` argument which will optimize the contraction order + of an einsum expression. For a contraction with three or more operands this + can greatly increase the computational efficiency at the cost of a larger + memory footprint during computation. + + See ``np.einsum_path`` for more details. + + Examples + -------- + >>> a = np.arange(25).reshape(5,5) + >>> b = np.arange(5) + >>> c = np.arange(6).reshape(2,3) + + >>> np.einsum('ii', a) + 60 + >>> np.einsum(a, [0,0]) + 60 + >>> np.trace(a) + 60 + + >>> np.einsum('ii->i', a) + array([ 0, 6, 12, 18, 24]) + >>> np.einsum(a, [0,0], [0]) + array([ 0, 6, 12, 18, 24]) + >>> np.diag(a) + array([ 0, 6, 12, 18, 24]) + + >>> np.einsum('ij,j', a, b) + array([ 30, 80, 130, 180, 230]) + >>> np.einsum(a, [0,1], b, [1]) + array([ 30, 80, 130, 180, 230]) + >>> np.dot(a, b) + array([ 30, 80, 130, 180, 230]) + >>> np.einsum('...j,j', a, b) + array([ 30, 80, 130, 180, 230]) + + >>> np.einsum('ji', c) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> np.einsum(c, [1,0]) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> c.T + array([[0, 3], + [1, 4], + [2, 5]]) + + >>> np.einsum('..., ...', 3, c) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + >>> np.multiply(3, c) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + + >>> np.einsum('i,i', b, b) + 30 + >>> np.einsum(b, [0], b, [0]) + 30 + >>> np.inner(b,b) + 30 + + >>> np.einsum('i,j', np.arange(2)+1, b) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + >>> np.einsum(np.arange(2)+1, [0], b, [1]) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + >>> np.outer(np.arange(2)+1, b) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + + >>> np.einsum('i...->...', a) + array([50, 55, 60, 65, 70]) + >>> np.einsum(a, [0,Ellipsis], [Ellipsis]) + array([50, 55, 60, 65, 70]) + >>> np.sum(a, axis=0) + array([50, 55, 60, 65, 70]) + + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> np.einsum('ijk,jil->kl', a, b) + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + >>> np.tensordot(a,b, axes=([1,0],[0,1])) + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + + >>> a = np.arange(6).reshape((3,2)) + >>> b = np.arange(12).reshape((4,3)) + >>> np.einsum('ki,jk->ij', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + >>> np.einsum('ki,...k->i...', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + >>> np.einsum('k...,jk', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + + >>> # since version 1.10.0 + >>> a = np.zeros((3, 3)) + >>> np.einsum('ii->i', a)[:] = 1 + >>> a + array([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + + """ + + # Grab non-einsum kwargs + optimize_arg = kwargs.pop('optimize', False) + + # If no optimization, run pure einsum + if optimize_arg is False: + return c_einsum(*operands, **kwargs) + + valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting'] + einsum_kwargs = {k: v for (k, v) in kwargs.items() if + k in valid_einsum_kwargs} + + # Make sure all keywords are valid + valid_contract_kwargs = ['optimize'] + valid_einsum_kwargs + unknown_kwargs = [k for (k, v) in kwargs.items() if + k not in valid_contract_kwargs] + + if len(unknown_kwargs): + raise TypeError("Did not understand the following kwargs: %s" + % unknown_kwargs) + + # Special handeling if out is specified + specified_out = False + out_array = einsum_kwargs.pop('out', None) + if out_array is not None: + specified_out = True + + # Build the contraction list and operand + operands, contraction_list = einsum_path(*operands, optimize=optimize_arg, + einsum_call=True) + # Start contraction loop + for num, contraction in enumerate(contraction_list): + inds, idx_rm, einsum_str, remaining = contraction + tmp_operands = [] + for x in inds: + tmp_operands.append(operands.pop(x)) + + # If out was specified + if specified_out and ((num + 1) == len(contraction_list)): + einsum_kwargs["out"] = out_array + + # Do the contraction + new_view = c_einsum(einsum_str, *tmp_operands, **einsum_kwargs) + + # Append new items and derefernce what we can + operands.append(new_view) + del tmp_operands, new_view + + if specified_out: + return out_array + else: + return operands[0] diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 81d9d3697..82f7b081f 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -12,7 +12,7 @@ from .multiarray import ( _fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS, BUFSIZE, CLIP, MAXDIMS, MAY_SHARE_BOUNDS, MAY_SHARE_EXACT, RAISE, WRAP, arange, array, broadcast, can_cast, compare_chararrays, - concatenate, copyto, count_nonzero, dot, dtype, einsum, empty, + concatenate, copyto, count_nonzero, dot, dtype, empty, empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring, inner, int_asbuffer, lexsort, matmul, may_share_memory, min_scalar_type, ndarray, nditer, nested_iters, promote_types, @@ -53,7 +53,7 @@ __all__ = [ 'min_scalar_type', 'result_type', 'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like', 'zeros_like', 'ones_like', 'correlate', 'convolve', 'inner', 'dot', - 'einsum', 'outer', 'vdot', 'alterdot', 'restoredot', 'roll', + 'outer', 'vdot', 'alterdot', 'restoredot', 'roll', 'rollaxis', 'moveaxis', 'cross', 'tensordot', 'array2string', 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str', 'set_string_function', 'little_endian', 'require', 'fromiter', diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 7c3c95b24..fb646b336 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4139,7 +4139,7 @@ static struct PyMethodDef array_module_methods[] = { {"matmul", (PyCFunction)array_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, - {"einsum", + {"c_einsum", (PyCFunction)array_einsum, METH_VARARGS|METH_KEYWORDS, NULL}, {"_fastCopyAndTranspose", diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index c31d281e9..3ecc829f4 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -3,282 +3,308 @@ from __future__ import division, absolute_import, print_function import numpy as np from numpy.testing import ( TestCase, run_module_suite, assert_, assert_equal, assert_array_equal, - assert_raises, suppress_warnings + assert_almost_equal, assert_raises, suppress_warnings ) +# Setup for optimize einsum +chars = 'abcdefghij' +sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3]) +global_size_dict = {} +for size, char in zip(sizes, chars): + global_size_dict[char] = size + class TestEinSum(TestCase): def test_einsum_errors(self): - # Need enough arguments - assert_raises(ValueError, np.einsum) - assert_raises(ValueError, np.einsum, "") - - # subscripts must be a string - assert_raises(TypeError, np.einsum, 0, 0) - - # out parameter must be an array - assert_raises(TypeError, np.einsum, "", 0, out='test') - - # order parameter must be a valid order - assert_raises(TypeError, np.einsum, "", 0, order='W') - - # casting parameter must be a valid casting - assert_raises(ValueError, np.einsum, "", 0, casting='blah') - - # dtype parameter must be a valid dtype - assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type') - - # other keyword arguments are rejected - assert_raises(TypeError, np.einsum, "", 0, bad_arg=0) - - # issue 4528 revealed a segfault with this call - assert_raises(TypeError, np.einsum, *(None,)*63) - - # number of operands must match count in subscripts string - assert_raises(ValueError, np.einsum, "", 0, 0) - assert_raises(ValueError, np.einsum, ",", 0, [0], [0]) - assert_raises(ValueError, np.einsum, ",", [0]) - - # can't have more subscripts than dimensions in the operand - assert_raises(ValueError, np.einsum, "i", 0) - assert_raises(ValueError, np.einsum, "ij", [0, 0]) - assert_raises(ValueError, np.einsum, "...i", 0) - assert_raises(ValueError, np.einsum, "i...j", [0, 0]) - assert_raises(ValueError, np.einsum, "i...", 0) - assert_raises(ValueError, np.einsum, "ij...", [0, 0]) - - # invalid ellipsis - assert_raises(ValueError, np.einsum, "i..", [0, 0]) - assert_raises(ValueError, np.einsum, ".i...", [0, 0]) - assert_raises(ValueError, np.einsum, "j->..j", [0, 0]) - assert_raises(ValueError, np.einsum, "j->.j...", [0, 0]) - - # invalid subscript character - assert_raises(ValueError, np.einsum, "i%...", [0, 0]) - assert_raises(ValueError, np.einsum, "...j$", [0, 0]) - assert_raises(ValueError, np.einsum, "i->&", [0, 0]) - - # output subscripts must appear in input - assert_raises(ValueError, np.einsum, "i->ij", [0, 0]) - - # output subscripts may only be specified once - assert_raises(ValueError, np.einsum, "ij->jij", [[0, 0], [0, 0]]) - - # dimensions much match when being collapsed - assert_raises(ValueError, np.einsum, "ii", np.arange(6).reshape(2, 3)) - assert_raises(ValueError, np.einsum, "ii->i", np.arange(6).reshape(2, 3)) - - # broadcasting to new dimensions must be enabled explicitly - assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2, 3)) - assert_raises(ValueError, np.einsum, "i->i", [[0, 1], [0, 1]], - out=np.arange(4).reshape(2, 2)) + for do_opt in [True, False]: + # Need enough arguments + assert_raises(ValueError, np.einsum, optimize=do_opt) + assert_raises(ValueError, np.einsum, "", optimize=do_opt) + + # subscripts must be a string + assert_raises(TypeError, np.einsum, 0, 0, optimize=do_opt) + + # out parameter must be an array + assert_raises(TypeError, np.einsum, "", 0, out='test', + optimize=do_opt) + + # order parameter must be a valid order + assert_raises(TypeError, np.einsum, "", 0, order='W', + optimize=do_opt) + + # casting parameter must be a valid casting + assert_raises(ValueError, np.einsum, "", 0, casting='blah', + optimize=do_opt) + + # dtype parameter must be a valid dtype + assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type', + optimize=do_opt) + + # other keyword arguments are rejected + assert_raises(TypeError, np.einsum, "", 0, bad_arg=0, + optimize=do_opt) + + # issue 4528 revealed a segfault with this call + assert_raises(TypeError, np.einsum, *(None,)*63, optimize=do_opt) + + # number of operands must match count in subscripts string + assert_raises(ValueError, np.einsum, "", 0, 0, optimize=do_opt) + assert_raises(ValueError, np.einsum, ",", 0, [0], [0], + optimize=do_opt) + assert_raises(ValueError, np.einsum, ",", [0], optimize=do_opt) + + # can't have more subscripts than dimensions in the operand + assert_raises(ValueError, np.einsum, "i", 0, optimize=do_opt) + assert_raises(ValueError, np.einsum, "ij", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, "...i", 0, optimize=do_opt) + assert_raises(ValueError, np.einsum, "i...j", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, "i...", 0, optimize=do_opt) + assert_raises(ValueError, np.einsum, "ij...", [0, 0], optimize=do_opt) + + # invalid ellipsis + assert_raises(ValueError, np.einsum, "i..", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, ".i...", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, "j->..j", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, "j->.j...", [0, 0], optimize=do_opt) + + # invalid subscript character + assert_raises(ValueError, np.einsum, "i%...", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, "...j$", [0, 0], optimize=do_opt) + assert_raises(ValueError, np.einsum, "i->&", [0, 0], optimize=do_opt) + + # output subscripts must appear in input + assert_raises(ValueError, np.einsum, "i->ij", [0, 0], optimize=do_opt) + + # output subscripts may only be specified once + assert_raises(ValueError, np.einsum, "ij->jij", [[0, 0], [0, 0]], + optimize=do_opt) + + # dimensions much match when being collapsed + assert_raises(ValueError, np.einsum, "ii", + np.arange(6).reshape(2, 3), optimize=do_opt) + assert_raises(ValueError, np.einsum, "ii->i", + np.arange(6).reshape(2, 3), optimize=do_opt) + + # broadcasting to new dimensions must be enabled explicitly + assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2, 3), + optimize=do_opt) + assert_raises(ValueError, np.einsum, "i->i", [[0, 1], [0, 1]], + out=np.arange(4).reshape(2, 2), optimize=do_opt) def test_einsum_views(self): # pass-through - a = np.arange(6) - a.shape = (2, 3) + for do_opt in [True, False]: + a = np.arange(6) + a.shape = (2, 3) + + b = np.einsum("...", a, optimize=do_opt) + assert_(b.base is a) + + b = np.einsum(a, [Ellipsis], optimize=do_opt) + assert_(b.base is a) + + b = np.einsum("ij", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, a) + + b = np.einsum(a, [0, 1], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, a) + + # output is writeable whenever input is writeable + b = np.einsum("...", a, optimize=do_opt) + assert_(b.flags['WRITEABLE']) + a.flags['WRITEABLE'] = False + b = np.einsum("...", a, optimize=do_opt) + assert_(not b.flags['WRITEABLE']) + + # transpose + a = np.arange(6) + a.shape = (2, 3) + + b = np.einsum("ji", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, a.T) + + b = np.einsum(a, [1, 0], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, a.T) + + # diagonal + a = np.arange(9) + a.shape = (3, 3) + + b = np.einsum("ii->i", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[i, i] for i in range(3)]) + + b = np.einsum(a, [0, 0], [0], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[i, i] for i in range(3)]) + + # diagonal with various ways of broadcasting an additional dimension + a = np.arange(27) + a.shape = (3, 3, 3) + + b = np.einsum("...ii->...i", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [[x[i, i] for i in range(3)] for x in a]) + + b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [[x[i, i] for i in range(3)] for x in a]) + + b = np.einsum("ii...->...i", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [[x[i, i] for i in range(3)] + for x in a.transpose(2, 0, 1)]) + + b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [[x[i, i] for i in range(3)] + for x in a.transpose(2, 0, 1)]) + + b = np.einsum("...ii->i...", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[:, i, i] for i in range(3)]) + + b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[:, i, i] for i in range(3)]) + + b = np.einsum("jii->ij", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[:, i, i] for i in range(3)]) + + b = np.einsum(a, [1, 0, 0], [0, 1], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[:, i, i] for i in range(3)]) + + b = np.einsum("ii...->i...", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)]) + + b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)]) + + b = np.einsum("i...i->i...", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)]) + + b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)]) + + b = np.einsum("i...i->...i", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [[x[i, i] for i in range(3)] + for x in a.transpose(1, 0, 2)]) + + b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [[x[i, i] for i in range(3)] + for x in a.transpose(1, 0, 2)]) + + # triple diagonal + a = np.arange(27) + a.shape = (3, 3, 3) + + b = np.einsum("iii->i", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[i, i, i] for i in range(3)]) + + b = np.einsum(a, [0, 0, 0], [0], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, [a[i, i, i] for i in range(3)]) - b = np.einsum("...", a) - assert_(b.base is a) - - b = np.einsum(a, [Ellipsis]) - assert_(b.base is a) - - b = np.einsum("ij", a) - assert_(b.base is a) - assert_equal(b, a) - - b = np.einsum(a, [0, 1]) - assert_(b.base is a) - assert_equal(b, a) - - # output is writeable whenever input is writeable - b = np.einsum("...", a) - assert_(b.flags['WRITEABLE']) - a.flags['WRITEABLE'] = False - b = np.einsum("...", a) - assert_(not b.flags['WRITEABLE']) - - # transpose - a = np.arange(6) - a.shape = (2, 3) - - b = np.einsum("ji", a) - assert_(b.base is a) - assert_equal(b, a.T) - - b = np.einsum(a, [1, 0]) - assert_(b.base is a) - assert_equal(b, a.T) - - # diagonal - a = np.arange(9) - a.shape = (3, 3) - - b = np.einsum("ii->i", a) - assert_(b.base is a) - assert_equal(b, [a[i, i] for i in range(3)]) - - b = np.einsum(a, [0, 0], [0]) - assert_(b.base is a) - assert_equal(b, [a[i, i] for i in range(3)]) - - # diagonal with various ways of broadcasting an additional dimension - a = np.arange(27) - a.shape = (3, 3, 3) - - b = np.einsum("...ii->...i", a) - assert_(b.base is a) - assert_equal(b, [[x[i, i] for i in range(3)] for x in a]) - - b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0]) - assert_(b.base is a) - assert_equal(b, [[x[i, i] for i in range(3)] for x in a]) - - b = np.einsum("ii...->...i", a) - assert_(b.base is a) - assert_equal(b, [[x[i, i] for i in range(3)] - for x in a.transpose(2, 0, 1)]) - - b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0]) - assert_(b.base is a) - assert_equal(b, [[x[i, i] for i in range(3)] - for x in a.transpose(2, 0, 1)]) - - b = np.einsum("...ii->i...", a) - assert_(b.base is a) - assert_equal(b, [a[:, i, i] for i in range(3)]) - - b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis]) - assert_(b.base is a) - assert_equal(b, [a[:, i, i] for i in range(3)]) - - b = np.einsum("jii->ij", a) - assert_(b.base is a) - assert_equal(b, [a[:, i, i] for i in range(3)]) - - b = np.einsum(a, [1, 0, 0], [0, 1]) - assert_(b.base is a) - assert_equal(b, [a[:, i, i] for i in range(3)]) - - b = np.einsum("ii...->i...", a) - assert_(b.base is a) - assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)]) - - b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis]) - assert_(b.base is a) - assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)]) - - b = np.einsum("i...i->i...", a) - assert_(b.base is a) - assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)]) - - b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis]) - assert_(b.base is a) - assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)]) - - b = np.einsum("i...i->...i", a) - assert_(b.base is a) - assert_equal(b, [[x[i, i] for i in range(3)] - for x in a.transpose(1, 0, 2)]) - - b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0]) - assert_(b.base is a) - assert_equal(b, [[x[i, i] for i in range(3)] - for x in a.transpose(1, 0, 2)]) - - # triple diagonal - a = np.arange(27) - a.shape = (3, 3, 3) - - b = np.einsum("iii->i", a) - assert_(b.base is a) - assert_equal(b, [a[i, i, i] for i in range(3)]) - - b = np.einsum(a, [0, 0, 0], [0]) - assert_(b.base is a) - assert_equal(b, [a[i, i, i] for i in range(3)]) + # swap axes + a = np.arange(24) + a.shape = (2, 3, 4) - # swap axes - a = np.arange(24) - a.shape = (2, 3, 4) + b = np.einsum("ijk->jik", a, optimize=do_opt) + assert_(b.base is a) + assert_equal(b, a.swapaxes(0, 1)) - b = np.einsum("ijk->jik", a) - assert_(b.base is a) - assert_equal(b, a.swapaxes(0, 1)) + b = np.einsum(a, [0, 1, 2], [1, 0, 2], optimize=do_opt) + assert_(b.base is a) + assert_equal(b, a.swapaxes(0, 1)) - b = np.einsum(a, [0, 1, 2], [1, 0, 2]) - assert_(b.base is a) - assert_equal(b, a.swapaxes(0, 1)) - - def check_einsum_sums(self, dtype): + def check_einsum_sums(self, dtype, do_opt=False): # Check various sums. Does many sizes to exercise unrolled loops. # sum(a, axis=-1) for n in range(1, 17): a = np.arange(n, dtype=dtype) - assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype)) - assert_equal(np.einsum(a, [0], []), + assert_equal(np.einsum("i->", a, optimize=do_opt), + np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [0], [], optimize=do_opt), np.sum(a, axis=-1).astype(dtype)) for n in range(1, 17): a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n) - assert_equal(np.einsum("...i->...", a), + assert_equal(np.einsum("...i->...", a, optimize=do_opt), np.sum(a, axis=-1).astype(dtype)) - assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis]), + assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt), np.sum(a, axis=-1).astype(dtype)) # sum(a, axis=0) for n in range(1, 17): a = np.arange(2*n, dtype=dtype).reshape(2, n) - assert_equal(np.einsum("i...->...", a), + assert_equal(np.einsum("i...->...", a, optimize=do_opt), np.sum(a, axis=0).astype(dtype)) - assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis]), + assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt), np.sum(a, axis=0).astype(dtype)) for n in range(1, 17): a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n) - assert_equal(np.einsum("i...->...", a), + assert_equal(np.einsum("i...->...", a, optimize=do_opt), np.sum(a, axis=0).astype(dtype)) - assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis]), + assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt), np.sum(a, axis=0).astype(dtype)) # trace(a) for n in range(1, 17): a = np.arange(n*n, dtype=dtype).reshape(n, n) - assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype)) - assert_equal(np.einsum(a, [0, 0]), np.trace(a).astype(dtype)) + assert_equal(np.einsum("ii", a, optimize=do_opt), + np.trace(a).astype(dtype)) + assert_equal(np.einsum(a, [0, 0], optimize=do_opt), + np.trace(a).astype(dtype)) # multiply(a, b) assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case for n in range(1, 17): - a = np.arange(3*n, dtype=dtype).reshape(3, n) - b = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n) - assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b)) - assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]), + a = np.arange(3 * n, dtype=dtype).reshape(3, n) + b = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n) + assert_equal(np.einsum("..., ...", a, b, optimize=do_opt), + np.multiply(a, b)) + assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis], optimize=do_opt), np.multiply(a, b)) # inner(a,b) for n in range(1, 17): - a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n) + a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n) b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b)) - assert_equal(np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0]), + assert_equal(np.einsum("...i, ...i", a, b, optimize=do_opt), np.inner(a, b)) + assert_equal(np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0], optimize=do_opt), np.inner(a, b)) for n in range(1, 11): - a = np.arange(n*3*2, dtype=dtype).reshape(n, 3, 2) + a = np.arange(n * 3 * 2, dtype=dtype).reshape(n, 3, 2) b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T) - assert_equal(np.einsum(a, [0, Ellipsis], b, [0, Ellipsis]), + assert_equal(np.einsum("i..., i...", a, b, optimize=do_opt), + np.inner(a.T, b.T).T) + assert_equal(np.einsum(a, [0, Ellipsis], b, [0, Ellipsis], optimize=do_opt), np.inner(a.T, b.T).T) # outer(a,b) for n in range(1, 17): a = np.arange(3, dtype=dtype)+1 b = np.arange(n, dtype=dtype)+1 - assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) - assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b)) + assert_equal(np.einsum("i,j", a, b, optimize=do_opt), + np.outer(a, b)) + assert_equal(np.einsum(a, [0], b, [1], optimize=do_opt), + np.outer(a, b)) # Suppress the complex warnings for the 'as f8' tests with suppress_warnings() as sup: @@ -288,62 +314,70 @@ class TestEinSum(TestCase): for n in range(1, 17): a = np.arange(4*n, dtype=dtype).reshape(4, n) b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("ij, j", a, b), np.dot(a, b)) - assert_equal(np.einsum(a, [0, 1], b, [1]), np.dot(a, b)) + assert_equal(np.einsum("ij, j", a, b, optimize=do_opt), + np.dot(a, b)) + assert_equal(np.einsum(a, [0, 1], b, [1], optimize=do_opt), + np.dot(a, b)) c = np.arange(4, dtype=dtype) np.einsum("ij,j", a, b, out=c, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) c[...] = 0 np.einsum(a, [0, 1], b, [1], out=c, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) for n in range(1, 17): a = np.arange(4*n, dtype=dtype).reshape(4, n) b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) - assert_equal(np.einsum(a.T, [1, 0], b.T, [1]), np.dot(b.T, a.T)) + assert_equal(np.einsum("ji,j", a.T, b.T, optimize=do_opt), + np.dot(b.T, a.T)) + assert_equal(np.einsum(a.T, [1, 0], b.T, [1], optimize=do_opt), + np.dot(b.T, a.T)) c = np.arange(4, dtype=dtype) - np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') + np.einsum("ji,j", a.T, b.T, out=c, + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, - np.dot(b.T.astype('f8'), - a.T.astype('f8')).astype(dtype)) + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) c[...] = 0 np.einsum(a.T, [1, 0], b.T, [1], out=c, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, - np.dot(b.T.astype('f8'), - a.T.astype('f8')).astype(dtype)) + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) # matmat(a,b) / a.dot(b) where a is matrix, b is matrix for n in range(1, 17): if n < 8 or dtype != 'f2': a = np.arange(4*n, dtype=dtype).reshape(4, n) b = np.arange(n*6, dtype=dtype).reshape(n, 6) - assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) - assert_equal(np.einsum(a, [0, 1], b, [1, 2]), np.dot(a, b)) + assert_equal(np.einsum("ij,jk", a, b, optimize=do_opt), + np.dot(a, b)) + assert_equal(np.einsum(a, [0, 1], b, [1, 2], optimize=do_opt), + np.dot(a, b)) for n in range(1, 17): a = np.arange(4*n, dtype=dtype).reshape(4, n) b = np.arange(n*6, dtype=dtype).reshape(n, 6) c = np.arange(24, dtype=dtype).reshape(4, 6) - np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') + np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe', + optimize=do_opt) assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) c[...] = 0 np.einsum(a, [0, 1], b, [1, 2], out=c, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) # matrix triple product (note this is not currently an efficient # way to multiply 3 matrices) @@ -351,21 +385,21 @@ class TestEinSum(TestCase): b = np.arange(20, dtype=dtype).reshape(4, 5) c = np.arange(30, dtype=dtype).reshape(5, 6) if dtype != 'f2': - assert_equal(np.einsum("ij,jk,kl", a, b, c), - a.dot(b).dot(c)) - assert_equal(np.einsum(a, [0, 1], b, [1, 2], c, [2, 3]), - a.dot(b).dot(c)) + assert_equal(np.einsum("ij,jk,kl", a, b, c, optimize=do_opt), + a.dot(b).dot(c)) + assert_equal(np.einsum(a, [0, 1], b, [1, 2], c, [2, 3], + optimize=do_opt), a.dot(b).dot(c)) d = np.arange(18, dtype=dtype).reshape(3, 6) np.einsum("ij,jk,kl", a, b, c, out=d, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) tgt = a.astype('f8').dot(b.astype('f8')) tgt = tgt.dot(c.astype('f8')).astype(dtype) assert_equal(d, tgt) d[...] = 0 np.einsum(a, [0, 1], b, [1, 2], c, [2, 3], out=d, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) tgt = a.astype('f8').dot(b.astype('f8')) tgt = tgt.dot(c.astype('f8')).astype(dtype) assert_equal(d, tgt) @@ -375,31 +409,31 @@ class TestEinSum(TestCase): a = np.arange(60, dtype=dtype).reshape(3, 4, 5) b = np.arange(24, dtype=dtype).reshape(4, 3, 2) assert_equal(np.einsum("ijk, jil -> kl", a, b), - np.tensordot(a, b, axes=([1, 0], [0, 1]))) + np.tensordot(a, b, axes=([1, 0], [0, 1]))) assert_equal(np.einsum(a, [0, 1, 2], b, [1, 0, 3], [2, 3]), - np.tensordot(a, b, axes=([1, 0], [0, 1]))) + np.tensordot(a, b, axes=([1, 0], [0, 1]))) c = np.arange(10, dtype=dtype).reshape(5, 2) np.einsum("ijk,jil->kl", a, b, out=c, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), - axes=([1, 0], [0, 1])).astype(dtype)) + axes=([1, 0], [0, 1])).astype(dtype)) c[...] = 0 np.einsum(a, [0, 1, 2], b, [1, 0, 3], [2, 3], out=c, - dtype='f8', casting='unsafe') + dtype='f8', casting='unsafe', optimize=do_opt) assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), - axes=([1, 0], [0, 1])).astype(dtype)) + axes=([1, 0], [0, 1])).astype(dtype)) # logical_and(logical_and(a!=0, b!=0), c!=0) a = np.array([1, 3, -2, 0, 12, 13, 0, 1], dtype=dtype) b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12], dtype=dtype) c = np.array([True, True, False, True, True, False, True, True]) assert_equal(np.einsum("i,i,i->i", a, b, c, - dtype='?', casting='unsafe'), - np.logical_and(np.logical_and(a != 0, b != 0), c != 0)) + dtype='?', casting='unsafe', optimize=do_opt), + np.logical_and(np.logical_and(a != 0, b != 0), c != 0)) assert_equal(np.einsum(a, [0], b, [0], c, [0], [0], - dtype='?', casting='unsafe'), - np.logical_and(np.logical_and(a != 0, b != 0), c != 0)) + dtype='?', casting='unsafe'), + np.logical_and(np.logical_and(a != 0, b != 0), c != 0)) a = np.arange(9, dtype=dtype) assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a)) @@ -411,21 +445,24 @@ class TestEinSum(TestCase): for n in range(1, 25): a = np.arange(n, dtype=dtype) if np.dtype(dtype).itemsize > 1: - assert_equal(np.einsum("...,...", a, a), np.multiply(a, a)) - assert_equal(np.einsum("i,i", a, a), np.dot(a, a)) - assert_equal(np.einsum("i,->i", a, 2), 2*a) - assert_equal(np.einsum(",i->i", 2, a), 2*a) - assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a)) - assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a)) - - assert_equal(np.einsum("...,...", a[1:], a[:-1]), + assert_equal(np.einsum("...,...", a, a, optimize=do_opt), + np.multiply(a, a)) + assert_equal(np.einsum("i,i", a, a, optimize=do_opt), np.dot(a, a)) + assert_equal(np.einsum("i,->i", a, 2, optimize=do_opt), 2*a) + assert_equal(np.einsum(",i->i", 2, a, optimize=do_opt), 2*a) + assert_equal(np.einsum("i,->", a, 2, optimize=do_opt), 2*np.sum(a)) + assert_equal(np.einsum(",i->", 2, a, optimize=do_opt), 2*np.sum(a)) + + assert_equal(np.einsum("...,...", a[1:], a[:-1], optimize=do_opt), np.multiply(a[1:], a[:-1])) - assert_equal(np.einsum("i,i", a[1:], a[:-1]), + assert_equal(np.einsum("i,i", a[1:], a[:-1], optimize=do_opt), np.dot(a[1:], a[:-1])) - assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:]) - assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:]) - assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:])) - assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:])) + assert_equal(np.einsum("i,->i", a[1:], 2, optimize=do_opt), 2*a[1:]) + assert_equal(np.einsum(",i->i", 2, a[1:], optimize=do_opt), 2*a[1:]) + assert_equal(np.einsum("i,->", a[1:], 2, optimize=do_opt), + 2*np.sum(a[1:])) + assert_equal(np.einsum(",i->", 2, a[1:], optimize=do_opt), + 2*np.sum(a[1:])) # An object array, summed as the data type a = np.arange(9, dtype=object) @@ -458,9 +495,11 @@ class TestEinSum(TestCase): def test_einsum_sums_int32(self): self.check_einsum_sums('i4') + self.check_einsum_sums('i4', True) def test_einsum_sums_uint32(self): self.check_einsum_sums('u4') + self.check_einsum_sums('u4', True) def test_einsum_sums_int64(self): self.check_einsum_sums('i8') @@ -476,12 +515,14 @@ class TestEinSum(TestCase): def test_einsum_sums_float64(self): self.check_einsum_sums('f8') + self.check_einsum_sums('f8', True) def test_einsum_sums_longdouble(self): self.check_einsum_sums(np.longdouble) def test_einsum_sums_cfloat64(self): self.check_einsum_sums('c8') + self.check_einsum_sums('c8', True) def test_einsum_sums_cfloat128(self): self.check_einsum_sums('c16') @@ -495,12 +536,15 @@ class TestEinSum(TestCase): a = np.ones((1, 2)) b = np.ones((2, 2, 1)) assert_equal(np.einsum('ij...,j...->i...', a, b), [[[2], [2]]]) + assert_equal(np.einsum('ij...,j...->i...', a, b, optimize=True), [[[2], [2]]]) # The iterator had an issue with buffering this reduction a = np.ones((5, 12, 4, 2, 3), np.int64) b = np.ones((5, 12, 11), np.int64) assert_equal(np.einsum('ijklm,ijn,ijn->', a, b, b), - np.einsum('ijklm,ijn->', a, b)) + np.einsum('ijklm,ijn->', a, b)) + assert_equal(np.einsum('ijklm,ijn,ijn->', a, b, b, optimize=True), + np.einsum('ijklm,ijn->', a, b, optimize=True)) # Issue #2027, was a problem in the contiguous 3-argument # inner loop implementation @@ -508,8 +552,11 @@ class TestEinSum(TestCase): b = np.arange(1, 5).reshape(2, 2) c = np.arange(1, 9).reshape(4, 2) assert_equal(np.einsum('x,yx,zx->xzy', a, b, c), - [[[1, 3], [3, 9], [5, 15], [7, 21]], - [[8, 16], [16, 32], [24, 48], [32, 64]]]) + [[[1, 3], [3, 9], [5, 15], [7, 21]], + [[8, 16], [16, 32], [24, 48], [32, 64]]]) + assert_equal(np.einsum('x,yx,zx->xzy', a, b, c, optimize=True), + [[[1, 3], [3, 9], [5, 15], [7, 21]], + [[8, 16], [16, 32], [24, 48], [32, 64]]]) def test_einsum_broadcast(self): # Issue #2455 change in handling ellipsis @@ -517,23 +564,33 @@ class TestEinSum(TestCase): # only use the 'RIGHT' iteration in prepare_op_axes # adds auto broadcast on left where it belongs # broadcast on right has to be explicit + # We need to test the optimized parsing as well - A = np.arange(2*3*4).reshape(2,3,4) + A = np.arange(2 * 3 * 4).reshape(2, 3, 4) B = np.arange(3) - ref = np.einsum('ijk,j->ijk',A, B) - assert_equal(np.einsum('ij...,j...->ij...',A, B), ref) - assert_equal(np.einsum('ij...,...j->ij...',A, B), ref) - assert_equal(np.einsum('ij...,j->ij...',A, B), ref) # used to raise error + ref = np.einsum('ijk,j->ijk', A, B) + assert_equal(np.einsum('ij...,j...->ij...', A, B), ref) + assert_equal(np.einsum('ij...,...j->ij...', A, B), ref) + assert_equal(np.einsum('ij...,j->ij...', A, B), ref) # used to raise error + + assert_equal(np.einsum('ij...,j...->ij...', A, B, optimize=True), ref) + assert_equal(np.einsum('ij...,...j->ij...', A, B, optimize=True), ref) + assert_equal(np.einsum('ij...,j->ij...', A, B, optimize=True), ref) # used to raise error - A = np.arange(12).reshape((4,3)) - B = np.arange(6).reshape((3,2)) + A = np.arange(12).reshape((4, 3)) + B = np.arange(6).reshape((3, 2)) ref = np.einsum('ik,kj->ij', A, B) assert_equal(np.einsum('ik...,k...->i...', A, B), ref) assert_equal(np.einsum('ik...,...kj->i...j', A, B), ref) assert_equal(np.einsum('...k,kj', A, B), ref) # used to raise error assert_equal(np.einsum('ik,k...->i...', A, B), ref) # used to raise error - dims = [2,3,4,5] + assert_equal(np.einsum('ik...,k...->i...', A, B, optimize=True), ref) + assert_equal(np.einsum('ik...,...kj->i...j', A, B, optimize=True), ref) + assert_equal(np.einsum('...k,kj', A, B, optimize=True), ref) # used to raise error + assert_equal(np.einsum('ik,k...->i...', A, B, optimize=True), ref) # used to raise error + + dims = [2, 3, 4, 5] a = np.arange(np.prod(dims)).reshape(dims) v = np.arange(dims[2]) ref = np.einsum('ijkl,k->ijl', a, v) @@ -542,11 +599,17 @@ class TestEinSum(TestCase): assert_equal(np.einsum('...kl,k...', a, v), ref) # no real diff from 1st - J,K,M = 160,160,120 - A = np.arange(J*K*M).reshape(1,1,1,J,K,M) - B = np.arange(J*K*M*3).reshape(J,K,M,3) + assert_equal(np.einsum('ijkl,k', a, v, optimize=True), ref) + assert_equal(np.einsum('...kl,k', a, v, optimize=True), ref) # used to raise error + assert_equal(np.einsum('...kl,k...', a, v, optimize=True), ref) + + J, K, M = 160, 160, 120 + A = np.arange(J * K * M).reshape(1, 1, 1, J, K, M) + B = np.arange(J * K * M * 3).reshape(J, K, M, 3) ref = np.einsum('...lmn,...lmno->...o', A, B) assert_equal(np.einsum('...lmn,lmno->...o', A, B), ref) # used to raise error + assert_equal(np.einsum('...lmn,lmno->...o', A, B, + optimize=True), ref) # used to raise error def test_einsum_fixedstridebug(self): # Issue #4485 obscure einsum bug @@ -565,22 +628,22 @@ class TestEinSum(TestCase): # used by einsum, is 8192, and 3*2731 = 8193, is larger than that # and results in a mismatch between the buffering and the # striding for operand A. - A = np.arange(2*3).reshape(2,3).astype(np.float32) - B = np.arange(2*3*2731).reshape(2,3,2731).astype(np.int16) - es = np.einsum('cl,cpx->lpx', A, B) - tp = np.tensordot(A, B, axes=(0, 0)) - assert_equal(es, tp) + A = np.arange(2 * 3).reshape(2, 3).astype(np.float32) + B = np.arange(2 * 3 * 2731).reshape(2, 3, 2731).astype(np.int16) + es = np.einsum('cl, cpx->lpx', A, B) + tp = np.tensordot(A, B, axes=(0, 0)) + assert_equal(es, tp) # The following is the original test case from the bug report, # made repeatable by changing random arrays to aranges. - A = np.arange(3*3).reshape(3,3).astype(np.float64) - B = np.arange(3*3*64*64).reshape(3,3,64,64).astype(np.float32) - es = np.einsum('cl,cpxy->lpxy', A,B) - tp = np.tensordot(A,B, axes=(0,0)) + A = np.arange(3 * 3).reshape(3, 3).astype(np.float64) + B = np.arange(3 * 3 * 64 * 64).reshape(3, 3, 64, 64).astype(np.float32) + es = np.einsum('cl, cpxy->lpxy', A, B) + tp = np.tensordot(A, B, axes=(0, 0)) assert_equal(es, tp) def test_einsum_fixed_collapsingbug(self): # Issue #5147. - # The bug only occurred when output argument of einssum was used. + # The bug only occured when output argument of einssum was used. x = np.random.normal(0, 1, (5, 5, 5, 5)) y1 = np.zeros((5, 5)) np.einsum('aabb->ab', x, out=y1) @@ -617,10 +680,219 @@ class TestEinSum(TestCase): a = np.zeros((16, 1, 1), dtype=np.bool_)[:2] a[...] = True out = np.zeros((16, 1, 1), dtype=np.bool_)[:2] - tgt = np.ones((2,1,1), dtype=np.bool_) + tgt = np.ones((2, 1, 1), dtype=np.bool_) res = np.einsum('...ij,...jk->...ik', a, a, out=out) assert_equal(res, tgt) + def optimize_compare(self, string): + # Tests all paths of the optimization function against + # conventional einsum + operands = [string] + terms = string.split('->')[0].split(',') + for term in terms: + dims = [global_size_dict[x] for x in term] + operands.append(np.random.rand(*dims)) + + noopt = np.einsum(*operands, optimize=False) + opt = np.einsum(*operands, optimize='greedy') + assert_almost_equal(opt, noopt) + opt = np.einsum(*operands, optimize='optimal') + assert_almost_equal(opt, noopt) + + def test_hadamard_like_products(self): + # Hadamard outer products + self.optimize_compare('a,ab,abc->abc') + self.optimize_compare('a,b,ab->ab') + + def test_index_transformations(self): + # Simple index transformation cases + self.optimize_compare('ea,fb,gc,hd,abcd->efgh') + self.optimize_compare('ea,fb,abcd,gc,hd->efgh') + self.optimize_compare('abcd,ea,fb,gc,hd->efgh') + + def test_complex(self): + # Long test cases + self.optimize_compare('acdf,jbje,gihb,hfac,gfac,gifabc,hfac') + self.optimize_compare('acdf,jbje,gihb,hfac,gfac,gifabc,hfac') + self.optimize_compare('cd,bdhe,aidb,hgca,gc,hgibcd,hgac') + self.optimize_compare('abhe,hidj,jgba,hiab,gab') + self.optimize_compare('bde,cdh,agdb,hica,ibd,hgicd,hiac') + self.optimize_compare('chd,bde,agbc,hiad,hgc,hgi,hiad') + self.optimize_compare('chd,bde,agbc,hiad,bdi,cgh,agdb') + self.optimize_compare('bdhe,acad,hiab,agac,hibd') + + def test_collapse(self): + # Inner products + self.optimize_compare('ab,ab,c->') + self.optimize_compare('ab,ab,c->c') + self.optimize_compare('ab,ab,cd,cd->') + self.optimize_compare('ab,ab,cd,cd->ac') + self.optimize_compare('ab,ab,cd,cd->cd') + self.optimize_compare('ab,ab,cd,cd,ef,ef->') + + def test_expand(self): + # Outer products + self.optimize_compare('ab,cd,ef->abcdef') + self.optimize_compare('ab,cd,ef->acdf') + self.optimize_compare('ab,cd,de->abcde') + self.optimize_compare('ab,cd,de->be') + self.optimize_compare('ab,bcd,cd->abcd') + self.optimize_compare('ab,bcd,cd->abd') + + def test_edge_cases(self): + # Difficult edge cases for optimization + self.optimize_compare('eb,cb,fb->cef') + self.optimize_compare('dd,fb,be,cdb->cef') + self.optimize_compare('bca,cdb,dbf,afc->') + self.optimize_compare('dcc,fce,ea,dbf->ab') + self.optimize_compare('fdf,cdd,ccd,afe->ae') + self.optimize_compare('abcd,ad') + self.optimize_compare('ed,fcd,ff,bcf->be') + self.optimize_compare('baa,dcf,af,cde->be') + self.optimize_compare('bd,db,eac->ace') + self.optimize_compare('fff,fae,bef,def->abd') + self.optimize_compare('efc,dbc,acf,fd->abe') + self.optimize_compare('ba,ac,da->bcd') + + def test_inner_product(self): + # Inner products + self.optimize_compare('ab,ab') + self.optimize_compare('ab,ba') + self.optimize_compare('abc,abc') + self.optimize_compare('abc,bac') + self.optimize_compare('abc,cba') + + def test_random_cases(self): + # Randomly built test cases + self.optimize_compare('aab,fa,df,ecc->bde') + self.optimize_compare('ecb,fef,bad,ed->ac') + self.optimize_compare('bcf,bbb,fbf,fc->') + self.optimize_compare('bb,ff,be->e') + self.optimize_compare('bcb,bb,fc,fff->') + self.optimize_compare('fbb,dfd,fc,fc->') + self.optimize_compare('afd,ba,cc,dc->bf') + self.optimize_compare('adb,bc,fa,cfc->d') + self.optimize_compare('bbd,bda,fc,db->acf') + self.optimize_compare('dba,ead,cad->bce') + self.optimize_compare('aef,fbc,dca->bde') + + +class TestEinSumPath(TestCase): + def build_operands(self, string): + + # Builds views based off initial operands + operands = [string] + terms = string.split('->')[0].split(',') + for term in terms: + dims = [global_size_dict[x] for x in term] + operands.append(np.random.rand(*dims)) + + return operands + + def assert_path_equal(self, comp, benchmark): + # Checks if list of tuples are equivalent + ret = (len(comp) == len(benchmark)) + assert_(ret) + for pos in range(len(comp) - 1): + ret &= isinstance(comp[pos + 1], tuple) + ret &= (comp[pos + 1] == benchmark[pos + 1]) + assert_(ret) + + def test_memory_contraints(self): + # Ensure memory constraints are satisfied + + outer_test = self.build_operands('a,b,c->abc') + + path, path_str = np.einsum_path(*outer_test, optimize=('greedy', 0)) + self.assert_path_equal(path, ['einsum_path', (0, 1, 2)]) + + path, path_str = np.einsum_path(*outer_test, optimize=('optimal', 0)) + self.assert_path_equal(path, ['einsum_path', (0, 1, 2)]) + + long_test = self.build_operands('acdf,jbje,gihb,hfac') + path, path_str = np.einsum_path(*long_test, optimize=('greedy', 0)) + self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)]) + + path, path_str = np.einsum_path(*long_test, optimize=('optimal', 0)) + self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)]) + + def test_long_paths(self): + # Long complex cases + + # Long test 1 + long_test1 = self.build_operands('acdf,jbje,gihb,hfac,gfac,gifabc,hfac') + path, path_str = np.einsum_path(*long_test1, optimize='greedy') + self.assert_path_equal(path, ['einsum_path', + (1, 4), (2, 4), (1, 4), (1, 3), (1, 2), (0, 1)]) + + path, path_str = np.einsum_path(*long_test1, optimize='optimal') + self.assert_path_equal(path, ['einsum_path', + (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)]) + + # Long test 2 + long_test2 = self.build_operands('chd,bde,agbc,hiad,bdi,cgh,agdb') + path, path_str = np.einsum_path(*long_test2, optimize='greedy') + self.assert_path_equal(path, ['einsum_path', + (3, 4), (0, 3), (3, 4), (1, 3), (1, 2), (0, 1)]) + + path, path_str = np.einsum_path(*long_test2, optimize='optimal') + self.assert_path_equal(path, ['einsum_path', + (0, 5), (1, 4), (3, 4), (1, 3), (1, 2), (0, 1)]) + + def test_edge_paths(self): + # Difficult edge cases + + # Edge test1 + edge_test1 = self.build_operands('eb,cb,fb->cef') + path, path_str = np.einsum_path(*edge_test1, optimize='greedy') + self.assert_path_equal(path, ['einsum_path', (0, 2), (0, 1)]) + + path, path_str = np.einsum_path(*edge_test1, optimize='optimal') + self.assert_path_equal(path, ['einsum_path', (0, 2), (0, 1)]) + + # Edge test2 + edge_test2 = self.build_operands('dd,fb,be,cdb->cef') + path, path_str = np.einsum_path(*edge_test2, optimize='greedy') + self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 1), (0, 1)]) + + path, path_str = np.einsum_path(*edge_test2, optimize='optimal') + self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 1), (0, 1)]) + + # Edge test3 + edge_test3 = self.build_operands('bca,cdb,dbf,afc->') + path, path_str = np.einsum_path(*edge_test3, optimize='greedy') + self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)]) + + path, path_str = np.einsum_path(*edge_test3, optimize='optimal') + self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)]) + + # Edge test4 + edge_test4 = self.build_operands('dcc,fce,ea,dbf->ab') + path, path_str = np.einsum_path(*edge_test4, optimize='greedy') + self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)]) + + path, path_str = np.einsum_path(*edge_test4, optimize='optimal') + self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)]) + + def test_path_type_input(self): + # Test explicit path handeling + path_test = self.build_operands('dcc,fce,ea,dbf->ab') + + path, path_str = np.einsum_path(*path_test, optimize=False) + self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)]) + + path, path_str = np.einsum_path(*path_test, optimize=True) + self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)]) + + exp_path = ['einsum_path', (0, 2), (0, 2), (0, 1)] + path, path_str = np.einsum_path(*path_test, optimize=exp_path) + self.assert_path_equal(path, exp_path) + + # Double check einsum works on the input path + noopt = np.einsum(*path_test, optimize=False) + opt = np.einsum(*path_test, optimize=exp_path) + assert_almost_equal(noopt, opt) + if __name__ == "__main__": run_module_suite() |