diff options
-rw-r--r-- | numpy/core/src/common/cblasfuncs.c | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/matmul.c.src | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 30 |
3 files changed, 38 insertions, 2 deletions
diff --git a/numpy/core/src/common/cblasfuncs.c b/numpy/core/src/common/cblasfuncs.c index 64a0569ad..39572fed4 100644 --- a/numpy/core/src/common/cblasfuncs.c +++ b/numpy/core/src/common/cblasfuncs.c @@ -188,6 +188,7 @@ _bad_strides(PyArrayObject *ap) int itemsize = PyArray_ITEMSIZE(ap); int i, N=PyArray_NDIM(ap); npy_intp *strides = PyArray_STRIDES(ap); + npy_intp *dims = PyArray_DIMS(ap); if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) { return 1; @@ -196,6 +197,9 @@ _bad_strides(PyArrayObject *ap) if ((strides[i] < 0) || (strides[i] % itemsize) != 0) { return 1; } + if ((strides[i] == 0 && dims[i] > 1)) { + return 1; + } } return 0; diff --git a/numpy/core/src/umath/matmul.c.src b/numpy/core/src/umath/matmul.c.src index 3ee0ec4f2..0cb3c82ad 100644 --- a/numpy/core/src/umath/matmul.c.src +++ b/numpy/core/src/umath/matmul.c.src @@ -312,8 +312,10 @@ NPY_NO_EXPORT void npy_bool i2blasable = i2_c_blasable || i2_f_blasable; npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz); npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz); - npy_bool vector_matrix = ((dm == 1) && i2blasable); - npy_bool matrix_vector = ((dp == 1) && i1blasable); + npy_bool vector_matrix = ((dm == 1) && i2blasable && + is_blasable2d(is1_n, sz, dn, 1, sz)); + npy_bool matrix_vector = ((dp == 1) && i1blasable && + is_blasable2d(is2_n, sz, dn, 1, sz)); #endif for (iOuter = 0; iOuter < dOuter; iOuter++, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 311b477f5..cdacdabbe 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2713,6 +2713,36 @@ class TestMethods(object): assert_equal(func(edf, edf.T), eddtf) assert_equal(func(edf.T, edf), edtdf) + @pytest.mark.parametrize('func', (np.dot, np.matmul)) + @pytest.mark.parametrize('dtype', 'ifdFD') + def test_no_dgemv(self, func, dtype): + # check vector arg for contiguous before gemv + # gh-12156 + a = np.arange(8.0, dtype=dtype).reshape(2, 4) + b = np.broadcast_to(1., (4, 1)) + ret1 = func(a, b) + ret2 = func(a, b.copy()) + assert_equal(ret1, ret2) + + ret1 = func(b.T, a.T) + ret2 = func(b.T.copy(), a.T) + assert_equal(ret1, ret2) + + # check for unaligned data + dt = np.dtype(dtype) + a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype) + a = a.reshape(2, 4) + b = a[0] + # make sure it is not aligned + assert_(a.__array_interface__['data'][0] % dt.itemsize != 0) + ret1 = func(a, b) + ret2 = func(a.copy(), b.copy()) + assert_equal(ret1, ret2) + + ret1 = func(b.T, a.T) + ret2 = func(b.T.copy(), a.T.copy()) + assert_equal(ret1, ret2) + def test_dot(self): a = np.array([[1, 0], [0, 1]]) b = np.array([[0, 1], [1, 0]]) |