summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2022-10-05 21:10:17 -0500
committerGitHub <noreply@github.com>2022-10-05 21:10:17 -0500
commit611fadc179cf749be0571dff4ce040aebb665700 (patch)
tree37cc570a1b84ceb8d6e5c8dce4af327b7162663e
parent93ab5d641af16f5df1a008cb97e573d70999c0a2 (diff)
parentad81a80cc73e0e1ccb6b13289b75bb50df448dd9 (diff)
downloadnumpy-611fadc179cf749be0571dff4ce040aebb665700.tar.gz
Merge pull request #22384 from charris/backport-22327
BUG: Fix complex vector dot with more than NPY_CBLAS_CHUNK elements
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src3
-rw-r--r--numpy/core/tests/test_multiarray.py11
2 files changed, 12 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index cb87433ef..8597d80e8 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -3581,7 +3581,8 @@ NPY_NO_EXPORT void
CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
@type@ tmp[2];
- CBLAS_FUNC(cblas_@prefix@dotu_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp);
+ CBLAS_FUNC(cblas_@prefix@dotu_sub)(
+ (CBLAS_INT)chunk, ip1, is1b, ip2, is2b, tmp);
sum[0] += (double)tmp[0];
sum[1] += (double)tmp[1];
/* use char strides here */
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index a7bdf335b..828e7f033 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -29,7 +29,7 @@ from numpy.testing import (
assert_allclose, IS_PYPY, IS_PYSTON, HAS_REFCOUNT, assert_array_less,
runstring, temppath, suppress_warnings, break_cycles,
)
-from numpy.testing._private.utils import _no_tracing
+from numpy.testing._private.utils import requires_memory, _no_tracing
from numpy.core.tests._locales import CommaDecimalPointLocale
from numpy.lib.recfunctions import repack_fields
@@ -6691,6 +6691,15 @@ class TestDot:
# Strides in A cols and X
assert_dot_close(A_f_12, X_f_2, desired)
+ @pytest.mark.slow
+ @pytest.mark.parametrize("dtype", [np.float64, np.complex128])
+ @requires_memory(free_bytes=9*10**9) # complex case needs 8GiB+
+ def test_huge_vectordot(self, dtype):
+ # Large vector multiplications are chunked with 32bit BLAS
+ # Test that the chunking does the right thing, see also gh-22262
+ data = np.ones(2**30+100, dtype=dtype)
+ res = np.dot(data, data)
+ assert res == 2**30+100
class MatmulCommon: