diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2022-09-22 15:37:17 +0200 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2022-10-05 13:37:01 -0600 |
commit | ad81a80cc73e0e1ccb6b13289b75bb50df448dd9 (patch) | |
tree | 504c38888157f148bea47da2187599e31b1529cb | |
parent | dbbf06aab54f803fb4560048d305ea182311f9b2 (diff) | |
download | numpy-ad81a80cc73e0e1ccb6b13289b75bb50df448dd9.tar.gz |
BUG: Fix complex vector dot with more than NPY_CBLAS_CHUNK elements
The iteration was simply using the wrong value, the larger value
might even work sometimes, but then we do another iteration counting
the remaining elements twice.
Closes gh-22262
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 3 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 11 |
2 files changed, 12 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index cb87433ef..8597d80e8 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -3581,7 +3581,8 @@ NPY_NO_EXPORT void CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK; @type@ tmp[2]; - CBLAS_FUNC(cblas_@prefix@dotu_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp); + CBLAS_FUNC(cblas_@prefix@dotu_sub)( + (CBLAS_INT)chunk, ip1, is1b, ip2, is2b, tmp); sum[0] += (double)tmp[0]; sum[1] += (double)tmp[1]; /* use char strides here */ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index a7bdf335b..828e7f033 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -29,7 +29,7 @@ from numpy.testing import ( assert_allclose, IS_PYPY, IS_PYSTON, HAS_REFCOUNT, assert_array_less, runstring, temppath, suppress_warnings, break_cycles, ) -from numpy.testing._private.utils import _no_tracing +from numpy.testing._private.utils import requires_memory, _no_tracing from numpy.core.tests._locales import CommaDecimalPointLocale from numpy.lib.recfunctions import repack_fields @@ -6691,6 +6691,15 @@ class TestDot: # Strides in A cols and X assert_dot_close(A_f_12, X_f_2, desired) + @pytest.mark.slow + @pytest.mark.parametrize("dtype", [np.float64, np.complex128]) + @requires_memory(free_bytes=9*10**9) # complex case needs 8GiB+ + def test_huge_vectordot(self, dtype): + # Large vector multiplications are chunked with 32bit BLAS + # Test that the chunking does the right thing, see also gh-22262 + data = np.ones(2**30+100, dtype=dtype) + res = np.dot(data, data) + assert res == 2**30+100 class MatmulCommon: |