diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-06-08 22:28:41 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-08 22:28:41 -0500 |
commit | d647ef2d98852e322d487d6045d5860223dcda79 (patch) | |
tree | 3e6c4d39c34a047244631e49130c26b5be991a17 | |
parent | 8f8603dd18c9c70191fd75d98793dc94c8ac30fa (diff) | |
parent | 5f6be40bf009aa4438ec89310c4b55ce2eb9ee07 (diff) | |
download | numpy-d647ef2d98852e322d487d6045d5860223dcda79.tar.gz |
Merge pull request #16446 from dgasmith/einsum_order
BUG: fixes einsum output order with optimization (#14615)
-rw-r--r-- | numpy/core/einsumfunc.py | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 39 |
2 files changed, 48 insertions, 2 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index c46ae173d..f65f4015c 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -1358,11 +1358,18 @@ def einsum(*operands, out=None, optimize=False, **kwargs): raise TypeError("Did not understand the following kwargs: %s" % unknown_kwargs) - # Build the contraction list and operand operands, contraction_list = einsum_path(*operands, optimize=optimize, einsum_call=True) + # Handle order kwarg for output array, c_einsum allows mixed case + output_order = kwargs.pop('order', 'K') + if output_order.upper() == 'A': + if all(arr.flags.f_contiguous for arr in operands): + output_order = 'F' + else: + output_order = 'C' + # Start contraction loop for num, contraction in enumerate(contraction_list): inds, idx_rm, einsum_str, remaining, blas = contraction @@ -1412,4 +1419,4 @@ def einsum(*operands, out=None, optimize=False, **kwargs): if specified_out: return out else: - return operands[0] + return asanyarray(operands[0], order=output_order) diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index da84735a0..c697d0c2d 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -94,6 +94,10 @@ class TestEinsum: b = np.ones((3, 4, 5)) np.einsum('aabcb,abc', a, b) + # Check order kwarg, asanyarray allows 1d to pass through + assert_raises(ValueError, np.einsum, "i->i", np.arange(6).reshape(-1, 1), + optimize=do_opt, order='d') + def test_einsum_views(self): # pass-through for do_opt in [True, False]: @@ -876,6 +880,41 @@ class TestEinsum: g = np.arange(64).reshape(2, 4, 8) self.optimize_compare('obk,ijk->ioj', operands=[g, g]) + def test_output_order(self): + # Ensure output order is respected for optimize cases, the below + # conraction should yield a reshaped tensor view + # gh-16415 + + a = np.ones((2, 3, 5), order='F') + b = np.ones((4, 3), order='F') + + for opt in [True, False]: + tmp = np.einsum('...ft,mf->...mt', a, b, order='a', optimize=opt) + assert_(tmp.flags.f_contiguous) + + tmp = np.einsum('...ft,mf->...mt', a, b, order='f', optimize=opt) + assert_(tmp.flags.f_contiguous) + + tmp = np.einsum('...ft,mf->...mt', a, b, order='c', optimize=opt) + assert_(tmp.flags.c_contiguous) + + tmp = np.einsum('...ft,mf->...mt', a, b, order='k', optimize=opt) + assert_(tmp.flags.c_contiguous is False) + assert_(tmp.flags.f_contiguous is False) + + tmp = np.einsum('...ft,mf->...mt', a, b, optimize=opt) + assert_(tmp.flags.c_contiguous is False) + assert_(tmp.flags.f_contiguous is False) + + c = np.ones((4, 3), order='C') + for opt in [True, False]: + tmp = np.einsum('...ft,mf->...mt', a, c, order='a', optimize=opt) + assert_(tmp.flags.c_contiguous) + + d = np.ones((2, 3, 5), order='C') + for opt in [True, False]: + tmp = np.einsum('...ft,mf->...mt', d, c, order='a', optimize=opt) + assert_(tmp.flags.c_contiguous) class TestEinsumPath: def build_operands(self, string, size_dict=global_size_dict): |