diff options
author | Lars Buitinck <larsmans@gmail.com> | 2013-10-25 18:57:19 +0200 |
---|---|---|
committer | Lars Buitinck <larsmans@gmail.com> | 2013-10-26 12:38:58 +0200 |
commit | 14153dc41baae421e429ce6295267763a571172a (patch) | |
tree | f8777d67d5cee4297ff0c1abdb84aecb436b2b74 /numpy/core/blasdot | |
parent | a3e8c12ed88c6db2aa89cfbb7a69fc863e8a40dc (diff) | |
download | numpy-14153dc41baae421e429ce6295267763a571172a.tar.gz |
MAINT: dotblas: factor out all gemm and gemv calls
Diffstat (limited to 'numpy/core/blasdot')
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 250 |
1 files changed, 83 insertions, 167 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index 9119ade14..a8b34593e 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -1,5 +1,5 @@ /* - * This module provides a BLAS optimized\nmatrix multiply, + * This module provides a BLAS optimized matrix multiply, * inner product and dot for numpy arrays */ @@ -85,6 +85,78 @@ CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, } +static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0}; +static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0}; + +/* + * Helper: dispatch to appropriate cblas_?gemm for typenum. + */ +static void +gemm(int typenum, enum CBLAS_ORDER order, + enum CBLAS_TRANSPOSE transA, enum CBLAS_TRANSPOSE transB, + int m, int n, int k, + PyArrayObject *A, int lda, PyArrayObject *B, int ldb, PyArrayObject *R) +{ + const void *Adata = PyArray_DATA(A), *Bdata = PyArray_DATA(B); + void *Rdata = PyArray_DATA(R); + + int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1; + + switch (typenum) { + case NPY_DOUBLE: + cblas_dgemm(order, transA, transB, m, n, k, 1., + Adata, lda, Bdata, ldb, 0., Rdata, ldc); + break; + case NPY_FLOAT: + cblas_sgemm(order, transA, transB, m, n, k, 1.f, + Adata, lda, Bdata, ldb, 0.f, Rdata, ldc); + break; + case NPY_CDOUBLE: + cblas_zgemm(order, transA, transB, m, n, k, oneD, + Adata, lda, Bdata, ldb, zeroD, Rdata, ldc); + break; + case NPY_CFLOAT: + cblas_cgemm(order, transA, transB, m, n, k, oneF, + Adata, lda, Bdata, ldb, zeroF, Rdata, ldc); + break; + } +} + + +/* + * Helper: dispatch to appropriate cblas_?gemv for typenum. + */ +static void +gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans, + PyArrayObject *A, int lda, PyArrayObject *X, int incX, + PyArrayObject *R) +{ + const void *Adata = PyArray_DATA(A), *Xdata = PyArray_DATA(X); + void *Rdata = PyArray_DATA(R); + + int m = PyArray_DIM(A, 0), n = PyArray_DIM(A, 1); + + switch (typenum) { + case NPY_DOUBLE: + cblas_dgemv(order, trans, m, n, 1., Adata, lda, Xdata, incX, + 0., Rdata, 1); + break; + case NPY_FLOAT: + cblas_sgemv(order, trans, m, n, 1.f, Adata, lda, Xdata, incX, + 0.f, Rdata, 1); + break; + case NPY_CDOUBLE: + cblas_zgemv(order, trans, m, n, oneD, Adata, lda, Xdata, incX, + zeroD, Rdata, 1); + break; + case NPY_CFLOAT: + cblas_cgemv(order, trans, m, n, oneF, Adata, lda, Xdata, incX, + zeroF, Rdata, 1); + break; + } +} + + static npy_bool altered=NPY_FALSE; /* @@ -224,15 +296,11 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa PyObject *op1, *op2; PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL; int errval; - int j, l, lda, ldb, ldc; + int j, l, lda, ldb; int typenum, nd; npy_intp ap1stride = 0; npy_intp dimensions[NPY_MAXDIMS]; npy_intp numbytes; - static const float oneF[2] = {1.0, 0.0}; - static const float zeroF[2] = {0.0, 0.0}; - static const double oneD[2] = {1.0, 0.0}; - static const double zeroD[2] = {0.0, 0.0}; double prior1, prior2; PyTypeObject *subtype; PyArray_Descr *dtype; @@ -689,32 +757,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1); } ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2); - if (typenum == NPY_DOUBLE) { - cblas_dgemv(Order, CblasNoTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - 1.0, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), ap2s, 0.0, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_FLOAT) { - cblas_sgemv(Order, CblasNoTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - 1.0, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), ap2s, 0.0, (float *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zgemv(Order, - CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - oneD, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), ap2s, zeroD, - (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CFLOAT) { - cblas_cgemv(Order, - CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - oneF, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), ap2s, zeroF, - (float *)PyArray_DATA(ret), 1); - } + gemv(typenum, Order, CblasNoTrans, ap1, lda, ap2, ap2s, ret); NPY_END_ALLOW_THREADS; } else if (ap1shape != _matrix && ap2shape == _matrix) { @@ -746,30 +789,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa else { ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1); } - if (typenum == NPY_DOUBLE) { - cblas_dgemv(Order, - CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - 1.0, (double *)PyArray_DATA(ap2), lda, - (double *)PyArray_DATA(ap1), ap1s, 0.0, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_FLOAT) { - cblas_sgemv(Order, - CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - 1.0, (float *)PyArray_DATA(ap2), lda, - (float *)PyArray_DATA(ap1), ap1s, 0.0, (float *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zgemv(Order, - CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - oneD, (double *)PyArray_DATA(ap2), lda, - (double *)PyArray_DATA(ap1), ap1s, zeroD, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CFLOAT) { - cblas_cgemv(Order, - CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - oneF, (float *)PyArray_DATA(ap2), lda, - (float *)PyArray_DATA(ap1), ap1s, zeroF, (float *)PyArray_DATA(ret), 1); - } + gemv(typenum, Order, CblasTrans, ap2, lda, ap1, ap1s, ret); NPY_END_ALLOW_THREADS; } else { @@ -816,7 +836,6 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa M = PyArray_DIM(ap2, 0); lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1); ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); - ldc = (PyArray_DIM(ret, 1) > 1 ? PyArray_DIM(ret, 1) : 1); /* * Avoid temporary copies for arrays in Fortran order @@ -829,34 +848,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa Trans2 = CblasTrans; ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1); } - if (typenum == NPY_DOUBLE) { - cblas_dgemm(Order, Trans1, Trans2, - L, N, M, - 1.0, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), ldb, - 0.0, (double *)PyArray_DATA(ret), ldc); - } - else if (typenum == NPY_FLOAT) { - cblas_sgemm(Order, Trans1, Trans2, - L, N, M, - 1.0, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), ldb, - 0.0, (float *)PyArray_DATA(ret), ldc); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zgemm(Order, Trans1, Trans2, - L, N, M, - oneD, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), ldb, - zeroD, (double *)PyArray_DATA(ret), ldc); - } - else if (typenum == NPY_CFLOAT) { - cblas_cgemm(Order, Trans1, Trans2, - L, N, M, - oneF, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), ldb, - zeroF, (float *)PyArray_DATA(ret), ldc); - } + gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret); NPY_END_ALLOW_THREADS; } @@ -887,13 +879,9 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) { PyObject *op1, *op2; PyArrayObject *ap1, *ap2, *ret; - int j, l, lda, ldb, ldc; + int j, l, lda, ldb; int typenum, nd; npy_intp dimensions[NPY_MAXDIMS]; - static const float oneF[2] = {1.0, 0.0}; - static const float zeroF[2] = {0.0, 0.0}; - static const double oneD[2] = {1.0, 0.0}; - static const double zeroD[2] = {0.0, 0.0}; PyTypeObject *subtype; double prior1, prior2; @@ -1028,92 +1016,20 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) { /* Matrix-vector multiplication -- Level 2 BLAS */ lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1); - if (typenum == NPY_DOUBLE) { - cblas_dgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - 1.0, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), 1, 0.0, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - oneD, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), 1, zeroD, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_FLOAT) { - cblas_sgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - 1.0, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), 1, 0.0, (float *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CFLOAT) { - cblas_cgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1), - oneF, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), 1, zeroF, (float *)PyArray_DATA(ret), 1); - } + gemv(typenum, CblasRowMajor, CblasNoTrans, ap1, lda, ap2, 1, ret); } else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 2) { /* Vector matrix multiplication -- Level 2 BLAS */ lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); - if (typenum == NPY_DOUBLE) { - cblas_dgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - 1.0, (double *)PyArray_DATA(ap2), lda, - (double *)PyArray_DATA(ap1), 1, 0.0, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - oneD, (double *)PyArray_DATA(ap2), lda, - (double *)PyArray_DATA(ap1), 1, zeroD, (double *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_FLOAT) { - cblas_sgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - 1.0, (float *)PyArray_DATA(ap2), lda, - (float *)PyArray_DATA(ap1), 1, 0.0, (float *)PyArray_DATA(ret), 1); - } - else if (typenum == NPY_CFLOAT) { - cblas_cgemv(CblasRowMajor, - CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1), - oneF, (float *)PyArray_DATA(ap2), lda, - (float *)PyArray_DATA(ap1), 1, zeroF, (float *)PyArray_DATA(ret), 1); - } + gemv(typenum, CblasRowMajor, CblasNoTrans, ap2, lda, ap1, 1, ret); } else { /* (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2) */ /* Matrix matrix multiplication -- Level 3 BLAS */ lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1); ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); - ldc = (PyArray_DIM(ret, 1) > 1 ? PyArray_DIM(ret, 1) : 1); - if (typenum == NPY_DOUBLE) { - cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1), - 1.0, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), ldb, - 0.0, (double *)PyArray_DATA(ret), ldc); - } - else if (typenum == NPY_FLOAT) { - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1), - 1.0, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), ldb, - 0.0, (float *)PyArray_DATA(ret), ldc); - } - else if (typenum == NPY_CDOUBLE) { - cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1), - oneD, (double *)PyArray_DATA(ap1), lda, - (double *)PyArray_DATA(ap2), ldb, - zeroD, (double *)PyArray_DATA(ret), ldc); - } - else if (typenum == NPY_CFLOAT) { - cblas_cgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1), - oneF, (float *)PyArray_DATA(ap1), lda, - (float *)PyArray_DATA(ap2), ldb, - zeroF, (float *)PyArray_DATA(ret), ldc); - } + gemm(typenum, CblasRowMajor, CblasNoTrans, CblasTrans, + PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1), + ap1, lda, ap2, ldb, ret); } NPY_END_ALLOW_THREADS Py_DECREF(ap1); |