summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/__init__.py3
-rw-r--r--numpy/core/einsumfunc.py990
-rw-r--r--numpy/core/numeric.py4
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c2
-rw-r--r--numpy/core/tests/test_einsum.py866
5 files changed, 1565 insertions, 300 deletions
diff --git a/numpy/core/__init__.py b/numpy/core/__init__.py
index 1ac850002..ca2f45ece 100644
--- a/numpy/core/__init__.py
+++ b/numpy/core/__init__.py
@@ -50,6 +50,8 @@ from . import getlimits
from .getlimits import *
from . import shape_base
from .shape_base import *
+from . import einsumfunc
+from .einsumfunc import *
del nt
from .fromnumeric import amax as max, amin as min, round_ as round
@@ -64,6 +66,7 @@ __all__ += function_base.__all__
__all__ += machar.__all__
__all__ += getlimits.__all__
__all__ += shape_base.__all__
+__all__ += einsumfunc.__all__
from numpy.testing.nosetester import _numpy_tester
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
new file mode 100644
index 000000000..97eb7924f
--- /dev/null
+++ b/numpy/core/einsumfunc.py
@@ -0,0 +1,990 @@
+"""
+Implementation of optimized einsum.
+
+"""
+from __future__ import division, absolute_import, print_function
+
+from numpy.core.multiarray import c_einsum
+from numpy.core.numeric import asarray, asanyarray, result_type
+
+__all__ = ['einsum', 'einsum_path']
+
+einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+einsum_symbols_set = set(einsum_symbols)
+
+
+def _compute_size_by_dict(indices, idx_dict):
+ """
+ Computes the product of the elements in indices based on the dictionary
+ idx_dict.
+
+ Parameters
+ ----------
+ indices : iterable
+ Indices to base the product on.
+ idx_dict : dictionary
+ Dictionary of index sizes
+
+ Returns
+ -------
+ ret : int
+ The resulting product.
+
+ Examples
+ --------
+ >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
+ 90
+
+ """
+ ret = 1
+ for i in indices:
+ ret *= idx_dict[i]
+ return ret
+
+
+def _find_contraction(positions, input_sets, output_set):
+ """
+ Finds the contraction for a given set of input and output sets.
+
+ Paramaters
+ ----------
+ positions : iterable
+ Integer positions of terms used in the contraction.
+ input_sets : list
+ List of sets that represent the lhs side of the einsum subscript
+ output_set : set
+ Set that represents the rhs side of the overall einsum subscript
+
+ Returns
+ -------
+ new_result : set
+ The indices of the resulting contraction
+ remaining : list
+ List of sets that have not been contracted, the new set is appended to
+ the end of this list
+ idx_removed : set
+ Indices removed from the entire contraction
+ idx_contraction : set
+ The indices used in the current contraction
+
+ Examples
+ --------
+
+ # A simple dot product test case
+ >>> pos = (0, 1)
+ >>> isets = [set('ab'), set('bc')]
+ >>> oset = set('ac')
+ >>> _find_contraction(pos, isets, oset)
+ ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
+
+ # A more complex case with additional terms in the contraction
+ >>> pos = (0, 2)
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
+ >>> oset = set('ac')
+ >>> _find_contraction(pos, isets, oset)
+ ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
+ """
+
+ idx_contract = set()
+ idx_remain = output_set.copy()
+ remaining = []
+ for ind, value in enumerate(input_sets):
+ if ind in positions:
+ idx_contract |= value
+ else:
+ remaining.append(value)
+ idx_remain |= value
+
+ new_result = idx_remain & idx_contract
+ idx_removed = (idx_contract - new_result)
+ remaining.append(new_result)
+
+ return (new_result, remaining, idx_removed, idx_contract)
+
+
+def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
+ """
+ Computes all possible pair contractions, sieves the results based
+ on ``memory_limit`` and returns the lowest cost path. This algorithm
+ scales factorial with respect to the elements in the list ``input_sets``.
+
+ Paramaters
+ ----------
+ input_sets : list
+ List of sets that represent the lhs side of the einsum subscript
+ output_set : set
+ Set that represents the rhs side of the overall einsum subscript
+ idx_dict : dictionary
+ Dictionary of index sizes
+ memory_limit : int
+ The maximum number of elements in a temporary array
+
+ Returns
+ -------
+ path : list
+ The optimal contraction order within the memory limit constraint.
+
+ Examples
+ --------
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
+ >>> oset = set('')
+ >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
+ >>> _path__optimal_path(isets, oset, idx_sizes, 5000)
+ [(0, 2), (0, 1)]
+ """
+
+ full_results = [(0, [], input_sets)]
+ for iteration in range(len(input_sets) - 1):
+ iter_results = []
+
+ # Compute all unique pairs
+ comb_iter = []
+ for x in range(len(input_sets) - iteration):
+ for y in range(x + 1, len(input_sets) - iteration):
+ comb_iter.append((x, y))
+
+ for curr in full_results:
+ cost, positions, remaining = curr
+ for con in comb_iter:
+
+ # Find the contraction
+ cont = _find_contraction(con, remaining, output_set)
+ new_result, new_input_sets, idx_removed, idx_contract = cont
+
+ # Sieve the results based on memory_limit
+ new_size = _compute_size_by_dict(new_result, idx_dict)
+ if new_size > memory_limit:
+ continue
+
+ # Find cost
+ new_cost = _compute_size_by_dict(idx_contract, idx_dict)
+ if idx_removed:
+ new_cost *= 2
+
+ # Build (total_cost, positions, indices_remaining)
+ new_cost += cost
+ new_pos = positions + [con]
+ iter_results.append((new_cost, new_pos, new_input_sets))
+
+ # Update list to iterate over
+ full_results = iter_results
+
+ # If we have not found anything return single einsum contraction
+ if len(full_results) == 0:
+ return [tuple(range(len(input_sets)))]
+
+ path = min(full_results, key=lambda x: x[0])[1]
+ return path
+
+
+def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
+ """
+ Finds the path by contracting the best pair until the input list is
+ exhausted. The best pair is found by minimizing the tuple
+ ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
+ matrix multiplication or inner product operations, then Hadamard like
+ operations, and finally outer operations. Outer products are limited by
+ ``memory_limit``. This algorithm scales cubically with respect to the
+ number of elements in the list ``input_sets``.
+
+ Paramaters
+ ----------
+ input_sets : list
+ List of sets that represent the lhs side of the einsum subscript
+ output_set : set
+ Set that represents the rhs side of the overall einsum subscript
+ idx_dict : dictionary
+ Dictionary of index sizes
+ memory_limit_limit : int
+ The maximum number of elements in a temporary array
+
+ Returns
+ -------
+ path : list
+ The greedy contraction order within the memory limit constraint.
+
+ Examples
+ --------
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
+ >>> oset = set('')
+ >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
+ >>> _path__greedy_path(isets, oset, idx_sizes, 5000)
+ [(0, 2), (0, 1)]
+ """
+
+ if len(input_sets) == 1:
+ return [(0,)]
+
+ path = []
+ for iteration in range(len(input_sets) - 1):
+ iteration_results = []
+ comb_iter = []
+
+ # Compute all unique pairs
+ for x in range(len(input_sets)):
+ for y in range(x + 1, len(input_sets)):
+ comb_iter.append((x, y))
+
+ for positions in comb_iter:
+
+ # Find the contraction
+ contract = _find_contraction(positions, input_sets, output_set)
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
+
+ # Sieve the results based on memory_limit
+ if _compute_size_by_dict(idx_result, idx_dict) > memory_limit:
+ continue
+
+ # Build sort tuple
+ removed_size = _compute_size_by_dict(idx_removed, idx_dict)
+ cost = _compute_size_by_dict(idx_contract, idx_dict)
+ sort = (-removed_size, cost)
+
+ # Add contraction to possible choices
+ iteration_results.append([sort, positions, new_input_sets])
+
+ # If we did not find a new contraction contract remaining
+ if len(iteration_results) == 0:
+ path.append(tuple(range(len(input_sets))))
+ break
+
+ # Sort based on first index
+ best = min(iteration_results, key=lambda x: x[0])
+ path.append(best[1])
+ input_sets = best[2]
+
+ return path
+
+
+def _parse_einsum_input(operands):
+ """
+ A reproduction of einsum c side einsum parsing in python.
+
+ Returns
+ -------
+ input_strings : str
+ Parsed input strings
+ output_string : str
+ Parsed output string
+ operands : list of array_like
+ The operands to use in the numpy contraction
+
+ Examples
+ --------
+ The operand list is simplified to reduce printing:
+
+ >>> a = np.random.rand(4, 4)
+ >>> b = np.random.rand(4, 4, 4)
+ >>> __parse_einsum_input(('...a,...a->...', a, b))
+ ('za,xza', 'xz', [a, b])
+
+ >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
+ ('za,xza', 'xz', [a, b])
+ """
+
+ if len(operands) == 0:
+ raise ValueError("No input operands")
+
+ if isinstance(operands[0], str):
+ subscripts = operands[0].replace(" ", "")
+ operands = [asanyarray(v) for v in operands[1:]]
+
+ # Ensure all characters are valid
+ for s in subscripts:
+ if s in '.,->':
+ continue
+ if s not in einsum_symbols:
+ raise ValueError("Character %s is not a valid symbol." % s)
+
+ else:
+ tmp_operands = list(operands)
+ operand_list = []
+ subscript_list = []
+ for p in range(len(operands) // 2):
+ operand_list.append(tmp_operands.pop(0))
+ subscript_list.append(tmp_operands.pop(0))
+
+ output_list = tmp_operands[-1] if len(tmp_operands) else None
+ operands = [asanyarray(v) for v in operand_list]
+ subscripts = ""
+ last = len(subscript_list) - 1
+ for num, sub in enumerate(subscript_list):
+ for s in sub:
+ if s is Ellipsis:
+ subscripts += "..."
+ elif isinstance(s, int):
+ subscripts += einsum_symbols[s]
+ else:
+ raise TypeError("For this input type lists must contain "
+ "either int or Ellipsis")
+ if num != last:
+ subscripts += ","
+
+ if output_list is not None:
+ subscripts += "->"
+ for s in output_list:
+ if s is Ellipsis:
+ subscripts += "..."
+ elif isinstance(s, int):
+ subscripts += einsum_symbols[s]
+ else:
+ raise TypeError("For this input type lists must contain "
+ "either int or Ellipsis")
+ # Check for proper "->"
+ if ("-" in subscripts) or (">" in subscripts):
+ invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
+ if invalid or (subscripts.count("->") != 1):
+ raise ValueError("Subscripts can only contain one '->'.")
+
+ # Parse ellipses
+ if "." in subscripts:
+ used = subscripts.replace(".", "").replace(",", "").replace("->", "")
+ unused = list(einsum_symbols_set - set(used))
+ ellipse_inds = "".join(unused)
+ longest = 0
+
+ if "->" in subscripts:
+ input_tmp, output_sub = subscripts.split("->")
+ split_subscripts = input_tmp.split(",")
+ out_sub = True
+ else:
+ split_subscripts = subscripts.split(',')
+ out_sub = False
+
+ for num, sub in enumerate(split_subscripts):
+ if "." in sub:
+ if (sub.count(".") != 3) or (sub.count("...") != 1):
+ raise ValueError("Invalid Ellipses.")
+
+ # Take into account numerical values
+ if operands[num].shape == ():
+ ellipse_count = 0
+ else:
+ ellipse_count = max(len(operands[num].shape), 1)
+ ellipse_count -= (len(sub) - 3)
+
+ if ellipse_count > longest:
+ longest = ellipse_count
+
+ if ellipse_count < 0:
+ raise ValueError("Ellipses lengths do not match.")
+ elif ellipse_count == 0:
+ split_subscripts[num] = sub.replace('...', '')
+ else:
+ rep_inds = ellipse_inds[-ellipse_count:]
+ split_subscripts[num] = sub.replace('...', rep_inds)
+
+ subscripts = ",".join(split_subscripts)
+ if longest == 0:
+ out_ellipse = ""
+ else:
+ out_ellipse = ellipse_inds[-longest:]
+
+ if out_sub:
+ subscripts += "->" + output_sub.replace("...", out_ellipse)
+ else:
+ # Special care for outputless ellipses
+ output_subscript = ""
+ tmp_subscripts = subscripts.replace(",", "")
+ for s in sorted(set(tmp_subscripts)):
+ if s not in (einsum_symbols):
+ raise ValueError("Character %s is not a valid symbol." % s)
+ if tmp_subscripts.count(s) == 1:
+ output_subscript += s
+ normal_inds = ''.join(sorted(set(output_subscript) -
+ set(out_ellipse)))
+
+ subscripts += "->" + out_ellipse + normal_inds
+
+ # Build output string if does not exist
+ if "->" in subscripts:
+ input_subscripts, output_subscript = subscripts.split("->")
+ else:
+ input_subscripts = subscripts
+ # Build output subscripts
+ tmp_subscripts = subscripts.replace(",", "")
+ output_subscript = ""
+ for s in sorted(set(tmp_subscripts)):
+ if s not in einsum_symbols:
+ raise ValueError("Character %s is not a valid symbol." % s)
+ if tmp_subscripts.count(s) == 1:
+ output_subscript += s
+
+ # Make sure output subscripts are in the input
+ for char in output_subscript:
+ if char not in input_subscripts:
+ raise ValueError("Output character %s did not appear in the input"
+ % char)
+
+ # Make sure number operands is equivalent to the number of terms
+ if len(input_subscripts.split(',')) != len(operands):
+ raise ValueError("Number of einsum subscripts must be equal to the "
+ "number of operands.")
+
+ return (input_subscripts, output_subscript, operands)
+
+
+def einsum_path(*operands, **kwargs):
+ """
+ einsum_path(subscripts, *operands, optimize='greedy')
+
+ Evaluates the lowest cost contraction order for an einsum expression by
+ considering the creation of intermediate arrays.
+
+ Parameters
+ ----------
+ subscripts : str
+ Specifies the subscripts for summation.
+ *operands : list of array_like
+ These are the arrays for the operation.
+ optimize : {bool, list, tuple, 'greedy', 'optimal'}
+ Choose the type of path. If a tuple is provided, the second argument is
+ assumed to be the maximum intermediate size created. If only a single
+ argument is provided the largest input or output array size is used
+ as a maximum intermediate size.
+
+ * if a list is given that starts with ``einsum_path``, uses this as the
+ contraction path
+ * if False no optimization is taken
+ * if True defaults to the 'greedy' algorithm
+ * 'optimal' An algorithm that combinatorially explores all possible
+ ways of contracting the listed tensors and choosest the least costly
+ path. Scales exponentially with the number of terms in the
+ contraction.
+ * 'greedy' An algorithm that chooses the best pair contraction
+ at each step. Effectively, this algorithm searches the largest inner,
+ Hadamard, and then outer products at each step. Scales cubically with
+ the number of terms in the contraction. Equivalent to the 'optimal'
+ path for most contractions.
+
+ Default is 'greedy'.
+
+ Returns
+ -------
+ path : list of tuples
+ A list representation of the einsum path.
+ string_repr : str
+ A printable representation of the einsum path.
+
+ Notes
+ -----
+ The resulting path indicates which terms of the input contraction should be
+ contracted first, the result of this contraction is then appended to the
+ end of the contraction list. This list can then be iterated over until all
+ intermediate contractions are complete.
+
+ See Also
+ --------
+ einsum, linalg.multi_dot
+
+ Examples
+ --------
+
+ We can begin with a chain dot example. In this case, it is optimal to
+ contract the ``b`` and ``c`` tensors first as reprsented by the first
+ element of the path ``(1, 2)``. The resulting tensor is added to the end
+ of the contraction and the remaining contraction ``(0, 1)`` is then
+ completed.
+
+ >>> a = np.random.rand(2, 2)
+ >>> b = np.random.rand(2, 5)
+ >>> c = np.random.rand(5, 2)
+ >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
+ >>> print(path_info[0])
+ ['einsum_path', (1, 2), (0, 1)]
+ >>> print(path_info[1])
+ Complete contraction: ij,jk,kl->il
+ Naive scaling: 4
+ Optimized scaling: 3
+ Naive FLOP count: 1.600e+02
+ Optimized FLOP count: 5.600e+01
+ Theoretical speedup: 2.857
+ Largest intermediate: 4.000e+00 elements
+ -------------------------------------------------------------------------
+ scaling current remaining
+ -------------------------------------------------------------------------
+ 3 kl,jk->jl ij,jl->il
+ 3 jl,ij->il il->il
+
+
+ A more complex index transformation example.
+
+ >>> I = np.random.rand(10, 10, 10, 10)
+ >>> C = np.random.rand(10, 10)
+ >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
+ optimize='greedy')
+
+ >>> print(path_info[0])
+ ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
+ >>> print(path_info[1])
+ Complete contraction: ea,fb,abcd,gc,hd->efgh
+ Naive scaling: 8
+ Optimized scaling: 5
+ Naive FLOP count: 8.000e+08
+ Optimized FLOP count: 8.000e+05
+ Theoretical speedup: 1000.000
+ Largest intermediate: 1.000e+04 elements
+ --------------------------------------------------------------------------
+ scaling current remaining
+ --------------------------------------------------------------------------
+ 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
+ 5 bcde,fb->cdef gc,hd,cdef->efgh
+ 5 cdef,gc->defg hd,defg->efgh
+ 5 defg,hd->efgh efgh->efgh
+ """
+
+ # Make sure all keywords are valid
+ valid_contract_kwargs = ['optimize', 'einsum_call']
+ unknown_kwargs = [k for (k, v) in kwargs.items() if k
+ not in valid_contract_kwargs]
+ if len(unknown_kwargs):
+ raise TypeError("Did not understand the following kwargs:"
+ " %s" % unknown_kwargs)
+
+ # Figure out what the path really is
+ path_type = kwargs.pop('optimize', False)
+ if path_type is True:
+ path_type = 'greedy'
+ if path_type is None:
+ path_type = False
+
+ memory_limit = None
+
+ # No optimization or a named path algorithm
+ if (path_type is False) or isinstance(path_type, str):
+ pass
+
+ # Given an explicit path
+ elif len(path_type) and (path_type[0] == 'einsum_path'):
+ pass
+
+ # Path tuple with memory limit
+ elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
+ isinstance(path_type[1], (int, float))):
+ memory_limit = int(path_type[1])
+ path_type = path_type[0]
+
+ else:
+ raise TypeError("Did not understand the path: %s" % str(path_type))
+
+ # Hidden option, only einsum should call this
+ einsum_call_arg = kwargs.pop("einsum_call", False)
+
+ # Python side parsing
+ input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
+ subscripts = input_subscripts + '->' + output_subscript
+
+ # Build a few useful list and sets
+ input_list = input_subscripts.split(',')
+ input_sets = [set(x) for x in input_list]
+ output_set = set(output_subscript)
+ indices = set(input_subscripts.replace(',', ''))
+
+ # Get length of each unique dimension and ensure all dimensions are correct
+ dimension_dict = {}
+ for tnum, term in enumerate(input_list):
+ sh = operands[tnum].shape
+ if len(sh) != len(term):
+ raise ValueError("Einstein sum subscript %s does not contain the "
+ "correct number of indices for operand %d.",
+ input_subscripts[tnum], tnum)
+ for cnum, char in enumerate(term):
+ dim = sh[cnum]
+ if char in dimension_dict.keys():
+ if dimension_dict[char] != dim:
+ raise ValueError("Size of label '%s' for operand %d does "
+ "not match previous terms.", char, tnum)
+ else:
+ dimension_dict[char] = dim
+
+ # Compute size of each input array plus the output array
+ size_list = []
+ for term in input_list + [output_subscript]:
+ size_list.append(_compute_size_by_dict(term, dimension_dict))
+ max_size = max(size_list)
+
+ if memory_limit is None:
+ memory_arg = max_size
+ else:
+ memory_arg = memory_limit
+
+ # Compute naive cost
+ # This isnt quite right, need to look into exactly how einsum does this
+ naive_cost = _compute_size_by_dict(indices, dimension_dict)
+ indices_in_input = input_subscripts.replace(',', '')
+ mult = max(len(input_list) - 1, 1)
+ if (len(indices_in_input) - len(set(indices_in_input))):
+ mult *= 2
+ naive_cost *= mult
+
+ # Compute the path
+ if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
+ # Nothing to be optimized, leave it to einsum
+ path = [tuple(range(len(input_list)))]
+ elif path_type == "greedy":
+ # Maximum memory should be at most out_size for this algorithm
+ memory_arg = min(memory_arg, max_size)
+ path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
+ elif path_type == "optimal":
+ path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
+ elif path_type[0] == 'einsum_path':
+ path = path_type[1:]
+ else:
+ raise KeyError("Path name %s not found", path_type)
+
+ cost_list, scale_list, size_list, contraction_list = [], [], [], []
+
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
+ for cnum, contract_inds in enumerate(path):
+ # Make sure we remove inds from right to left
+ contract_inds = tuple(sorted(list(contract_inds), reverse=True))
+
+ contract = _find_contraction(contract_inds, input_sets, output_set)
+ out_inds, input_sets, idx_removed, idx_contract = contract
+
+ cost = _compute_size_by_dict(idx_contract, dimension_dict)
+ if idx_removed:
+ cost *= 2
+ cost_list.append(cost)
+ scale_list.append(len(idx_contract))
+ size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
+
+ tmp_inputs = []
+ for x in contract_inds:
+ tmp_inputs.append(input_list.pop(x))
+
+ # Last contraction
+ if (cnum - len(path)) == -1:
+ idx_result = output_subscript
+ else:
+ sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
+ idx_result = "".join([x[1] for x in sorted(sort_result)])
+
+ input_list.append(idx_result)
+ einsum_str = ",".join(tmp_inputs) + "->" + idx_result
+
+ contraction = (contract_inds, idx_removed, einsum_str, input_list[:])
+ contraction_list.append(contraction)
+
+ opt_cost = sum(cost_list) + 1
+
+ if einsum_call_arg:
+ return (operands, contraction_list)
+
+ # Return the path along with a nice string representation
+ overall_contraction = input_subscripts + "->" + output_subscript
+ header = ("scaling", "current", "remaining")
+
+ speedup = naive_cost / opt_cost
+ max_i = max(size_list)
+
+ path_print = " Complete contraction: %s\n" % overall_contraction
+ path_print += " Naive scaling: %d\n" % len(indices)
+ path_print += " Optimized scaling: %d\n" % max(scale_list)
+ path_print += " Naive FLOP count: %.3e\n" % naive_cost
+ path_print += " Optimized FLOP count: %.3e\n" % opt_cost
+ path_print += " Theoretical speedup: %3.3f\n" % speedup
+ path_print += " Largest intermediate: %.3e elements\n" % max_i
+ path_print += "-" * 74 + "\n"
+ path_print += "%6s %24s %40s\n" % header
+ path_print += "-" * 74
+
+ for n, contraction in enumerate(contraction_list):
+ inds, idx_rm, einsum_str, remaining = contraction
+ remaining_str = ",".join(remaining) + "->" + output_subscript
+ path_run = (scale_list[n], einsum_str, remaining_str)
+ path_print += "\n%4d %24s %40s" % path_run
+
+ path = ['einsum_path'] + path
+ return (path, path_print)
+
+
+# Rewrite einsum to handle different cases
+def einsum(*operands, **kwargs):
+ """
+ einsum(subscripts, *operands, out=None, dtype=None, order='K',
+ casting='safe', optimize=False)
+
+ Evaluates the Einstein summation convention on the operands.
+
+ Using the Einstein summation convention, many common multi-dimensional
+ array operations can be represented in a simple fashion. This function
+ provides a way to compute such summations. The best way to understand this
+ function is to try the examples below, which show how many common NumPy
+ functions can be implemented as calls to `einsum`.
+
+ Parameters
+ ----------
+ subscripts : str
+ Specifies the subscripts for summation.
+ operands : list of array_like
+ These are the arrays for the operation.
+ out : {ndarray, None}, optional
+ If provided, the calculation is done into this array.
+ dtype : {data-type, None}, optional
+ If provided, forces the calculation to use the data type specified.
+ Note that you may have to also give a more liberal `casting`
+ parameter to allow the conversions. Default is None.
+ order : {'C', 'F', 'A', 'K'}, optional
+ Controls the memory layout of the output. 'C' means it should
+ be C contiguous. 'F' means it should be Fortran contiguous,
+ 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
+ 'K' means it should be as close to the layout as the inputs as
+ is possible, including arbitrarily permuted axes.
+ Default is 'K'.
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
+ Controls what kind of data casting may occur. Setting this to
+ 'unsafe' is not recommended, as it can adversely affect accumulations.
+
+ * 'no' means the data types should not be cast at all.
+ * 'equiv' means only byte-order changes are allowed.
+ * 'safe' means only casts which can preserve values are allowed.
+ * 'same_kind' means only safe casts or casts within a kind,
+ like float64 to float32, are allowed.
+ * 'unsafe' means any data conversions may be done.
+
+ Default is 'safe'.
+ optimize : {False, True, 'greedy', 'optimal'}, optional
+ Controls if intermediate optimization should occur. No optimization
+ will occur if False and True will default to the 'greedy' algorithm.
+ Also accepts an explicit contraction list from the ``np.einsum_path``
+ function. See ``np.einsum_path`` for more details. Default is False.
+
+ Returns
+ -------
+ output : ndarray
+ The calculation based on the Einstein summation convention.
+
+ See Also
+ --------
+ einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
+
+ Notes
+ -----
+ .. versionadded:: 1.6.0
+
+ The subscripts string is a comma-separated list of subscript labels,
+ where each label refers to a dimension of the corresponding operand.
+ Repeated subscripts labels in one operand take the diagonal. For example,
+ ``np.einsum('ii', a)`` is equivalent to ``np.trace(a)``.
+
+ Whenever a label is repeated, it is summed, so ``np.einsum('i,i', a, b)``
+ is equivalent to ``np.inner(a,b)``. If a label appears only once,
+ it is not summed, so ``np.einsum('i', a)`` produces a view of ``a``
+ with no changes.
+
+ The order of labels in the output is by default alphabetical. This
+ means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
+ ``np.einsum('ji', a)`` takes its transpose.
+
+ The output can be controlled by specifying output subscript labels
+ as well. This specifies the label order, and allows summing to
+ be disallowed or forced when desired. The call ``np.einsum('i->', a)``
+ is like ``np.sum(a, axis=-1)``, and ``np.einsum('ii->i', a)``
+ is like ``np.diag(a)``. The difference is that `einsum` does not
+ allow broadcasting by default.
+
+ To enable and control broadcasting, use an ellipsis. Default
+ NumPy-style broadcasting is done by adding an ellipsis
+ to the left of each term, like ``np.einsum('...ii->...i', a)``.
+ To take the trace along the first and last axes,
+ you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
+ product with the left-most indices instead of rightmost, you can do
+ ``np.einsum('ij...,jk...->ik...', a, b)``.
+
+ When there is only one operand, no axes are summed, and no output
+ parameter is provided, a view into the operand is returned instead
+ of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
+ produces a view.
+
+ An alternative way to provide the subscripts and operands is as
+ ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. The examples
+ below have corresponding `einsum` calls with the two parameter methods.
+
+ .. versionadded:: 1.10.0
+
+ Views returned from einsum are now writeable whenever the input array
+ is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
+ have the same effect as ``np.swapaxes(a, 0, 2)`` and
+ ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
+ of a 2D array.
+
+ .. versionadded:: 1.12.0
+
+ Added the ``optimize`` argument which will optimize the contraction order
+ of an einsum expression. For a contraction with three or more operands this
+ can greatly increase the computational efficiency at the cost of a larger
+ memory footprint during computation.
+
+ See ``np.einsum_path`` for more details.
+
+ Examples
+ --------
+ >>> a = np.arange(25).reshape(5,5)
+ >>> b = np.arange(5)
+ >>> c = np.arange(6).reshape(2,3)
+
+ >>> np.einsum('ii', a)
+ 60
+ >>> np.einsum(a, [0,0])
+ 60
+ >>> np.trace(a)
+ 60
+
+ >>> np.einsum('ii->i', a)
+ array([ 0, 6, 12, 18, 24])
+ >>> np.einsum(a, [0,0], [0])
+ array([ 0, 6, 12, 18, 24])
+ >>> np.diag(a)
+ array([ 0, 6, 12, 18, 24])
+
+ >>> np.einsum('ij,j', a, b)
+ array([ 30, 80, 130, 180, 230])
+ >>> np.einsum(a, [0,1], b, [1])
+ array([ 30, 80, 130, 180, 230])
+ >>> np.dot(a, b)
+ array([ 30, 80, 130, 180, 230])
+ >>> np.einsum('...j,j', a, b)
+ array([ 30, 80, 130, 180, 230])
+
+ >>> np.einsum('ji', c)
+ array([[0, 3],
+ [1, 4],
+ [2, 5]])
+ >>> np.einsum(c, [1,0])
+ array([[0, 3],
+ [1, 4],
+ [2, 5]])
+ >>> c.T
+ array([[0, 3],
+ [1, 4],
+ [2, 5]])
+
+ >>> np.einsum('..., ...', 3, c)
+ array([[ 0, 3, 6],
+ [ 9, 12, 15]])
+ >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
+ array([[ 0, 3, 6],
+ [ 9, 12, 15]])
+ >>> np.multiply(3, c)
+ array([[ 0, 3, 6],
+ [ 9, 12, 15]])
+
+ >>> np.einsum('i,i', b, b)
+ 30
+ >>> np.einsum(b, [0], b, [0])
+ 30
+ >>> np.inner(b,b)
+ 30
+
+ >>> np.einsum('i,j', np.arange(2)+1, b)
+ array([[0, 1, 2, 3, 4],
+ [0, 2, 4, 6, 8]])
+ >>> np.einsum(np.arange(2)+1, [0], b, [1])
+ array([[0, 1, 2, 3, 4],
+ [0, 2, 4, 6, 8]])
+ >>> np.outer(np.arange(2)+1, b)
+ array([[0, 1, 2, 3, 4],
+ [0, 2, 4, 6, 8]])
+
+ >>> np.einsum('i...->...', a)
+ array([50, 55, 60, 65, 70])
+ >>> np.einsum(a, [0,Ellipsis], [Ellipsis])
+ array([50, 55, 60, 65, 70])
+ >>> np.sum(a, axis=0)
+ array([50, 55, 60, 65, 70])
+
+ >>> a = np.arange(60.).reshape(3,4,5)
+ >>> b = np.arange(24.).reshape(4,3,2)
+ >>> np.einsum('ijk,jil->kl', a, b)
+ array([[ 4400., 4730.],
+ [ 4532., 4874.],
+ [ 4664., 5018.],
+ [ 4796., 5162.],
+ [ 4928., 5306.]])
+ >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
+ array([[ 4400., 4730.],
+ [ 4532., 4874.],
+ [ 4664., 5018.],
+ [ 4796., 5162.],
+ [ 4928., 5306.]])
+ >>> np.tensordot(a,b, axes=([1,0],[0,1]))
+ array([[ 4400., 4730.],
+ [ 4532., 4874.],
+ [ 4664., 5018.],
+ [ 4796., 5162.],
+ [ 4928., 5306.]])
+
+ >>> a = np.arange(6).reshape((3,2))
+ >>> b = np.arange(12).reshape((4,3))
+ >>> np.einsum('ki,jk->ij', a, b)
+ array([[10, 28, 46, 64],
+ [13, 40, 67, 94]])
+ >>> np.einsum('ki,...k->i...', a, b)
+ array([[10, 28, 46, 64],
+ [13, 40, 67, 94]])
+ >>> np.einsum('k...,jk', a, b)
+ array([[10, 28, 46, 64],
+ [13, 40, 67, 94]])
+
+ >>> # since version 1.10.0
+ >>> a = np.zeros((3, 3))
+ >>> np.einsum('ii->i', a)[:] = 1
+ >>> a
+ array([[ 1., 0., 0.],
+ [ 0., 1., 0.],
+ [ 0., 0., 1.]])
+
+ """
+
+ # Grab non-einsum kwargs
+ optimize_arg = kwargs.pop('optimize', False)
+
+ # If no optimization, run pure einsum
+ if optimize_arg is False:
+ return c_einsum(*operands, **kwargs)
+
+ valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
+ einsum_kwargs = {k: v for (k, v) in kwargs.items() if
+ k in valid_einsum_kwargs}
+
+ # Make sure all keywords are valid
+ valid_contract_kwargs = ['optimize'] + valid_einsum_kwargs
+ unknown_kwargs = [k for (k, v) in kwargs.items() if
+ k not in valid_contract_kwargs]
+
+ if len(unknown_kwargs):
+ raise TypeError("Did not understand the following kwargs: %s"
+ % unknown_kwargs)
+
+ # Special handeling if out is specified
+ specified_out = False
+ out_array = einsum_kwargs.pop('out', None)
+ if out_array is not None:
+ specified_out = True
+
+ # Build the contraction list and operand
+ operands, contraction_list = einsum_path(*operands, optimize=optimize_arg,
+ einsum_call=True)
+ # Start contraction loop
+ for num, contraction in enumerate(contraction_list):
+ inds, idx_rm, einsum_str, remaining = contraction
+ tmp_operands = []
+ for x in inds:
+ tmp_operands.append(operands.pop(x))
+
+ # If out was specified
+ if specified_out and ((num + 1) == len(contraction_list)):
+ einsum_kwargs["out"] = out_array
+
+ # Do the contraction
+ new_view = c_einsum(einsum_str, *tmp_operands, **einsum_kwargs)
+
+ # Append new items and derefernce what we can
+ operands.append(new_view)
+ del tmp_operands, new_view
+
+ if specified_out:
+ return out_array
+ else:
+ return operands[0]
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 81d9d3697..82f7b081f 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -12,7 +12,7 @@ from .multiarray import (
_fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS,
BUFSIZE, CLIP, MAXDIMS, MAY_SHARE_BOUNDS, MAY_SHARE_EXACT, RAISE,
WRAP, arange, array, broadcast, can_cast, compare_chararrays,
- concatenate, copyto, count_nonzero, dot, dtype, einsum, empty,
+ concatenate, copyto, count_nonzero, dot, dtype, empty,
empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring,
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
@@ -53,7 +53,7 @@ __all__ = [
'min_scalar_type', 'result_type', 'asarray', 'asanyarray',
'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like',
'zeros_like', 'ones_like', 'correlate', 'convolve', 'inner', 'dot',
- 'einsum', 'outer', 'vdot', 'alterdot', 'restoredot', 'roll',
+ 'outer', 'vdot', 'alterdot', 'restoredot', 'roll',
'rollaxis', 'moveaxis', 'cross', 'tensordot', 'array2string',
'get_printoptions', 'set_printoptions', 'array_repr', 'array_str',
'set_string_function', 'little_endian', 'require', 'fromiter',
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 7c3c95b24..fb646b336 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4139,7 +4139,7 @@ static struct PyMethodDef array_module_methods[] = {
{"matmul",
(PyCFunction)array_matmul,
METH_VARARGS | METH_KEYWORDS, NULL},
- {"einsum",
+ {"c_einsum",
(PyCFunction)array_einsum,
METH_VARARGS|METH_KEYWORDS, NULL},
{"_fastCopyAndTranspose",
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index c31d281e9..3ecc829f4 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -3,282 +3,308 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_equal, assert_array_equal,
- assert_raises, suppress_warnings
+ assert_almost_equal, assert_raises, suppress_warnings
)
+# Setup for optimize einsum
+chars = 'abcdefghij'
+sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3])
+global_size_dict = {}
+for size, char in zip(sizes, chars):
+ global_size_dict[char] = size
+
class TestEinSum(TestCase):
def test_einsum_errors(self):
- # Need enough arguments
- assert_raises(ValueError, np.einsum)
- assert_raises(ValueError, np.einsum, "")
-
- # subscripts must be a string
- assert_raises(TypeError, np.einsum, 0, 0)
-
- # out parameter must be an array
- assert_raises(TypeError, np.einsum, "", 0, out='test')
-
- # order parameter must be a valid order
- assert_raises(TypeError, np.einsum, "", 0, order='W')
-
- # casting parameter must be a valid casting
- assert_raises(ValueError, np.einsum, "", 0, casting='blah')
-
- # dtype parameter must be a valid dtype
- assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type')
-
- # other keyword arguments are rejected
- assert_raises(TypeError, np.einsum, "", 0, bad_arg=0)
-
- # issue 4528 revealed a segfault with this call
- assert_raises(TypeError, np.einsum, *(None,)*63)
-
- # number of operands must match count in subscripts string
- assert_raises(ValueError, np.einsum, "", 0, 0)
- assert_raises(ValueError, np.einsum, ",", 0, [0], [0])
- assert_raises(ValueError, np.einsum, ",", [0])
-
- # can't have more subscripts than dimensions in the operand
- assert_raises(ValueError, np.einsum, "i", 0)
- assert_raises(ValueError, np.einsum, "ij", [0, 0])
- assert_raises(ValueError, np.einsum, "...i", 0)
- assert_raises(ValueError, np.einsum, "i...j", [0, 0])
- assert_raises(ValueError, np.einsum, "i...", 0)
- assert_raises(ValueError, np.einsum, "ij...", [0, 0])
-
- # invalid ellipsis
- assert_raises(ValueError, np.einsum, "i..", [0, 0])
- assert_raises(ValueError, np.einsum, ".i...", [0, 0])
- assert_raises(ValueError, np.einsum, "j->..j", [0, 0])
- assert_raises(ValueError, np.einsum, "j->.j...", [0, 0])
-
- # invalid subscript character
- assert_raises(ValueError, np.einsum, "i%...", [0, 0])
- assert_raises(ValueError, np.einsum, "...j$", [0, 0])
- assert_raises(ValueError, np.einsum, "i->&", [0, 0])
-
- # output subscripts must appear in input
- assert_raises(ValueError, np.einsum, "i->ij", [0, 0])
-
- # output subscripts may only be specified once
- assert_raises(ValueError, np.einsum, "ij->jij", [[0, 0], [0, 0]])
-
- # dimensions much match when being collapsed
- assert_raises(ValueError, np.einsum, "ii", np.arange(6).reshape(2, 3))
- assert_raises(ValueError, np.einsum, "ii->i", np.arange(6).reshape(2, 3))
-
- # broadcasting to new dimensions must be enabled explicitly
- assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2, 3))
- assert_raises(ValueError, np.einsum, "i->i", [[0, 1], [0, 1]],
- out=np.arange(4).reshape(2, 2))
+ for do_opt in [True, False]:
+ # Need enough arguments
+ assert_raises(ValueError, np.einsum, optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "", optimize=do_opt)
+
+ # subscripts must be a string
+ assert_raises(TypeError, np.einsum, 0, 0, optimize=do_opt)
+
+ # out parameter must be an array
+ assert_raises(TypeError, np.einsum, "", 0, out='test',
+ optimize=do_opt)
+
+ # order parameter must be a valid order
+ assert_raises(TypeError, np.einsum, "", 0, order='W',
+ optimize=do_opt)
+
+ # casting parameter must be a valid casting
+ assert_raises(ValueError, np.einsum, "", 0, casting='blah',
+ optimize=do_opt)
+
+ # dtype parameter must be a valid dtype
+ assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type',
+ optimize=do_opt)
+
+ # other keyword arguments are rejected
+ assert_raises(TypeError, np.einsum, "", 0, bad_arg=0,
+ optimize=do_opt)
+
+ # issue 4528 revealed a segfault with this call
+ assert_raises(TypeError, np.einsum, *(None,)*63, optimize=do_opt)
+
+ # number of operands must match count in subscripts string
+ assert_raises(ValueError, np.einsum, "", 0, 0, optimize=do_opt)
+ assert_raises(ValueError, np.einsum, ",", 0, [0], [0],
+ optimize=do_opt)
+ assert_raises(ValueError, np.einsum, ",", [0], optimize=do_opt)
+
+ # can't have more subscripts than dimensions in the operand
+ assert_raises(ValueError, np.einsum, "i", 0, optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "ij", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "...i", 0, optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "i...j", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "i...", 0, optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "ij...", [0, 0], optimize=do_opt)
+
+ # invalid ellipsis
+ assert_raises(ValueError, np.einsum, "i..", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, ".i...", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "j->..j", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "j->.j...", [0, 0], optimize=do_opt)
+
+ # invalid subscript character
+ assert_raises(ValueError, np.einsum, "i%...", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "...j$", [0, 0], optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "i->&", [0, 0], optimize=do_opt)
+
+ # output subscripts must appear in input
+ assert_raises(ValueError, np.einsum, "i->ij", [0, 0], optimize=do_opt)
+
+ # output subscripts may only be specified once
+ assert_raises(ValueError, np.einsum, "ij->jij", [[0, 0], [0, 0]],
+ optimize=do_opt)
+
+ # dimensions much match when being collapsed
+ assert_raises(ValueError, np.einsum, "ii",
+ np.arange(6).reshape(2, 3), optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "ii->i",
+ np.arange(6).reshape(2, 3), optimize=do_opt)
+
+ # broadcasting to new dimensions must be enabled explicitly
+ assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2, 3),
+ optimize=do_opt)
+ assert_raises(ValueError, np.einsum, "i->i", [[0, 1], [0, 1]],
+ out=np.arange(4).reshape(2, 2), optimize=do_opt)
def test_einsum_views(self):
# pass-through
- a = np.arange(6)
- a.shape = (2, 3)
+ for do_opt in [True, False]:
+ a = np.arange(6)
+ a.shape = (2, 3)
+
+ b = np.einsum("...", a, optimize=do_opt)
+ assert_(b.base is a)
+
+ b = np.einsum(a, [Ellipsis], optimize=do_opt)
+ assert_(b.base is a)
+
+ b = np.einsum("ij", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, a)
+
+ b = np.einsum(a, [0, 1], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, a)
+
+ # output is writeable whenever input is writeable
+ b = np.einsum("...", a, optimize=do_opt)
+ assert_(b.flags['WRITEABLE'])
+ a.flags['WRITEABLE'] = False
+ b = np.einsum("...", a, optimize=do_opt)
+ assert_(not b.flags['WRITEABLE'])
+
+ # transpose
+ a = np.arange(6)
+ a.shape = (2, 3)
+
+ b = np.einsum("ji", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, a.T)
+
+ b = np.einsum(a, [1, 0], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, a.T)
+
+ # diagonal
+ a = np.arange(9)
+ a.shape = (3, 3)
+
+ b = np.einsum("ii->i", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[i, i] for i in range(3)])
+
+ b = np.einsum(a, [0, 0], [0], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[i, i] for i in range(3)])
+
+ # diagonal with various ways of broadcasting an additional dimension
+ a = np.arange(27)
+ a.shape = (3, 3, 3)
+
+ b = np.einsum("...ii->...i", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [[x[i, i] for i in range(3)] for x in a])
+
+ b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [[x[i, i] for i in range(3)] for x in a])
+
+ b = np.einsum("ii...->...i", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [[x[i, i] for i in range(3)]
+ for x in a.transpose(2, 0, 1)])
+
+ b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [[x[i, i] for i in range(3)]
+ for x in a.transpose(2, 0, 1)])
+
+ b = np.einsum("...ii->i...", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[:, i, i] for i in range(3)])
+
+ b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[:, i, i] for i in range(3)])
+
+ b = np.einsum("jii->ij", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[:, i, i] for i in range(3)])
+
+ b = np.einsum(a, [1, 0, 0], [0, 1], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[:, i, i] for i in range(3)])
+
+ b = np.einsum("ii...->i...", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])
+
+ b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])
+
+ b = np.einsum("i...i->i...", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])
+
+ b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])
+
+ b = np.einsum("i...i->...i", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [[x[i, i] for i in range(3)]
+ for x in a.transpose(1, 0, 2)])
+
+ b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [[x[i, i] for i in range(3)]
+ for x in a.transpose(1, 0, 2)])
+
+ # triple diagonal
+ a = np.arange(27)
+ a.shape = (3, 3, 3)
+
+ b = np.einsum("iii->i", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[i, i, i] for i in range(3)])
+
+ b = np.einsum(a, [0, 0, 0], [0], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, [a[i, i, i] for i in range(3)])
- b = np.einsum("...", a)
- assert_(b.base is a)
-
- b = np.einsum(a, [Ellipsis])
- assert_(b.base is a)
-
- b = np.einsum("ij", a)
- assert_(b.base is a)
- assert_equal(b, a)
-
- b = np.einsum(a, [0, 1])
- assert_(b.base is a)
- assert_equal(b, a)
-
- # output is writeable whenever input is writeable
- b = np.einsum("...", a)
- assert_(b.flags['WRITEABLE'])
- a.flags['WRITEABLE'] = False
- b = np.einsum("...", a)
- assert_(not b.flags['WRITEABLE'])
-
- # transpose
- a = np.arange(6)
- a.shape = (2, 3)
-
- b = np.einsum("ji", a)
- assert_(b.base is a)
- assert_equal(b, a.T)
-
- b = np.einsum(a, [1, 0])
- assert_(b.base is a)
- assert_equal(b, a.T)
-
- # diagonal
- a = np.arange(9)
- a.shape = (3, 3)
-
- b = np.einsum("ii->i", a)
- assert_(b.base is a)
- assert_equal(b, [a[i, i] for i in range(3)])
-
- b = np.einsum(a, [0, 0], [0])
- assert_(b.base is a)
- assert_equal(b, [a[i, i] for i in range(3)])
-
- # diagonal with various ways of broadcasting an additional dimension
- a = np.arange(27)
- a.shape = (3, 3, 3)
-
- b = np.einsum("...ii->...i", a)
- assert_(b.base is a)
- assert_equal(b, [[x[i, i] for i in range(3)] for x in a])
-
- b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0])
- assert_(b.base is a)
- assert_equal(b, [[x[i, i] for i in range(3)] for x in a])
-
- b = np.einsum("ii...->...i", a)
- assert_(b.base is a)
- assert_equal(b, [[x[i, i] for i in range(3)]
- for x in a.transpose(2, 0, 1)])
-
- b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0])
- assert_(b.base is a)
- assert_equal(b, [[x[i, i] for i in range(3)]
- for x in a.transpose(2, 0, 1)])
-
- b = np.einsum("...ii->i...", a)
- assert_(b.base is a)
- assert_equal(b, [a[:, i, i] for i in range(3)])
-
- b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis])
- assert_(b.base is a)
- assert_equal(b, [a[:, i, i] for i in range(3)])
-
- b = np.einsum("jii->ij", a)
- assert_(b.base is a)
- assert_equal(b, [a[:, i, i] for i in range(3)])
-
- b = np.einsum(a, [1, 0, 0], [0, 1])
- assert_(b.base is a)
- assert_equal(b, [a[:, i, i] for i in range(3)])
-
- b = np.einsum("ii...->i...", a)
- assert_(b.base is a)
- assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])
-
- b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis])
- assert_(b.base is a)
- assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])
-
- b = np.einsum("i...i->i...", a)
- assert_(b.base is a)
- assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])
-
- b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis])
- assert_(b.base is a)
- assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])
-
- b = np.einsum("i...i->...i", a)
- assert_(b.base is a)
- assert_equal(b, [[x[i, i] for i in range(3)]
- for x in a.transpose(1, 0, 2)])
-
- b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0])
- assert_(b.base is a)
- assert_equal(b, [[x[i, i] for i in range(3)]
- for x in a.transpose(1, 0, 2)])
-
- # triple diagonal
- a = np.arange(27)
- a.shape = (3, 3, 3)
-
- b = np.einsum("iii->i", a)
- assert_(b.base is a)
- assert_equal(b, [a[i, i, i] for i in range(3)])
-
- b = np.einsum(a, [0, 0, 0], [0])
- assert_(b.base is a)
- assert_equal(b, [a[i, i, i] for i in range(3)])
+ # swap axes
+ a = np.arange(24)
+ a.shape = (2, 3, 4)
- # swap axes
- a = np.arange(24)
- a.shape = (2, 3, 4)
+ b = np.einsum("ijk->jik", a, optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, a.swapaxes(0, 1))
- b = np.einsum("ijk->jik", a)
- assert_(b.base is a)
- assert_equal(b, a.swapaxes(0, 1))
+ b = np.einsum(a, [0, 1, 2], [1, 0, 2], optimize=do_opt)
+ assert_(b.base is a)
+ assert_equal(b, a.swapaxes(0, 1))
- b = np.einsum(a, [0, 1, 2], [1, 0, 2])
- assert_(b.base is a)
- assert_equal(b, a.swapaxes(0, 1))
-
- def check_einsum_sums(self, dtype):
+ def check_einsum_sums(self, dtype, do_opt=False):
# Check various sums. Does many sizes to exercise unrolled loops.
# sum(a, axis=-1)
for n in range(1, 17):
a = np.arange(n, dtype=dtype)
- assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype))
- assert_equal(np.einsum(a, [0], []),
+ assert_equal(np.einsum("i->", a, optimize=do_opt),
+ np.sum(a, axis=-1).astype(dtype))
+ assert_equal(np.einsum(a, [0], [], optimize=do_opt),
np.sum(a, axis=-1).astype(dtype))
for n in range(1, 17):
a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
- assert_equal(np.einsum("...i->...", a),
+ assert_equal(np.einsum("...i->...", a, optimize=do_opt),
np.sum(a, axis=-1).astype(dtype))
- assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis]),
+ assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt),
np.sum(a, axis=-1).astype(dtype))
# sum(a, axis=0)
for n in range(1, 17):
a = np.arange(2*n, dtype=dtype).reshape(2, n)
- assert_equal(np.einsum("i...->...", a),
+ assert_equal(np.einsum("i...->...", a, optimize=do_opt),
np.sum(a, axis=0).astype(dtype))
- assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis]),
+ assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
np.sum(a, axis=0).astype(dtype))
for n in range(1, 17):
a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
- assert_equal(np.einsum("i...->...", a),
+ assert_equal(np.einsum("i...->...", a, optimize=do_opt),
np.sum(a, axis=0).astype(dtype))
- assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis]),
+ assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
np.sum(a, axis=0).astype(dtype))
# trace(a)
for n in range(1, 17):
a = np.arange(n*n, dtype=dtype).reshape(n, n)
- assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype))
- assert_equal(np.einsum(a, [0, 0]), np.trace(a).astype(dtype))
+ assert_equal(np.einsum("ii", a, optimize=do_opt),
+ np.trace(a).astype(dtype))
+ assert_equal(np.einsum(a, [0, 0], optimize=do_opt),
+ np.trace(a).astype(dtype))
# multiply(a, b)
assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case
for n in range(1, 17):
- a = np.arange(3*n, dtype=dtype).reshape(3, n)
- b = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
- assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b))
- assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]),
+ a = np.arange(3 * n, dtype=dtype).reshape(3, n)
+ b = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
+ assert_equal(np.einsum("..., ...", a, b, optimize=do_opt),
+ np.multiply(a, b))
+ assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis], optimize=do_opt),
np.multiply(a, b))
# inner(a,b)
for n in range(1, 17):
- a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
+ a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
b = np.arange(n, dtype=dtype)
- assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b))
- assert_equal(np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0]),
+ assert_equal(np.einsum("...i, ...i", a, b, optimize=do_opt), np.inner(a, b))
+ assert_equal(np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0], optimize=do_opt),
np.inner(a, b))
for n in range(1, 11):
- a = np.arange(n*3*2, dtype=dtype).reshape(n, 3, 2)
+ a = np.arange(n * 3 * 2, dtype=dtype).reshape(n, 3, 2)
b = np.arange(n, dtype=dtype)
- assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T)
- assert_equal(np.einsum(a, [0, Ellipsis], b, [0, Ellipsis]),
+ assert_equal(np.einsum("i..., i...", a, b, optimize=do_opt),
+ np.inner(a.T, b.T).T)
+ assert_equal(np.einsum(a, [0, Ellipsis], b, [0, Ellipsis], optimize=do_opt),
np.inner(a.T, b.T).T)
# outer(a,b)
for n in range(1, 17):
a = np.arange(3, dtype=dtype)+1
b = np.arange(n, dtype=dtype)+1
- assert_equal(np.einsum("i,j", a, b), np.outer(a, b))
- assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b))
+ assert_equal(np.einsum("i,j", a, b, optimize=do_opt),
+ np.outer(a, b))
+ assert_equal(np.einsum(a, [0], b, [1], optimize=do_opt),
+ np.outer(a, b))
# Suppress the complex warnings for the 'as f8' tests
with suppress_warnings() as sup:
@@ -288,62 +314,70 @@ class TestEinSum(TestCase):
for n in range(1, 17):
a = np.arange(4*n, dtype=dtype).reshape(4, n)
b = np.arange(n, dtype=dtype)
- assert_equal(np.einsum("ij, j", a, b), np.dot(a, b))
- assert_equal(np.einsum(a, [0, 1], b, [1]), np.dot(a, b))
+ assert_equal(np.einsum("ij, j", a, b, optimize=do_opt),
+ np.dot(a, b))
+ assert_equal(np.einsum(a, [0, 1], b, [1], optimize=do_opt),
+ np.dot(a, b))
c = np.arange(4, dtype=dtype)
np.einsum("ij,j", a, b, out=c,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c,
- np.dot(a.astype('f8'),
- b.astype('f8')).astype(dtype))
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
c[...] = 0
np.einsum(a, [0, 1], b, [1], out=c,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c,
- np.dot(a.astype('f8'),
- b.astype('f8')).astype(dtype))
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
for n in range(1, 17):
a = np.arange(4*n, dtype=dtype).reshape(4, n)
b = np.arange(n, dtype=dtype)
- assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T))
- assert_equal(np.einsum(a.T, [1, 0], b.T, [1]), np.dot(b.T, a.T))
+ assert_equal(np.einsum("ji,j", a.T, b.T, optimize=do_opt),
+ np.dot(b.T, a.T))
+ assert_equal(np.einsum(a.T, [1, 0], b.T, [1], optimize=do_opt),
+ np.dot(b.T, a.T))
c = np.arange(4, dtype=dtype)
- np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe')
+ np.einsum("ji,j", a.T, b.T, out=c,
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c,
- np.dot(b.T.astype('f8'),
- a.T.astype('f8')).astype(dtype))
+ np.dot(b.T.astype('f8'),
+ a.T.astype('f8')).astype(dtype))
c[...] = 0
np.einsum(a.T, [1, 0], b.T, [1], out=c,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c,
- np.dot(b.T.astype('f8'),
- a.T.astype('f8')).astype(dtype))
+ np.dot(b.T.astype('f8'),
+ a.T.astype('f8')).astype(dtype))
# matmat(a,b) / a.dot(b) where a is matrix, b is matrix
for n in range(1, 17):
if n < 8 or dtype != 'f2':
a = np.arange(4*n, dtype=dtype).reshape(4, n)
b = np.arange(n*6, dtype=dtype).reshape(n, 6)
- assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b))
- assert_equal(np.einsum(a, [0, 1], b, [1, 2]), np.dot(a, b))
+ assert_equal(np.einsum("ij,jk", a, b, optimize=do_opt),
+ np.dot(a, b))
+ assert_equal(np.einsum(a, [0, 1], b, [1, 2], optimize=do_opt),
+ np.dot(a, b))
for n in range(1, 17):
a = np.arange(4*n, dtype=dtype).reshape(4, n)
b = np.arange(n*6, dtype=dtype).reshape(n, 6)
c = np.arange(24, dtype=dtype).reshape(4, 6)
- np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe')
+ np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe',
+ optimize=do_opt)
assert_equal(c,
- np.dot(a.astype('f8'),
- b.astype('f8')).astype(dtype))
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
c[...] = 0
np.einsum(a, [0, 1], b, [1, 2], out=c,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c,
- np.dot(a.astype('f8'),
- b.astype('f8')).astype(dtype))
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
# matrix triple product (note this is not currently an efficient
# way to multiply 3 matrices)
@@ -351,21 +385,21 @@ class TestEinSum(TestCase):
b = np.arange(20, dtype=dtype).reshape(4, 5)
c = np.arange(30, dtype=dtype).reshape(5, 6)
if dtype != 'f2':
- assert_equal(np.einsum("ij,jk,kl", a, b, c),
- a.dot(b).dot(c))
- assert_equal(np.einsum(a, [0, 1], b, [1, 2], c, [2, 3]),
- a.dot(b).dot(c))
+ assert_equal(np.einsum("ij,jk,kl", a, b, c, optimize=do_opt),
+ a.dot(b).dot(c))
+ assert_equal(np.einsum(a, [0, 1], b, [1, 2], c, [2, 3],
+ optimize=do_opt), a.dot(b).dot(c))
d = np.arange(18, dtype=dtype).reshape(3, 6)
np.einsum("ij,jk,kl", a, b, c, out=d,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
tgt = a.astype('f8').dot(b.astype('f8'))
tgt = tgt.dot(c.astype('f8')).astype(dtype)
assert_equal(d, tgt)
d[...] = 0
np.einsum(a, [0, 1], b, [1, 2], c, [2, 3], out=d,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
tgt = a.astype('f8').dot(b.astype('f8'))
tgt = tgt.dot(c.astype('f8')).astype(dtype)
assert_equal(d, tgt)
@@ -375,31 +409,31 @@ class TestEinSum(TestCase):
a = np.arange(60, dtype=dtype).reshape(3, 4, 5)
b = np.arange(24, dtype=dtype).reshape(4, 3, 2)
assert_equal(np.einsum("ijk, jil -> kl", a, b),
- np.tensordot(a, b, axes=([1, 0], [0, 1])))
+ np.tensordot(a, b, axes=([1, 0], [0, 1])))
assert_equal(np.einsum(a, [0, 1, 2], b, [1, 0, 3], [2, 3]),
- np.tensordot(a, b, axes=([1, 0], [0, 1])))
+ np.tensordot(a, b, axes=([1, 0], [0, 1])))
c = np.arange(10, dtype=dtype).reshape(5, 2)
np.einsum("ijk,jil->kl", a, b, out=c,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'),
- axes=([1, 0], [0, 1])).astype(dtype))
+ axes=([1, 0], [0, 1])).astype(dtype))
c[...] = 0
np.einsum(a, [0, 1, 2], b, [1, 0, 3], [2, 3], out=c,
- dtype='f8', casting='unsafe')
+ dtype='f8', casting='unsafe', optimize=do_opt)
assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'),
- axes=([1, 0], [0, 1])).astype(dtype))
+ axes=([1, 0], [0, 1])).astype(dtype))
# logical_and(logical_and(a!=0, b!=0), c!=0)
a = np.array([1, 3, -2, 0, 12, 13, 0, 1], dtype=dtype)
b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12], dtype=dtype)
c = np.array([True, True, False, True, True, False, True, True])
assert_equal(np.einsum("i,i,i->i", a, b, c,
- dtype='?', casting='unsafe'),
- np.logical_and(np.logical_and(a != 0, b != 0), c != 0))
+ dtype='?', casting='unsafe', optimize=do_opt),
+ np.logical_and(np.logical_and(a != 0, b != 0), c != 0))
assert_equal(np.einsum(a, [0], b, [0], c, [0], [0],
- dtype='?', casting='unsafe'),
- np.logical_and(np.logical_and(a != 0, b != 0), c != 0))
+ dtype='?', casting='unsafe'),
+ np.logical_and(np.logical_and(a != 0, b != 0), c != 0))
a = np.arange(9, dtype=dtype)
assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a))
@@ -411,21 +445,24 @@ class TestEinSum(TestCase):
for n in range(1, 25):
a = np.arange(n, dtype=dtype)
if np.dtype(dtype).itemsize > 1:
- assert_equal(np.einsum("...,...", a, a), np.multiply(a, a))
- assert_equal(np.einsum("i,i", a, a), np.dot(a, a))
- assert_equal(np.einsum("i,->i", a, 2), 2*a)
- assert_equal(np.einsum(",i->i", 2, a), 2*a)
- assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a))
- assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a))
-
- assert_equal(np.einsum("...,...", a[1:], a[:-1]),
+ assert_equal(np.einsum("...,...", a, a, optimize=do_opt),
+ np.multiply(a, a))
+ assert_equal(np.einsum("i,i", a, a, optimize=do_opt), np.dot(a, a))
+ assert_equal(np.einsum("i,->i", a, 2, optimize=do_opt), 2*a)
+ assert_equal(np.einsum(",i->i", 2, a, optimize=do_opt), 2*a)
+ assert_equal(np.einsum("i,->", a, 2, optimize=do_opt), 2*np.sum(a))
+ assert_equal(np.einsum(",i->", 2, a, optimize=do_opt), 2*np.sum(a))
+
+ assert_equal(np.einsum("...,...", a[1:], a[:-1], optimize=do_opt),
np.multiply(a[1:], a[:-1]))
- assert_equal(np.einsum("i,i", a[1:], a[:-1]),
+ assert_equal(np.einsum("i,i", a[1:], a[:-1], optimize=do_opt),
np.dot(a[1:], a[:-1]))
- assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:])
- assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:])
- assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:]))
- assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:]))
+ assert_equal(np.einsum("i,->i", a[1:], 2, optimize=do_opt), 2*a[1:])
+ assert_equal(np.einsum(",i->i", 2, a[1:], optimize=do_opt), 2*a[1:])
+ assert_equal(np.einsum("i,->", a[1:], 2, optimize=do_opt),
+ 2*np.sum(a[1:]))
+ assert_equal(np.einsum(",i->", 2, a[1:], optimize=do_opt),
+ 2*np.sum(a[1:]))
# An object array, summed as the data type
a = np.arange(9, dtype=object)
@@ -458,9 +495,11 @@ class TestEinSum(TestCase):
def test_einsum_sums_int32(self):
self.check_einsum_sums('i4')
+ self.check_einsum_sums('i4', True)
def test_einsum_sums_uint32(self):
self.check_einsum_sums('u4')
+ self.check_einsum_sums('u4', True)
def test_einsum_sums_int64(self):
self.check_einsum_sums('i8')
@@ -476,12 +515,14 @@ class TestEinSum(TestCase):
def test_einsum_sums_float64(self):
self.check_einsum_sums('f8')
+ self.check_einsum_sums('f8', True)
def test_einsum_sums_longdouble(self):
self.check_einsum_sums(np.longdouble)
def test_einsum_sums_cfloat64(self):
self.check_einsum_sums('c8')
+ self.check_einsum_sums('c8', True)
def test_einsum_sums_cfloat128(self):
self.check_einsum_sums('c16')
@@ -495,12 +536,15 @@ class TestEinSum(TestCase):
a = np.ones((1, 2))
b = np.ones((2, 2, 1))
assert_equal(np.einsum('ij...,j...->i...', a, b), [[[2], [2]]])
+ assert_equal(np.einsum('ij...,j...->i...', a, b, optimize=True), [[[2], [2]]])
# The iterator had an issue with buffering this reduction
a = np.ones((5, 12, 4, 2, 3), np.int64)
b = np.ones((5, 12, 11), np.int64)
assert_equal(np.einsum('ijklm,ijn,ijn->', a, b, b),
- np.einsum('ijklm,ijn->', a, b))
+ np.einsum('ijklm,ijn->', a, b))
+ assert_equal(np.einsum('ijklm,ijn,ijn->', a, b, b, optimize=True),
+ np.einsum('ijklm,ijn->', a, b, optimize=True))
# Issue #2027, was a problem in the contiguous 3-argument
# inner loop implementation
@@ -508,8 +552,11 @@ class TestEinSum(TestCase):
b = np.arange(1, 5).reshape(2, 2)
c = np.arange(1, 9).reshape(4, 2)
assert_equal(np.einsum('x,yx,zx->xzy', a, b, c),
- [[[1, 3], [3, 9], [5, 15], [7, 21]],
- [[8, 16], [16, 32], [24, 48], [32, 64]]])
+ [[[1, 3], [3, 9], [5, 15], [7, 21]],
+ [[8, 16], [16, 32], [24, 48], [32, 64]]])
+ assert_equal(np.einsum('x,yx,zx->xzy', a, b, c, optimize=True),
+ [[[1, 3], [3, 9], [5, 15], [7, 21]],
+ [[8, 16], [16, 32], [24, 48], [32, 64]]])
def test_einsum_broadcast(self):
# Issue #2455 change in handling ellipsis
@@ -517,23 +564,33 @@ class TestEinSum(TestCase):
# only use the 'RIGHT' iteration in prepare_op_axes
# adds auto broadcast on left where it belongs
# broadcast on right has to be explicit
+ # We need to test the optimized parsing as well
- A = np.arange(2*3*4).reshape(2,3,4)
+ A = np.arange(2 * 3 * 4).reshape(2, 3, 4)
B = np.arange(3)
- ref = np.einsum('ijk,j->ijk',A, B)
- assert_equal(np.einsum('ij...,j...->ij...',A, B), ref)
- assert_equal(np.einsum('ij...,...j->ij...',A, B), ref)
- assert_equal(np.einsum('ij...,j->ij...',A, B), ref) # used to raise error
+ ref = np.einsum('ijk,j->ijk', A, B)
+ assert_equal(np.einsum('ij...,j...->ij...', A, B), ref)
+ assert_equal(np.einsum('ij...,...j->ij...', A, B), ref)
+ assert_equal(np.einsum('ij...,j->ij...', A, B), ref) # used to raise error
+
+ assert_equal(np.einsum('ij...,j...->ij...', A, B, optimize=True), ref)
+ assert_equal(np.einsum('ij...,...j->ij...', A, B, optimize=True), ref)
+ assert_equal(np.einsum('ij...,j->ij...', A, B, optimize=True), ref) # used to raise error
- A = np.arange(12).reshape((4,3))
- B = np.arange(6).reshape((3,2))
+ A = np.arange(12).reshape((4, 3))
+ B = np.arange(6).reshape((3, 2))
ref = np.einsum('ik,kj->ij', A, B)
assert_equal(np.einsum('ik...,k...->i...', A, B), ref)
assert_equal(np.einsum('ik...,...kj->i...j', A, B), ref)
assert_equal(np.einsum('...k,kj', A, B), ref) # used to raise error
assert_equal(np.einsum('ik,k...->i...', A, B), ref) # used to raise error
- dims = [2,3,4,5]
+ assert_equal(np.einsum('ik...,k...->i...', A, B, optimize=True), ref)
+ assert_equal(np.einsum('ik...,...kj->i...j', A, B, optimize=True), ref)
+ assert_equal(np.einsum('...k,kj', A, B, optimize=True), ref) # used to raise error
+ assert_equal(np.einsum('ik,k...->i...', A, B, optimize=True), ref) # used to raise error
+
+ dims = [2, 3, 4, 5]
a = np.arange(np.prod(dims)).reshape(dims)
v = np.arange(dims[2])
ref = np.einsum('ijkl,k->ijl', a, v)
@@ -542,11 +599,17 @@ class TestEinSum(TestCase):
assert_equal(np.einsum('...kl,k...', a, v), ref)
# no real diff from 1st
- J,K,M = 160,160,120
- A = np.arange(J*K*M).reshape(1,1,1,J,K,M)
- B = np.arange(J*K*M*3).reshape(J,K,M,3)
+ assert_equal(np.einsum('ijkl,k', a, v, optimize=True), ref)
+ assert_equal(np.einsum('...kl,k', a, v, optimize=True), ref) # used to raise error
+ assert_equal(np.einsum('...kl,k...', a, v, optimize=True), ref)
+
+ J, K, M = 160, 160, 120
+ A = np.arange(J * K * M).reshape(1, 1, 1, J, K, M)
+ B = np.arange(J * K * M * 3).reshape(J, K, M, 3)
ref = np.einsum('...lmn,...lmno->...o', A, B)
assert_equal(np.einsum('...lmn,lmno->...o', A, B), ref) # used to raise error
+ assert_equal(np.einsum('...lmn,lmno->...o', A, B,
+ optimize=True), ref) # used to raise error
def test_einsum_fixedstridebug(self):
# Issue #4485 obscure einsum bug
@@ -565,22 +628,22 @@ class TestEinSum(TestCase):
# used by einsum, is 8192, and 3*2731 = 8193, is larger than that
# and results in a mismatch between the buffering and the
# striding for operand A.
- A = np.arange(2*3).reshape(2,3).astype(np.float32)
- B = np.arange(2*3*2731).reshape(2,3,2731).astype(np.int16)
- es = np.einsum('cl,cpx->lpx', A, B)
- tp = np.tensordot(A, B, axes=(0, 0))
- assert_equal(es, tp)
+ A = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
+ B = np.arange(2 * 3 * 2731).reshape(2, 3, 2731).astype(np.int16)
+ es = np.einsum('cl, cpx->lpx', A, B)
+ tp = np.tensordot(A, B, axes=(0, 0))
+ assert_equal(es, tp)
# The following is the original test case from the bug report,
# made repeatable by changing random arrays to aranges.
- A = np.arange(3*3).reshape(3,3).astype(np.float64)
- B = np.arange(3*3*64*64).reshape(3,3,64,64).astype(np.float32)
- es = np.einsum('cl,cpxy->lpxy', A,B)
- tp = np.tensordot(A,B, axes=(0,0))
+ A = np.arange(3 * 3).reshape(3, 3).astype(np.float64)
+ B = np.arange(3 * 3 * 64 * 64).reshape(3, 3, 64, 64).astype(np.float32)
+ es = np.einsum('cl, cpxy->lpxy', A, B)
+ tp = np.tensordot(A, B, axes=(0, 0))
assert_equal(es, tp)
def test_einsum_fixed_collapsingbug(self):
# Issue #5147.
- # The bug only occurred when output argument of einssum was used.
+ # The bug only occured when output argument of einssum was used.
x = np.random.normal(0, 1, (5, 5, 5, 5))
y1 = np.zeros((5, 5))
np.einsum('aabb->ab', x, out=y1)
@@ -617,10 +680,219 @@ class TestEinSum(TestCase):
a = np.zeros((16, 1, 1), dtype=np.bool_)[:2]
a[...] = True
out = np.zeros((16, 1, 1), dtype=np.bool_)[:2]
- tgt = np.ones((2,1,1), dtype=np.bool_)
+ tgt = np.ones((2, 1, 1), dtype=np.bool_)
res = np.einsum('...ij,...jk->...ik', a, a, out=out)
assert_equal(res, tgt)
+ def optimize_compare(self, string):
+ # Tests all paths of the optimization function against
+ # conventional einsum
+ operands = [string]
+ terms = string.split('->')[0].split(',')
+ for term in terms:
+ dims = [global_size_dict[x] for x in term]
+ operands.append(np.random.rand(*dims))
+
+ noopt = np.einsum(*operands, optimize=False)
+ opt = np.einsum(*operands, optimize='greedy')
+ assert_almost_equal(opt, noopt)
+ opt = np.einsum(*operands, optimize='optimal')
+ assert_almost_equal(opt, noopt)
+
+ def test_hadamard_like_products(self):
+ # Hadamard outer products
+ self.optimize_compare('a,ab,abc->abc')
+ self.optimize_compare('a,b,ab->ab')
+
+ def test_index_transformations(self):
+ # Simple index transformation cases
+ self.optimize_compare('ea,fb,gc,hd,abcd->efgh')
+ self.optimize_compare('ea,fb,abcd,gc,hd->efgh')
+ self.optimize_compare('abcd,ea,fb,gc,hd->efgh')
+
+ def test_complex(self):
+ # Long test cases
+ self.optimize_compare('acdf,jbje,gihb,hfac,gfac,gifabc,hfac')
+ self.optimize_compare('acdf,jbje,gihb,hfac,gfac,gifabc,hfac')
+ self.optimize_compare('cd,bdhe,aidb,hgca,gc,hgibcd,hgac')
+ self.optimize_compare('abhe,hidj,jgba,hiab,gab')
+ self.optimize_compare('bde,cdh,agdb,hica,ibd,hgicd,hiac')
+ self.optimize_compare('chd,bde,agbc,hiad,hgc,hgi,hiad')
+ self.optimize_compare('chd,bde,agbc,hiad,bdi,cgh,agdb')
+ self.optimize_compare('bdhe,acad,hiab,agac,hibd')
+
+ def test_collapse(self):
+ # Inner products
+ self.optimize_compare('ab,ab,c->')
+ self.optimize_compare('ab,ab,c->c')
+ self.optimize_compare('ab,ab,cd,cd->')
+ self.optimize_compare('ab,ab,cd,cd->ac')
+ self.optimize_compare('ab,ab,cd,cd->cd')
+ self.optimize_compare('ab,ab,cd,cd,ef,ef->')
+
+ def test_expand(self):
+ # Outer products
+ self.optimize_compare('ab,cd,ef->abcdef')
+ self.optimize_compare('ab,cd,ef->acdf')
+ self.optimize_compare('ab,cd,de->abcde')
+ self.optimize_compare('ab,cd,de->be')
+ self.optimize_compare('ab,bcd,cd->abcd')
+ self.optimize_compare('ab,bcd,cd->abd')
+
+ def test_edge_cases(self):
+ # Difficult edge cases for optimization
+ self.optimize_compare('eb,cb,fb->cef')
+ self.optimize_compare('dd,fb,be,cdb->cef')
+ self.optimize_compare('bca,cdb,dbf,afc->')
+ self.optimize_compare('dcc,fce,ea,dbf->ab')
+ self.optimize_compare('fdf,cdd,ccd,afe->ae')
+ self.optimize_compare('abcd,ad')
+ self.optimize_compare('ed,fcd,ff,bcf->be')
+ self.optimize_compare('baa,dcf,af,cde->be')
+ self.optimize_compare('bd,db,eac->ace')
+ self.optimize_compare('fff,fae,bef,def->abd')
+ self.optimize_compare('efc,dbc,acf,fd->abe')
+ self.optimize_compare('ba,ac,da->bcd')
+
+ def test_inner_product(self):
+ # Inner products
+ self.optimize_compare('ab,ab')
+ self.optimize_compare('ab,ba')
+ self.optimize_compare('abc,abc')
+ self.optimize_compare('abc,bac')
+ self.optimize_compare('abc,cba')
+
+ def test_random_cases(self):
+ # Randomly built test cases
+ self.optimize_compare('aab,fa,df,ecc->bde')
+ self.optimize_compare('ecb,fef,bad,ed->ac')
+ self.optimize_compare('bcf,bbb,fbf,fc->')
+ self.optimize_compare('bb,ff,be->e')
+ self.optimize_compare('bcb,bb,fc,fff->')
+ self.optimize_compare('fbb,dfd,fc,fc->')
+ self.optimize_compare('afd,ba,cc,dc->bf')
+ self.optimize_compare('adb,bc,fa,cfc->d')
+ self.optimize_compare('bbd,bda,fc,db->acf')
+ self.optimize_compare('dba,ead,cad->bce')
+ self.optimize_compare('aef,fbc,dca->bde')
+
+
+class TestEinSumPath(TestCase):
+ def build_operands(self, string):
+
+ # Builds views based off initial operands
+ operands = [string]
+ terms = string.split('->')[0].split(',')
+ for term in terms:
+ dims = [global_size_dict[x] for x in term]
+ operands.append(np.random.rand(*dims))
+
+ return operands
+
+ def assert_path_equal(self, comp, benchmark):
+ # Checks if list of tuples are equivalent
+ ret = (len(comp) == len(benchmark))
+ assert_(ret)
+ for pos in range(len(comp) - 1):
+ ret &= isinstance(comp[pos + 1], tuple)
+ ret &= (comp[pos + 1] == benchmark[pos + 1])
+ assert_(ret)
+
+ def test_memory_contraints(self):
+ # Ensure memory constraints are satisfied
+
+ outer_test = self.build_operands('a,b,c->abc')
+
+ path, path_str = np.einsum_path(*outer_test, optimize=('greedy', 0))
+ self.assert_path_equal(path, ['einsum_path', (0, 1, 2)])
+
+ path, path_str = np.einsum_path(*outer_test, optimize=('optimal', 0))
+ self.assert_path_equal(path, ['einsum_path', (0, 1, 2)])
+
+ long_test = self.build_operands('acdf,jbje,gihb,hfac')
+ path, path_str = np.einsum_path(*long_test, optimize=('greedy', 0))
+ self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)])
+
+ path, path_str = np.einsum_path(*long_test, optimize=('optimal', 0))
+ self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)])
+
+ def test_long_paths(self):
+ # Long complex cases
+
+ # Long test 1
+ long_test1 = self.build_operands('acdf,jbje,gihb,hfac,gfac,gifabc,hfac')
+ path, path_str = np.einsum_path(*long_test1, optimize='greedy')
+ self.assert_path_equal(path, ['einsum_path',
+ (1, 4), (2, 4), (1, 4), (1, 3), (1, 2), (0, 1)])
+
+ path, path_str = np.einsum_path(*long_test1, optimize='optimal')
+ self.assert_path_equal(path, ['einsum_path',
+ (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)])
+
+ # Long test 2
+ long_test2 = self.build_operands('chd,bde,agbc,hiad,bdi,cgh,agdb')
+ path, path_str = np.einsum_path(*long_test2, optimize='greedy')
+ self.assert_path_equal(path, ['einsum_path',
+ (3, 4), (0, 3), (3, 4), (1, 3), (1, 2), (0, 1)])
+
+ path, path_str = np.einsum_path(*long_test2, optimize='optimal')
+ self.assert_path_equal(path, ['einsum_path',
+ (0, 5), (1, 4), (3, 4), (1, 3), (1, 2), (0, 1)])
+
+ def test_edge_paths(self):
+ # Difficult edge cases
+
+ # Edge test1
+ edge_test1 = self.build_operands('eb,cb,fb->cef')
+ path, path_str = np.einsum_path(*edge_test1, optimize='greedy')
+ self.assert_path_equal(path, ['einsum_path', (0, 2), (0, 1)])
+
+ path, path_str = np.einsum_path(*edge_test1, optimize='optimal')
+ self.assert_path_equal(path, ['einsum_path', (0, 2), (0, 1)])
+
+ # Edge test2
+ edge_test2 = self.build_operands('dd,fb,be,cdb->cef')
+ path, path_str = np.einsum_path(*edge_test2, optimize='greedy')
+ self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 1), (0, 1)])
+
+ path, path_str = np.einsum_path(*edge_test2, optimize='optimal')
+ self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 1), (0, 1)])
+
+ # Edge test3
+ edge_test3 = self.build_operands('bca,cdb,dbf,afc->')
+ path, path_str = np.einsum_path(*edge_test3, optimize='greedy')
+ self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)])
+
+ path, path_str = np.einsum_path(*edge_test3, optimize='optimal')
+ self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)])
+
+ # Edge test4
+ edge_test4 = self.build_operands('dcc,fce,ea,dbf->ab')
+ path, path_str = np.einsum_path(*edge_test4, optimize='greedy')
+ self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)])
+
+ path, path_str = np.einsum_path(*edge_test4, optimize='optimal')
+ self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)])
+
+ def test_path_type_input(self):
+ # Test explicit path handeling
+ path_test = self.build_operands('dcc,fce,ea,dbf->ab')
+
+ path, path_str = np.einsum_path(*path_test, optimize=False)
+ self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)])
+
+ path, path_str = np.einsum_path(*path_test, optimize=True)
+ self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)])
+
+ exp_path = ['einsum_path', (0, 2), (0, 2), (0, 1)]
+ path, path_str = np.einsum_path(*path_test, optimize=exp_path)
+ self.assert_path_equal(path, exp_path)
+
+ # Double check einsum works on the input path
+ noopt = np.einsum(*path_test, optimize=False)
+ opt = np.einsum(*path_test, optimize=exp_path)
+ assert_almost_equal(noopt, opt)
+
if __name__ == "__main__":
run_module_suite()