summaryrefslogtreecommitdiff
path: root/numpy/core/einsumfunc.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r--numpy/core/einsumfunc.py123
1 files changed, 75 insertions, 48 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index 1281b3c98..3412c3fd5 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -9,6 +9,7 @@ import itertools
from numpy.compat import basestring
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asanyarray, tensordot
+from numpy.core.overrides import array_function_dispatch
__all__ = ['einsum', 'einsum_path']
@@ -40,10 +41,10 @@ def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
--------
>>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
- 90
+ 30
>>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
- 270
+ 60
"""
@@ -168,9 +169,9 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
Examples
--------
>>> isets = [set('abd'), set('ac'), set('bdc')]
- >>> oset = set('')
+ >>> oset = set()
>>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
- >>> _path__optimal_path(isets, oset, idx_sizes, 5000)
+ >>> _optimal_path(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
@@ -286,7 +287,7 @@ def _update_other_results(results, best):
Returns
-------
mod_results : list
- The list of modifed results, updated with outcome of ``best`` contraction.
+ The list of modified results, updated with outcome of ``best`` contraction.
"""
best_con = best[1]
@@ -339,9 +340,9 @@ def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
Examples
--------
>>> isets = [set('abd'), set('ac'), set('bdc')]
- >>> oset = set('')
+ >>> oset = set()
>>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
- >>> _path__greedy_path(isets, oset, idx_sizes, 5000)
+ >>> _greedy_path(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
@@ -538,13 +539,14 @@ def _parse_einsum_input(operands):
--------
The operand list is simplified to reduce printing:
+ >>> np.random.seed(123)
>>> a = np.random.rand(4, 4)
>>> b = np.random.rand(4, 4, 4)
- >>> __parse_einsum_input(('...a,...a->...', a, b))
- ('za,xza', 'xz', [a, b])
+ >>> _parse_einsum_input(('...a,...a->...', a, b))
+ ('za,xza', 'xz', [a, b]) # may vary
- >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
- ('za,xza', 'xz', [a, b])
+ >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
+ ('za,xza', 'xz', [a, b]) # may vary
"""
if len(operands) == 0:
@@ -689,6 +691,17 @@ def _parse_einsum_input(operands):
return (input_subscripts, output_subscript, operands)
+def _einsum_path_dispatcher(*operands, **kwargs):
+ # NOTE: technically, we should only dispatch on array-like arguments, not
+ # subscripts (given as strings). But separating operands into
+ # arrays/subscripts is a little tricky/slow (given einsum's two supported
+ # signatures), so as a practical shortcut we dispatch on everything.
+ # Strings will be ignored for dispatching since they don't define
+ # __array_function__.
+ return operands
+
+
+@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
def einsum_path(*operands, **kwargs):
"""
einsum_path(subscripts, *operands, optimize='greedy')
@@ -751,6 +764,7 @@ def einsum_path(*operands, **kwargs):
of the contraction and the remaining contraction ``(0, 1)`` is then
completed.
+ >>> np.random.seed(123)
>>> a = np.random.rand(2, 2)
>>> b = np.random.rand(2, 5)
>>> c = np.random.rand(5, 2)
@@ -758,7 +772,7 @@ def einsum_path(*operands, **kwargs):
>>> print(path_info[0])
['einsum_path', (1, 2), (0, 1)]
>>> print(path_info[1])
- Complete contraction: ij,jk,kl->il
+ Complete contraction: ij,jk,kl->il # may vary
Naive scaling: 4
Optimized scaling: 3
Naive FLOP count: 1.600e+02
@@ -777,12 +791,12 @@ def einsum_path(*operands, **kwargs):
>>> I = np.random.rand(10, 10, 10, 10)
>>> C = np.random.rand(10, 10)
>>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
- optimize='greedy')
+ ... optimize='greedy')
>>> print(path_info[0])
['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
- >>> print(path_info[1])
- Complete contraction: ea,fb,abcd,gc,hd->efgh
+ >>> print(path_info[1])
+ Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
Naive scaling: 8
Optimized scaling: 5
Naive FLOP count: 8.000e+08
@@ -837,7 +851,6 @@ def einsum_path(*operands, **kwargs):
# Python side parsing
input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
- subscripts = input_subscripts + '->' + output_subscript
# Build a few useful list and sets
input_list = input_subscripts.split(',')
@@ -876,9 +889,8 @@ def einsum_path(*operands, **kwargs):
broadcast_indices = [set(x) for x in broadcast_indices]
# Compute size of each input array plus the output array
- size_list = []
- for term in input_list + [output_subscript]:
- size_list.append(_compute_size_by_dict(term, dimension_dict))
+ size_list = [_compute_size_by_dict(term, dimension_dict)
+ for term in input_list + [output_subscript]]
max_size = max(size_list)
if memory_limit is None:
@@ -980,7 +992,16 @@ def einsum_path(*operands, **kwargs):
return (path, path_print)
+def _einsum_dispatcher(*operands, **kwargs):
+ # Arguably we dispatch on more arguments that we really should; see note in
+ # _einsum_path_dispatcher for why.
+ for op in operands:
+ yield op
+ yield kwargs.get('out')
+
+
# Rewrite einsum to handle different cases
+@array_function_dispatch(_einsum_dispatcher, module='numpy')
def einsum(*operands, **kwargs):
"""
einsum(subscripts, *operands, out=None, dtype=None, order='K',
@@ -1255,32 +1276,32 @@ def einsum(*operands, **kwargs):
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> np.einsum('ijk,jil->kl', a, b)
- array([[ 4400., 4730.],
- [ 4532., 4874.],
- [ 4664., 5018.],
- [ 4796., 5162.],
- [ 4928., 5306.]])
+ array([[4400., 4730.],
+ [4532., 4874.],
+ [4664., 5018.],
+ [4796., 5162.],
+ [4928., 5306.]])
>>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
- array([[ 4400., 4730.],
- [ 4532., 4874.],
- [ 4664., 5018.],
- [ 4796., 5162.],
- [ 4928., 5306.]])
+ array([[4400., 4730.],
+ [4532., 4874.],
+ [4664., 5018.],
+ [4796., 5162.],
+ [4928., 5306.]])
>>> np.tensordot(a,b, axes=([1,0],[0,1]))
- array([[ 4400., 4730.],
- [ 4532., 4874.],
- [ 4664., 5018.],
- [ 4796., 5162.],
- [ 4928., 5306.]])
+ array([[4400., 4730.],
+ [4532., 4874.],
+ [4664., 5018.],
+ [4796., 5162.],
+ [4928., 5306.]])
Writeable returned arrays (since version 1.10.0):
>>> a = np.zeros((3, 3))
>>> np.einsum('ii->i', a)[:] = 1
>>> a
- array([[ 1., 0., 0.],
- [ 0., 1., 0.],
- [ 0., 0., 1.]])
+ array([[1., 0., 0.],
+ [0., 1., 0.],
+ [0., 0., 1.]])
Example of ellipsis use:
@@ -1303,19 +1324,27 @@ def einsum(*operands, **kwargs):
particularly significant with larger arrays:
>>> a = np.ones(64).reshape(2,4,8)
- # Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
+
+ Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
+
>>> for iteration in range(500):
- ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
- # Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
+
+ Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
+
>>> for iteration in range(500):
- ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
- # Greedy `einsum` (faster optimal path approximation): ~160ms
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
+
+ Greedy `einsum` (faster optimal path approximation): ~160ms
+
>>> for iteration in range(500):
- ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
- # Optimal `einsum` (best usage pattern in some use cases): ~110ms
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
+
+ Optimal `einsum` (best usage pattern in some use cases): ~110ms
+
>>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
>>> for iteration in range(500):
- ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
"""
@@ -1354,9 +1383,7 @@ def einsum(*operands, **kwargs):
# Start contraction loop
for num, contraction in enumerate(contraction_list):
inds, idx_rm, einsum_str, remaining, blas = contraction
- tmp_operands = []
- for x in inds:
- tmp_operands.append(operands.pop(x))
+ tmp_operands = [operands.pop(x) for x in inds]
# Do we need to deal with the output?
handle_out = specified_out and ((num + 1) == len(contraction_list))