summaryrefslogtreecommitdiff
path: root/numpy/core/einsumfunc.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-10-12 20:57:00 -0700
committerStephan Hoyer <shoyer@google.com>2018-10-12 21:21:25 -0700
commitdfab760b4a328d9fa29cef123e0fe8e2926b0c8c (patch)
treed2ace5c607aec928a72382c6fef8a6d6e3f504fc /numpy/core/einsumfunc.py
parent18c12106b1c59052ab9c94a5a67513120580a10b (diff)
downloadnumpy-dfab760b4a328d9fa29cef123e0fe8e2926b0c8c.tar.gz
ENH: __array_function__ for np.einsum and np.block
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r--numpy/core/einsumfunc.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index 1281b3c98..3ffb152e1 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -9,6 +9,7 @@ import itertools
from numpy.compat import basestring
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asanyarray, tensordot
+from numpy.core.overrides import array_function_dispatch
__all__ = ['einsum', 'einsum_path']
@@ -689,6 +690,17 @@ def _parse_einsum_input(operands):
return (input_subscripts, output_subscript, operands)
+def _einsum_path_dispatcher(*operands, **kwargs):
+ # NOTE: technically, we should only dispatch on array-like arguments, not
+ # subscripts (given as strings). But separating operands into
+ # arrays/subscripts is a little tricky/slow (given einsum's two supported
+ # signatures), so as a practical shortcut we dispatch on everything.
+ # Strings will be ignored for dispatching since they don't define
+ # __array_function__.
+ return operands
+
+
+@array_function_dispatch(_einsum_path_dispatcher)
def einsum_path(*operands, **kwargs):
"""
einsum_path(subscripts, *operands, optimize='greedy')
@@ -980,7 +992,16 @@ def einsum_path(*operands, **kwargs):
return (path, path_print)
+def _einsum_dispatcher(*operands, **kwargs):
+ # Arguably we dispatch on more arguments that we really should; see note in
+ # _einsum_path_dispatcher for why.
+ for op in operands:
+ yield op
+ yield kwargs.get('out')
+
+
# Rewrite einsum to handle different cases
+@array_function_dispatch(_einsum_dispatcher)
def einsum(*operands, **kwargs):
"""
einsum(subscripts, *operands, out=None, dtype=None, order='K',