summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c68
1 files changed, 68 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 67f325ba1..1789b2caf 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
}
+/*
+ * Helper: dispatch to appropriate cblas_?syrk for typenum.
+ */
+static void
+syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
+ int n, int k,
+ PyArrayObject *A, int lda, PyArrayObject *R)
+{
+ const void *Adata = PyArray_DATA(A);
+ void *Rdata = PyArray_DATA(R);
+ int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
+
+ npy_intp i;
+ npy_intp j;
+
+ switch (typenum) {
+ case NPY_DOUBLE:
+ cblas_dsyrk(order, CblasUpper, trans, n, k, 1.,
+ Adata, lda, 0., Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_FLOAT:
+ cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f,
+ Adata, lda, 0.f, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CDOUBLE:
+ cblas_zsyrk(order, CblasUpper, trans, n, k, oneD,
+ Adata, lda, zeroD, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CFLOAT:
+ cblas_csyrk(order, CblasUpper, trans, n, k, oneF,
+ Adata, lda, zeroF, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ }
+}
+
+
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;