summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-06-08 22:28:41 -0500
committerGitHub <noreply@github.com>2020-06-08 22:28:41 -0500
commitd647ef2d98852e322d487d6045d5860223dcda79 (patch)
tree3e6c4d39c34a047244631e49130c26b5be991a17
parent8f8603dd18c9c70191fd75d98793dc94c8ac30fa (diff)
parent5f6be40bf009aa4438ec89310c4b55ce2eb9ee07 (diff)
downloadnumpy-d647ef2d98852e322d487d6045d5860223dcda79.tar.gz
Merge pull request #16446 from dgasmith/einsum_order
BUG: fixes einsum output order with optimization (#14615)
-rw-r--r--numpy/core/einsumfunc.py11
-rw-r--r--numpy/core/tests/test_einsum.py39
2 files changed, 48 insertions, 2 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index c46ae173d..f65f4015c 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -1358,11 +1358,18 @@ def einsum(*operands, out=None, optimize=False, **kwargs):
raise TypeError("Did not understand the following kwargs: %s"
% unknown_kwargs)
-
# Build the contraction list and operand
operands, contraction_list = einsum_path(*operands, optimize=optimize,
einsum_call=True)
+ # Handle order kwarg for output array, c_einsum allows mixed case
+ output_order = kwargs.pop('order', 'K')
+ if output_order.upper() == 'A':
+ if all(arr.flags.f_contiguous for arr in operands):
+ output_order = 'F'
+ else:
+ output_order = 'C'
+
# Start contraction loop
for num, contraction in enumerate(contraction_list):
inds, idx_rm, einsum_str, remaining, blas = contraction
@@ -1412,4 +1419,4 @@ def einsum(*operands, out=None, optimize=False, **kwargs):
if specified_out:
return out
else:
- return operands[0]
+ return asanyarray(operands[0], order=output_order)
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index da84735a0..c697d0c2d 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -94,6 +94,10 @@ class TestEinsum:
b = np.ones((3, 4, 5))
np.einsum('aabcb,abc', a, b)
+ # Check order kwarg, asanyarray allows 1d to pass through
+ assert_raises(ValueError, np.einsum, "i->i", np.arange(6).reshape(-1, 1),
+ optimize=do_opt, order='d')
+
def test_einsum_views(self):
# pass-through
for do_opt in [True, False]:
@@ -876,6 +880,41 @@ class TestEinsum:
g = np.arange(64).reshape(2, 4, 8)
self.optimize_compare('obk,ijk->ioj', operands=[g, g])
+ def test_output_order(self):
+ # Ensure output order is respected for optimize cases, the below
+ # conraction should yield a reshaped tensor view
+ # gh-16415
+
+ a = np.ones((2, 3, 5), order='F')
+ b = np.ones((4, 3), order='F')
+
+ for opt in [True, False]:
+ tmp = np.einsum('...ft,mf->...mt', a, b, order='a', optimize=opt)
+ assert_(tmp.flags.f_contiguous)
+
+ tmp = np.einsum('...ft,mf->...mt', a, b, order='f', optimize=opt)
+ assert_(tmp.flags.f_contiguous)
+
+ tmp = np.einsum('...ft,mf->...mt', a, b, order='c', optimize=opt)
+ assert_(tmp.flags.c_contiguous)
+
+ tmp = np.einsum('...ft,mf->...mt', a, b, order='k', optimize=opt)
+ assert_(tmp.flags.c_contiguous is False)
+ assert_(tmp.flags.f_contiguous is False)
+
+ tmp = np.einsum('...ft,mf->...mt', a, b, optimize=opt)
+ assert_(tmp.flags.c_contiguous is False)
+ assert_(tmp.flags.f_contiguous is False)
+
+ c = np.ones((4, 3), order='C')
+ for opt in [True, False]:
+ tmp = np.einsum('...ft,mf->...mt', a, c, order='a', optimize=opt)
+ assert_(tmp.flags.c_contiguous)
+
+ d = np.ones((2, 3, 5), order='C')
+ for opt in [True, False]:
+ tmp = np.einsum('...ft,mf->...mt', d, c, order='a', optimize=opt)
+ assert_(tmp.flags.c_contiguous)
class TestEinsumPath:
def build_operands(self, string, size_dict=global_size_dict):