summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/common/cblasfuncs.c4
-rw-r--r--numpy/core/src/umath/matmul.c.src6
-rw-r--r--numpy/core/tests/test_multiarray.py30
3 files changed, 38 insertions, 2 deletions
diff --git a/numpy/core/src/common/cblasfuncs.c b/numpy/core/src/common/cblasfuncs.c
index 64a0569ad..39572fed4 100644
--- a/numpy/core/src/common/cblasfuncs.c
+++ b/numpy/core/src/common/cblasfuncs.c
@@ -188,6 +188,7 @@ _bad_strides(PyArrayObject *ap)
int itemsize = PyArray_ITEMSIZE(ap);
int i, N=PyArray_NDIM(ap);
npy_intp *strides = PyArray_STRIDES(ap);
+ npy_intp *dims = PyArray_DIMS(ap);
if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) {
return 1;
@@ -196,6 +197,9 @@ _bad_strides(PyArrayObject *ap)
if ((strides[i] < 0) || (strides[i] % itemsize) != 0) {
return 1;
}
+ if ((strides[i] == 0 && dims[i] > 1)) {
+ return 1;
+ }
}
return 0;
diff --git a/numpy/core/src/umath/matmul.c.src b/numpy/core/src/umath/matmul.c.src
index 3ee0ec4f2..0cb3c82ad 100644
--- a/numpy/core/src/umath/matmul.c.src
+++ b/numpy/core/src/umath/matmul.c.src
@@ -312,8 +312,10 @@ NPY_NO_EXPORT void
npy_bool i2blasable = i2_c_blasable || i2_f_blasable;
npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz);
npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz);
- npy_bool vector_matrix = ((dm == 1) && i2blasable);
- npy_bool matrix_vector = ((dp == 1) && i1blasable);
+ npy_bool vector_matrix = ((dm == 1) && i2blasable &&
+ is_blasable2d(is1_n, sz, dn, 1, sz));
+ npy_bool matrix_vector = ((dp == 1) && i1blasable &&
+ is_blasable2d(is2_n, sz, dn, 1, sz));
#endif
for (iOuter = 0; iOuter < dOuter; iOuter++,
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 311b477f5..cdacdabbe 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2713,6 +2713,36 @@ class TestMethods(object):
assert_equal(func(edf, edf.T), eddtf)
assert_equal(func(edf.T, edf), edtdf)
+ @pytest.mark.parametrize('func', (np.dot, np.matmul))
+ @pytest.mark.parametrize('dtype', 'ifdFD')
+ def test_no_dgemv(self, func, dtype):
+ # check vector arg for contiguous before gemv
+ # gh-12156
+ a = np.arange(8.0, dtype=dtype).reshape(2, 4)
+ b = np.broadcast_to(1., (4, 1))
+ ret1 = func(a, b)
+ ret2 = func(a, b.copy())
+ assert_equal(ret1, ret2)
+
+ ret1 = func(b.T, a.T)
+ ret2 = func(b.T.copy(), a.T)
+ assert_equal(ret1, ret2)
+
+ # check for unaligned data
+ dt = np.dtype(dtype)
+ a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype)
+ a = a.reshape(2, 4)
+ b = a[0]
+ # make sure it is not aligned
+ assert_(a.__array_interface__['data'][0] % dt.itemsize != 0)
+ ret1 = func(a, b)
+ ret2 = func(a.copy(), b.copy())
+ assert_equal(ret1, ret2)
+
+ ret1 = func(b.T, a.T)
+ ret2 = func(b.T.copy(), a.T.copy())
+ assert_equal(ret1, ret2)
+
def test_dot(self):
a = np.array([[1, 0], [0, 1]])
b = np.array([[0, 1], [1, 0]])