summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2021-11-09 18:30:45 -0600
committerSebastian Berg <sebastian@sipsolutions.net>2021-11-12 11:34:05 -0600
commitd9dae76fce6fe3e87ee01670d3905f6fbdd04569 (patch)
tree887bd210aa1f3de338a45c053c1c6e1d5e25820b /numpy
parenta4daaf5a3924248f71caddbcbd3cf28afef802a6 (diff)
downloadnumpy-d9dae76fce6fe3e87ee01670d3905f6fbdd04569.tar.gz
TST: Add exhaustive test for einsum specialized loops
This hopefully tests things well enough, at least some/most of the paths get triggered and led to errors without the previous float16 typing fixes. I manually confirmed that all paths that were *modified* in the previous commit actually get hit with float16 specialized loops. NOTE: This test may be a bit fragile with floating point roundoff errors, and can in parts be relaxed if this happens.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/tests/test_einsum.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 78c5e527b..172311624 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -1,5 +1,7 @@
import itertools
+import pytest
+
import numpy as np
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_almost_equal,
@@ -744,6 +746,52 @@ class TestEinsum:
np.einsum('ij,jk->ik', x, x, out=out)
assert_array_equal(out.base, correct_base)
+ @pytest.mark.parametrize("dtype",
+ np.typecodes["AllFloat"] + np.typecodes["AllInteger"])
+ def test_different_paths(self, dtype):
+ # Test originally added to cover broken float16 path: gh-20305
+ # Likely most are covered elsewhere, at least partially.
+ dtype = np.dtype(dtype)
+ # Simple test, designed to excersize most specialized code paths,
+ # note the +0.5 for floats. This makes sure we use a float value
+ # where the results must be exact.
+ arr = (np.arange(7) + 0.5).astype(dtype)
+ scalar = np.array(2, dtype=dtype)
+
+ # contig -> scalar:
+ res = np.einsum('i->', arr)
+ assert res == arr.sum()
+ # contig, contig -> contig:
+ res = np.einsum('i,i->i', arr, arr)
+ assert_array_equal(res, arr * arr)
+ # noncontig, noncontig -> contig:
+ res = np.einsum('i,i->i', arr.repeat(2)[::2], arr.repeat(2)[::2])
+ assert_array_equal(res, arr * arr)
+ # contig + contig -> scalar
+ assert np.einsum('i,i->', arr, arr) == (arr * arr).sum()
+ # contig + scalar -> contig (with out)
+ out = np.ones(7, dtype=dtype)
+ res = np.einsum('i,->i', arr, dtype.type(2), out=out)
+ assert_array_equal(res, arr * dtype.type(2))
+ # scalar + contig -> contig (with out)
+ res = np.einsum(',i->i', scalar, arr)
+ assert_array_equal(res, arr * dtype.type(2))
+ # scalar + contig -> scalar
+ res = np.einsum(',i->', scalar, arr)
+ # Use einsum to compare to not have difference due to sum round-offs:
+ assert res == np.einsum('i->', scalar * arr)
+ # contig + scalar -> scalar
+ res = np.einsum('i,->', arr, scalar)
+ # Use einsum to compare to not have difference due to sum round-offs:
+ assert res == np.einsum('i->', scalar * arr)
+ # contig + contig + contig -> scalar
+ arr = np.array([0.5, 0.5, 0.25, 4.5, 3.], dtype=dtype)
+ res = np.einsum('i,i,i->', arr, arr, arr)
+ assert_array_equal(res, (arr * arr * arr).sum())
+ # four arrays:
+ res = np.einsum('i,i,i,i->', arr, arr, arr, arr)
+ assert_array_equal(res, (arr * arr * arr * arr).sum())
+
def test_small_boolean_arrays(self):
# See gh-5946.
# Use array of True embedded in False.