diff options
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r-- | numpy/core/einsumfunc.py | 123 |
1 files changed, 75 insertions, 48 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index 1281b3c98..3412c3fd5 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -9,6 +9,7 @@ import itertools from numpy.compat import basestring from numpy.core.multiarray import c_einsum from numpy.core.numeric import asanyarray, tensordot +from numpy.core.overrides import array_function_dispatch __all__ = ['einsum', 'einsum_path'] @@ -40,10 +41,10 @@ def _flop_count(idx_contraction, inner, num_terms, size_dictionary): -------- >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) - 90 + 30 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) - 270 + 60 """ @@ -168,9 +169,9 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit): Examples -------- >>> isets = [set('abd'), set('ac'), set('bdc')] - >>> oset = set('') + >>> oset = set() >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} - >>> _path__optimal_path(isets, oset, idx_sizes, 5000) + >>> _optimal_path(isets, oset, idx_sizes, 5000) [(0, 2), (0, 1)] """ @@ -286,7 +287,7 @@ def _update_other_results(results, best): Returns ------- mod_results : list - The list of modifed results, updated with outcome of ``best`` contraction. + The list of modified results, updated with outcome of ``best`` contraction. """ best_con = best[1] @@ -339,9 +340,9 @@ def _greedy_path(input_sets, output_set, idx_dict, memory_limit): Examples -------- >>> isets = [set('abd'), set('ac'), set('bdc')] - >>> oset = set('') + >>> oset = set() >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} - >>> _path__greedy_path(isets, oset, idx_sizes, 5000) + >>> _greedy_path(isets, oset, idx_sizes, 5000) [(0, 2), (0, 1)] """ @@ -538,13 +539,14 @@ def _parse_einsum_input(operands): -------- The operand list is simplified to reduce printing: + >>> np.random.seed(123) >>> a = np.random.rand(4, 4) >>> b = np.random.rand(4, 4, 4) - >>> __parse_einsum_input(('...a,...a->...', a, b)) - ('za,xza', 'xz', [a, b]) + >>> _parse_einsum_input(('...a,...a->...', a, b)) + ('za,xza', 'xz', [a, b]) # may vary - >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) - ('za,xza', 'xz', [a, b]) + >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) + ('za,xza', 'xz', [a, b]) # may vary """ if len(operands) == 0: @@ -689,6 +691,17 @@ def _parse_einsum_input(operands): return (input_subscripts, output_subscript, operands) +def _einsum_path_dispatcher(*operands, **kwargs): + # NOTE: technically, we should only dispatch on array-like arguments, not + # subscripts (given as strings). But separating operands into + # arrays/subscripts is a little tricky/slow (given einsum's two supported + # signatures), so as a practical shortcut we dispatch on everything. + # Strings will be ignored for dispatching since they don't define + # __array_function__. + return operands + + +@array_function_dispatch(_einsum_path_dispatcher, module='numpy') def einsum_path(*operands, **kwargs): """ einsum_path(subscripts, *operands, optimize='greedy') @@ -751,6 +764,7 @@ def einsum_path(*operands, **kwargs): of the contraction and the remaining contraction ``(0, 1)`` is then completed. + >>> np.random.seed(123) >>> a = np.random.rand(2, 2) >>> b = np.random.rand(2, 5) >>> c = np.random.rand(5, 2) @@ -758,7 +772,7 @@ def einsum_path(*operands, **kwargs): >>> print(path_info[0]) ['einsum_path', (1, 2), (0, 1)] >>> print(path_info[1]) - Complete contraction: ij,jk,kl->il + Complete contraction: ij,jk,kl->il # may vary Naive scaling: 4 Optimized scaling: 3 Naive FLOP count: 1.600e+02 @@ -777,12 +791,12 @@ def einsum_path(*operands, **kwargs): >>> I = np.random.rand(10, 10, 10, 10) >>> C = np.random.rand(10, 10) >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, - optimize='greedy') + ... optimize='greedy') >>> print(path_info[0]) ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] - >>> print(path_info[1]) - Complete contraction: ea,fb,abcd,gc,hd->efgh + >>> print(path_info[1]) + Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary Naive scaling: 8 Optimized scaling: 5 Naive FLOP count: 8.000e+08 @@ -837,7 +851,6 @@ def einsum_path(*operands, **kwargs): # Python side parsing input_subscripts, output_subscript, operands = _parse_einsum_input(operands) - subscripts = input_subscripts + '->' + output_subscript # Build a few useful list and sets input_list = input_subscripts.split(',') @@ -876,9 +889,8 @@ def einsum_path(*operands, **kwargs): broadcast_indices = [set(x) for x in broadcast_indices] # Compute size of each input array plus the output array - size_list = [] - for term in input_list + [output_subscript]: - size_list.append(_compute_size_by_dict(term, dimension_dict)) + size_list = [_compute_size_by_dict(term, dimension_dict) + for term in input_list + [output_subscript]] max_size = max(size_list) if memory_limit is None: @@ -980,7 +992,16 @@ def einsum_path(*operands, **kwargs): return (path, path_print) +def _einsum_dispatcher(*operands, **kwargs): + # Arguably we dispatch on more arguments that we really should; see note in + # _einsum_path_dispatcher for why. + for op in operands: + yield op + yield kwargs.get('out') + + # Rewrite einsum to handle different cases +@array_function_dispatch(_einsum_dispatcher, module='numpy') def einsum(*operands, **kwargs): """ einsum(subscripts, *operands, out=None, dtype=None, order='K', @@ -1255,32 +1276,32 @@ def einsum(*operands, **kwargs): >>> a = np.arange(60.).reshape(3,4,5) >>> b = np.arange(24.).reshape(4,3,2) >>> np.einsum('ijk,jil->kl', a, b) - array([[ 4400., 4730.], - [ 4532., 4874.], - [ 4664., 5018.], - [ 4796., 5162.], - [ 4928., 5306.]]) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) - array([[ 4400., 4730.], - [ 4532., 4874.], - [ 4664., 5018.], - [ 4796., 5162.], - [ 4928., 5306.]]) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) >>> np.tensordot(a,b, axes=([1,0],[0,1])) - array([[ 4400., 4730.], - [ 4532., 4874.], - [ 4664., 5018.], - [ 4796., 5162.], - [ 4928., 5306.]]) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) Writeable returned arrays (since version 1.10.0): >>> a = np.zeros((3, 3)) >>> np.einsum('ii->i', a)[:] = 1 >>> a - array([[ 1., 0., 0.], - [ 0., 1., 0.], - [ 0., 0., 1.]]) + array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) Example of ellipsis use: @@ -1303,19 +1324,27 @@ def einsum(*operands, **kwargs): particularly significant with larger arrays: >>> a = np.ones(64).reshape(2,4,8) - # Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) + + Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) + >>> for iteration in range(500): - ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) - # Sub-optimal `einsum` (due to repeated path calculation time): ~330ms + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) + + Sub-optimal `einsum` (due to repeated path calculation time): ~330ms + >>> for iteration in range(500): - ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal') - # Greedy `einsum` (faster optimal path approximation): ~160ms + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal') + + Greedy `einsum` (faster optimal path approximation): ~160ms + >>> for iteration in range(500): - ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') - # Optimal `einsum` (best usage pattern in some use cases): ~110ms + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') + + Optimal `einsum` (best usage pattern in some use cases): ~110ms + >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0] >>> for iteration in range(500): - ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) """ @@ -1354,9 +1383,7 @@ def einsum(*operands, **kwargs): # Start contraction loop for num, contraction in enumerate(contraction_list): inds, idx_rm, einsum_str, remaining, blas = contraction - tmp_operands = [] - for x in inds: - tmp_operands.append(operands.pop(x)) + tmp_operands = [operands.pop(x) for x in inds] # Do we need to deal with the output? handle_out = specified_out and ((num + 1) == len(contraction_list)) |