summaryrefslogtreecommitdiff
path: root/numpy/core/blasdot
diff options
context:
space:
mode:
authorLars Buitinck <larsmans@gmail.com>2013-10-25 18:57:19 +0200
committerLars Buitinck <larsmans@gmail.com>2013-10-26 12:38:58 +0200
commit14153dc41baae421e429ce6295267763a571172a (patch)
treef8777d67d5cee4297ff0c1abdb84aecb436b2b74 /numpy/core/blasdot
parenta3e8c12ed88c6db2aa89cfbb7a69fc863e8a40dc (diff)
downloadnumpy-14153dc41baae421e429ce6295267763a571172a.tar.gz
MAINT: dotblas: factor out all gemm and gemv calls
Diffstat (limited to 'numpy/core/blasdot')
-rw-r--r--numpy/core/blasdot/_dotblas.c250
1 files changed, 83 insertions, 167 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index 9119ade14..a8b34593e 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -1,5 +1,5 @@
/*
- * This module provides a BLAS optimized\nmatrix multiply,
+ * This module provides a BLAS optimized matrix multiply,
* inner product and dot for numpy arrays
*/
@@ -85,6 +85,78 @@ CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
}
+static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0};
+static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0};
+
+/*
+ * Helper: dispatch to appropriate cblas_?gemm for typenum.
+ */
+static void
+gemm(int typenum, enum CBLAS_ORDER order,
+ enum CBLAS_TRANSPOSE transA, enum CBLAS_TRANSPOSE transB,
+ int m, int n, int k,
+ PyArrayObject *A, int lda, PyArrayObject *B, int ldb, PyArrayObject *R)
+{
+ const void *Adata = PyArray_DATA(A), *Bdata = PyArray_DATA(B);
+ void *Rdata = PyArray_DATA(R);
+
+ int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
+
+ switch (typenum) {
+ case NPY_DOUBLE:
+ cblas_dgemm(order, transA, transB, m, n, k, 1.,
+ Adata, lda, Bdata, ldb, 0., Rdata, ldc);
+ break;
+ case NPY_FLOAT:
+ cblas_sgemm(order, transA, transB, m, n, k, 1.f,
+ Adata, lda, Bdata, ldb, 0.f, Rdata, ldc);
+ break;
+ case NPY_CDOUBLE:
+ cblas_zgemm(order, transA, transB, m, n, k, oneD,
+ Adata, lda, Bdata, ldb, zeroD, Rdata, ldc);
+ break;
+ case NPY_CFLOAT:
+ cblas_cgemm(order, transA, transB, m, n, k, oneF,
+ Adata, lda, Bdata, ldb, zeroF, Rdata, ldc);
+ break;
+ }
+}
+
+
+/*
+ * Helper: dispatch to appropriate cblas_?gemv for typenum.
+ */
+static void
+gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
+ PyArrayObject *A, int lda, PyArrayObject *X, int incX,
+ PyArrayObject *R)
+{
+ const void *Adata = PyArray_DATA(A), *Xdata = PyArray_DATA(X);
+ void *Rdata = PyArray_DATA(R);
+
+ int m = PyArray_DIM(A, 0), n = PyArray_DIM(A, 1);
+
+ switch (typenum) {
+ case NPY_DOUBLE:
+ cblas_dgemv(order, trans, m, n, 1., Adata, lda, Xdata, incX,
+ 0., Rdata, 1);
+ break;
+ case NPY_FLOAT:
+ cblas_sgemv(order, trans, m, n, 1.f, Adata, lda, Xdata, incX,
+ 0.f, Rdata, 1);
+ break;
+ case NPY_CDOUBLE:
+ cblas_zgemv(order, trans, m, n, oneD, Adata, lda, Xdata, incX,
+ zeroD, Rdata, 1);
+ break;
+ case NPY_CFLOAT:
+ cblas_cgemv(order, trans, m, n, oneF, Adata, lda, Xdata, incX,
+ zeroF, Rdata, 1);
+ break;
+ }
+}
+
+
static npy_bool altered=NPY_FALSE;
/*
@@ -224,15 +296,11 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
PyObject *op1, *op2;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL;
int errval;
- int j, l, lda, ldb, ldc;
+ int j, l, lda, ldb;
int typenum, nd;
npy_intp ap1stride = 0;
npy_intp dimensions[NPY_MAXDIMS];
npy_intp numbytes;
- static const float oneF[2] = {1.0, 0.0};
- static const float zeroF[2] = {0.0, 0.0};
- static const double oneD[2] = {1.0, 0.0};
- static const double zeroD[2] = {0.0, 0.0};
double prior1, prior2;
PyTypeObject *subtype;
PyArray_Descr *dtype;
@@ -689,32 +757,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1);
}
ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2);
- if (typenum == NPY_DOUBLE) {
- cblas_dgemv(Order, CblasNoTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- 1.0, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), ap2s, 0.0, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_sgemv(Order, CblasNoTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- 1.0, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), ap2s, 0.0, (float *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zgemv(Order,
- CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- oneD, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), ap2s, zeroD,
- (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cgemv(Order,
- CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- oneF, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), ap2s, zeroF,
- (float *)PyArray_DATA(ret), 1);
- }
+ gemv(typenum, Order, CblasNoTrans, ap1, lda, ap2, ap2s, ret);
NPY_END_ALLOW_THREADS;
}
else if (ap1shape != _matrix && ap2shape == _matrix) {
@@ -746,30 +789,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
else {
ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1);
}
- if (typenum == NPY_DOUBLE) {
- cblas_dgemv(Order,
- CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- 1.0, (double *)PyArray_DATA(ap2), lda,
- (double *)PyArray_DATA(ap1), ap1s, 0.0, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_sgemv(Order,
- CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- 1.0, (float *)PyArray_DATA(ap2), lda,
- (float *)PyArray_DATA(ap1), ap1s, 0.0, (float *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zgemv(Order,
- CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- oneD, (double *)PyArray_DATA(ap2), lda,
- (double *)PyArray_DATA(ap1), ap1s, zeroD, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cgemv(Order,
- CblasTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- oneF, (float *)PyArray_DATA(ap2), lda,
- (float *)PyArray_DATA(ap1), ap1s, zeroF, (float *)PyArray_DATA(ret), 1);
- }
+ gemv(typenum, Order, CblasTrans, ap2, lda, ap1, ap1s, ret);
NPY_END_ALLOW_THREADS;
}
else {
@@ -816,7 +836,6 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
M = PyArray_DIM(ap2, 0);
lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
- ldc = (PyArray_DIM(ret, 1) > 1 ? PyArray_DIM(ret, 1) : 1);
/*
* Avoid temporary copies for arrays in Fortran order
@@ -829,34 +848,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
Trans2 = CblasTrans;
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
}
- if (typenum == NPY_DOUBLE) {
- cblas_dgemm(Order, Trans1, Trans2,
- L, N, M,
- 1.0, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), ldb,
- 0.0, (double *)PyArray_DATA(ret), ldc);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_sgemm(Order, Trans1, Trans2,
- L, N, M,
- 1.0, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), ldb,
- 0.0, (float *)PyArray_DATA(ret), ldc);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zgemm(Order, Trans1, Trans2,
- L, N, M,
- oneD, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), ldb,
- zeroD, (double *)PyArray_DATA(ret), ldc);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cgemm(Order, Trans1, Trans2,
- L, N, M,
- oneF, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), ldb,
- zeroF, (float *)PyArray_DATA(ret), ldc);
- }
+ gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
NPY_END_ALLOW_THREADS;
}
@@ -887,13 +879,9 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
{
PyObject *op1, *op2;
PyArrayObject *ap1, *ap2, *ret;
- int j, l, lda, ldb, ldc;
+ int j, l, lda, ldb;
int typenum, nd;
npy_intp dimensions[NPY_MAXDIMS];
- static const float oneF[2] = {1.0, 0.0};
- static const float zeroF[2] = {0.0, 0.0};
- static const double oneD[2] = {1.0, 0.0};
- static const double zeroD[2] = {0.0, 0.0};
PyTypeObject *subtype;
double prior1, prior2;
@@ -1028,92 +1016,20 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
else if (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 1) {
/* Matrix-vector multiplication -- Level 2 BLAS */
lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
- if (typenum == NPY_DOUBLE) {
- cblas_dgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- 1.0, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), 1, 0.0, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- oneD, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), 1, zeroD, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_sgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- 1.0, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), 1, 0.0, (float *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap1, 0), PyArray_DIM(ap1, 1),
- oneF, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), 1, zeroF, (float *)PyArray_DATA(ret), 1);
- }
+ gemv(typenum, CblasRowMajor, CblasNoTrans, ap1, lda, ap2, 1, ret);
}
else if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 2) {
/* Vector matrix multiplication -- Level 2 BLAS */
lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
- if (typenum == NPY_DOUBLE) {
- cblas_dgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- 1.0, (double *)PyArray_DATA(ap2), lda,
- (double *)PyArray_DATA(ap1), 1, 0.0, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- oneD, (double *)PyArray_DATA(ap2), lda,
- (double *)PyArray_DATA(ap1), 1, zeroD, (double *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_sgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- 1.0, (float *)PyArray_DATA(ap2), lda,
- (float *)PyArray_DATA(ap1), 1, 0.0, (float *)PyArray_DATA(ret), 1);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cgemv(CblasRowMajor,
- CblasNoTrans, PyArray_DIM(ap2, 0), PyArray_DIM(ap2, 1),
- oneF, (float *)PyArray_DATA(ap2), lda,
- (float *)PyArray_DATA(ap1), 1, zeroF, (float *)PyArray_DATA(ret), 1);
- }
+ gemv(typenum, CblasRowMajor, CblasNoTrans, ap2, lda, ap1, 1, ret);
}
else { /* (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2) */
/* Matrix matrix multiplication -- Level 3 BLAS */
lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
- ldc = (PyArray_DIM(ret, 1) > 1 ? PyArray_DIM(ret, 1) : 1);
- if (typenum == NPY_DOUBLE) {
- cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1),
- 1.0, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), ldb,
- 0.0, (double *)PyArray_DATA(ret), ldc);
- }
- else if (typenum == NPY_FLOAT) {
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1),
- 1.0, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), ldb,
- 0.0, (float *)PyArray_DATA(ret), ldc);
- }
- else if (typenum == NPY_CDOUBLE) {
- cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1),
- oneD, (double *)PyArray_DATA(ap1), lda,
- (double *)PyArray_DATA(ap2), ldb,
- zeroD, (double *)PyArray_DATA(ret), ldc);
- }
- else if (typenum == NPY_CFLOAT) {
- cblas_cgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1),
- oneF, (float *)PyArray_DATA(ap1), lda,
- (float *)PyArray_DATA(ap2), ldb,
- zeroF, (float *)PyArray_DATA(ret), ldc);
- }
+ gemm(typenum, CblasRowMajor, CblasNoTrans, CblasTrans,
+ PyArray_DIM(ap1, 0), PyArray_DIM(ap2, 0), PyArray_DIM(ap1, 1),
+ ap1, lda, ap2, ldb, ret);
}
NPY_END_ALLOW_THREADS
Py_DECREF(ap1);