diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-10-12 20:57:00 -0700 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-10-12 21:21:25 -0700 |
commit | dfab760b4a328d9fa29cef123e0fe8e2926b0c8c (patch) | |
tree | d2ace5c607aec928a72382c6fef8a6d6e3f504fc /numpy/core/einsumfunc.py | |
parent | 18c12106b1c59052ab9c94a5a67513120580a10b (diff) | |
download | numpy-dfab760b4a328d9fa29cef123e0fe8e2926b0c8c.tar.gz |
ENH: __array_function__ for np.einsum and np.block
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r-- | numpy/core/einsumfunc.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index 1281b3c98..3ffb152e1 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -9,6 +9,7 @@ import itertools from numpy.compat import basestring from numpy.core.multiarray import c_einsum from numpy.core.numeric import asanyarray, tensordot +from numpy.core.overrides import array_function_dispatch __all__ = ['einsum', 'einsum_path'] @@ -689,6 +690,17 @@ def _parse_einsum_input(operands): return (input_subscripts, output_subscript, operands) +def _einsum_path_dispatcher(*operands, **kwargs): + # NOTE: technically, we should only dispatch on array-like arguments, not + # subscripts (given as strings). But separating operands into + # arrays/subscripts is a little tricky/slow (given einsum's two supported + # signatures), so as a practical shortcut we dispatch on everything. + # Strings will be ignored for dispatching since they don't define + # __array_function__. + return operands + + +@array_function_dispatch(_einsum_path_dispatcher) def einsum_path(*operands, **kwargs): """ einsum_path(subscripts, *operands, optimize='greedy') @@ -980,7 +992,16 @@ def einsum_path(*operands, **kwargs): return (path, path_print) +def _einsum_dispatcher(*operands, **kwargs): + # Arguably we dispatch on more arguments that we really should; see note in + # _einsum_path_dispatcher for why. + for op in operands: + yield op + yield kwargs.get('out') + + # Rewrite einsum to handle different cases +@array_function_dispatch(_einsum_dispatcher) def einsum(*operands, **kwargs): """ einsum(subscripts, *operands, out=None, dtype=None, order='K', |