summaryrefslogtreecommitdiff
path: root/numpy/core/blasdot
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-08-15 13:33:23 -0600
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-09-04 21:45:48 +0200
commita3919717c47070bc8a245ffa61b3a75dd7d5d879 (patch)
tree90a9fb972dfb5027f7975c853b04ae0017fb07f9 /numpy/core/blasdot
parent175cea4dc0590ae520f32476ffb9129b8524bcad (diff)
downloadnumpy-a3919717c47070bc8a245ffa61b3a75dd7d5d879.tar.gz
MAINT, STY: Remove use of alterdot, restoredot in _dotblas.c.
These are no longer needed. Also do C style cleanups.
Diffstat (limited to 'numpy/core/blasdot')
-rw-r--r--numpy/core/blasdot/_dotblas.c259
1 files changed, 108 insertions, 151 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index c1b4e7b05..812d947b7 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -19,8 +19,11 @@
#include <limits.h>
#include <stdio.h>
+
static char module_doc[] =
-"This module provides a BLAS optimized\nmatrix multiply, inner product and dot for numpy arrays";
+ "This module provides a BLAS optimized\n"
+ "matrix multiply, inner product and dot for numpy arrays";
+
static PyArray_DotFunc *oldFunctions[NPY_NTYPES];
@@ -33,7 +36,7 @@ static void
blas_dot(int typenum, npy_intp n,
void *a, npy_intp stridea, void *b, npy_intp strideb, void *res)
{
- PyArray_DotFunc *dot = NULL;
+ PyArray_DotFunc *dot;
dot = oldFunctions[typenum];
assert(dot != NULL);
@@ -44,6 +47,7 @@ blas_dot(int typenum, npy_intp n,
static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0};
static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0};
+
/*
* Helper: dispatch to appropriate cblas_?gemm for typenum.
*/
@@ -113,80 +117,37 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
}
-static npy_bool altered=NPY_FALSE;
-
/*
- * alterdot() changes all dot functions to use blas.
+ * Initialize oldFunctions table.
*/
-static PyObject *
-dotblas_alterdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
+static void
+init_oldFunctions(void)
{
PyArray_Descr *descr;
+ int i;
- if (!PyArg_ParseTuple(args, "")) return NULL;
-
- /* Replace the dot functions to the ones using blas */
-
- if (!altered) {
- descr = PyArray_DescrFromType(NPY_FLOAT);
- oldFunctions[NPY_FLOAT] = descr->f->dotfunc;
-
- descr = PyArray_DescrFromType(NPY_DOUBLE);
- oldFunctions[NPY_DOUBLE] = descr->f->dotfunc;
+ /* Initialise the array of dot functions */
+ for (i = 0; i < NPY_NTYPES; i++)
+ oldFunctions[i] = NULL;
- descr = PyArray_DescrFromType(NPY_CFLOAT);
- oldFunctions[NPY_CFLOAT] = descr->f->dotfunc;
+ /* index dot functions we want to use here */
+ descr = PyArray_DescrFromType(NPY_FLOAT);
+ oldFunctions[NPY_FLOAT] = descr->f->dotfunc;
- descr = PyArray_DescrFromType(NPY_CDOUBLE);
- oldFunctions[NPY_CDOUBLE] = descr->f->dotfunc;
+ descr = PyArray_DescrFromType(NPY_DOUBLE);
+ oldFunctions[NPY_DOUBLE] = descr->f->dotfunc;
- altered = NPY_TRUE;
- }
+ descr = PyArray_DescrFromType(NPY_CFLOAT);
+ oldFunctions[NPY_CFLOAT] = descr->f->dotfunc;
- Py_INCREF(Py_None);
- return Py_None;
+ descr = PyArray_DescrFromType(NPY_CDOUBLE);
+ oldFunctions[NPY_CDOUBLE] = descr->f->dotfunc;
}
-/*
- * restoredot() restores dots to defaults.
- */
-static PyObject *
-dotblas_restoredot(PyObject *NPY_UNUSED(dummy), PyObject *args)
-{
- PyArray_Descr *descr;
-
- if (!PyArg_ParseTuple(args, "")) return NULL;
-
- if (altered) {
- descr = PyArray_DescrFromType(NPY_FLOAT);
- descr->f->dotfunc = oldFunctions[NPY_FLOAT];
- oldFunctions[NPY_FLOAT] = NULL;
- Py_XDECREF(descr);
-
- descr = PyArray_DescrFromType(NPY_DOUBLE);
- descr->f->dotfunc = oldFunctions[NPY_DOUBLE];
- oldFunctions[NPY_DOUBLE] = NULL;
- Py_XDECREF(descr);
-
- descr = PyArray_DescrFromType(NPY_CFLOAT);
- descr->f->dotfunc = oldFunctions[NPY_CFLOAT];
- oldFunctions[NPY_CFLOAT] = NULL;
- Py_XDECREF(descr);
-
- descr = PyArray_DescrFromType(NPY_CDOUBLE);
- descr->f->dotfunc = oldFunctions[NPY_CDOUBLE];
- oldFunctions[NPY_CDOUBLE] = NULL;
- Py_XDECREF(descr);
-
- altered = NPY_FALSE;
- }
-
- Py_INCREF(Py_None);
- return Py_None;
-}
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
+
static MatrixShape
_select_matrix_shape(PyArrayObject *array)
{
@@ -212,21 +173,24 @@ _select_matrix_shape(PyArrayObject *array)
}
-/* This also makes sure that the data segment is aligned with
- an itemsize address as well by returning one if not true.
-*/
+/*
+ * This also makes sure that the data segment is aligned with
+ * an itemsize address as well by returning one if not true.
+ */
static int
_bad_strides(PyArrayObject *ap)
{
- register int itemsize = PyArray_ITEMSIZE(ap);
- register int i, N=PyArray_NDIM(ap);
- register npy_intp *strides = PyArray_STRIDES(ap);
+ int itemsize = PyArray_ITEMSIZE(ap);
+ int i, N=PyArray_NDIM(ap);
+ npy_intp *strides = PyArray_STRIDES(ap);
- if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0)
+ if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) {
return 1;
- for (i=0; i<N; i++) {
- if ((strides[i] < 0) || (strides[i] % itemsize) != 0)
+ }
+ for (i = 0; i < N; i++) {
+ if ((strides[i] < 0) || (strides[i] % itemsize) != 0) {
return 1;
+ }
}
return 0;
@@ -268,7 +232,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
Py_DECREF(module);
}
- errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwargs,
+ errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwargs,
&override, 2);
if (errval) {
return NULL;
@@ -295,7 +259,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
/* This function doesn't handle other types */
if ((typenum != NPY_DOUBLE && typenum != NPY_CDOUBLE &&
- typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) {
+ typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) {
return PyArray_Return((PyArrayObject *)PyArray_MatrixProduct2(
(PyObject *)op1,
(PyObject *)op2,
@@ -319,19 +283,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
}
if ((PyArray_NDIM(ap1) > 2) || (PyArray_NDIM(ap2) > 2)) {
- /*
- * This function doesn't handle dimensions greater than 2
- * (or negative striding) -- other
- * than to ensure the dot function is altered
- */
- if (!altered) {
- /* need to alter dot product */
- PyObject *tmp1, *tmp2;
- tmp1 = PyTuple_New(0);
- tmp2 = dotblas_alterdot(NULL, tmp1);
- Py_DECREF(tmp1);
- Py_DECREF(tmp2);
- }
+ /* This function doesn't handle dimensions greater than 2 */
ret = (PyArrayObject *)PyArray_MatrixProduct2((PyObject *)ap1,
(PyObject *)ap2,
out);
@@ -363,7 +315,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
PyArrayObject *oap1, *oap2;
oap1 = ap1; oap2 = ap2;
/* One of ap1 or ap2 is a scalar */
- if (ap1shape == _scalar) { /* Make ap2 the scalar */
+ if (ap1shape == _scalar) {
+ /* Make ap2 the scalar */
PyArrayObject *t = ap1;
ap1 = ap2;
ap2 = t;
@@ -453,8 +406,9 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
}
nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
- if (nd == 1)
+ if (nd == 1) {
dimensions[0] = (PyArray_NDIM(ap1) == 2) ? PyArray_DIM(ap1, 0) : PyArray_DIM(ap2, 1);
+ }
else if (nd == 2) {
dimensions[0] = PyArray_DIM(ap1, 0);
dimensions[1] = PyArray_DIM(ap2, 1);
@@ -474,6 +428,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
if (out) {
int d;
+
/* verify that out is usable */
if (Py_TYPE(out) != subtype ||
PyArray_NDIM(out) != nd ||
@@ -506,7 +461,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
}
numbytes = PyArray_NBYTES(ret);
memset(PyArray_DATA(ret), 0, numbytes);
- if (numbytes==0 || l == 0) {
+ if (numbytes == 0 || l == 0) {
Py_DECREF(ap1);
Py_DECREF(ap2);
return PyArray_Return(ret);
@@ -523,7 +478,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
if (typenum == NPY_DOUBLE) {
if (l == 1) {
*((double *)PyArray_DATA(ret)) = *((double *)PyArray_DATA(ap2)) *
- *((double *)PyArray_DATA(ap1));
+ *((double *)PyArray_DATA(ap1));
}
else if (ap1shape != _matrix) {
cblas_daxpy(l, *((double *)PyArray_DATA(ap2)), (double *)PyArray_DATA(ap1),
@@ -815,7 +770,9 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
PyTypeObject *subtype;
double prior1, prior2;
- if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) return NULL;
+ if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) {
+ return NULL;
+ }
/*
* Inner product using the BLAS. The product sum is taken along the last
@@ -829,28 +786,22 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
/* This function doesn't handle other types */
if ((typenum != NPY_DOUBLE && typenum != NPY_CDOUBLE &&
- typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) {
- return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(op1, op2));
+ typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) {
+ return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(op1, op2));
}
ret = NULL;
ap1 = (PyArrayObject *)PyArray_ContiguousFromObject(op1, typenum, 0, 0);
- if (ap1 == NULL) return NULL;
+ if (ap1 == NULL) {
+ return NULL;
+ }
ap2 = (PyArrayObject *)PyArray_ContiguousFromObject(op2, typenum, 0, 0);
- if (ap2 == NULL) goto fail;
+ if (ap2 == NULL) {
+ goto fail;
+ }
if ((PyArray_NDIM(ap1) > 2) || (PyArray_NDIM(ap2) > 2)) {
- /* This function doesn't handle dimensions greater than 2 -- other
- than to ensure the dot function is altered
- */
- if (!altered) {
- /* need to alter dot product */
- PyObject *tmp1, *tmp2;
- tmp1 = PyTuple_New(0);
- tmp2 = dotblas_alterdot(NULL, tmp1);
- Py_DECREF(tmp1);
- Py_DECREF(tmp2);
- }
+ /* This function doesn't handle dimensions greater than 2 */
ret = (PyArrayObject *)PyArray_InnerProduct((PyObject *)ap1,
(PyObject *)ap2);
Py_DECREF(ap1);
@@ -860,7 +811,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
/* One of ap1 or ap2 is a scalar */
- if (PyArray_NDIM(ap1) == 0) { /* Make ap2 the scalar */
+ if (PyArray_NDIM(ap1) == 0) {
+ /* Make ap2 the scalar */
PyArrayObject *t = ap1;
ap1 = ap2;
ap2 = t;
@@ -871,8 +823,11 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
}
nd = PyArray_NDIM(ap1);
}
- else { /* (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2) */
- /* Both ap1 and ap2 are vectors or matrices */
+ else {
+ /*
+ * (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2)
+ * Both ap1 and ap2 are vectors or matrices
+ */
l = PyArray_DIM(ap1, PyArray_NDIM(ap1)-1);
if (PyArray_DIM(ap2, PyArray_NDIM(ap2)-1) != l) {
@@ -899,7 +854,9 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
(PyObject *)\
(prior2 > prior1 ? ap2 : ap1));
- if (ret == NULL) goto fail;
+ if (ret == NULL) {
+ goto fail;
+ }
NPY_BEGIN_ALLOW_THREADS
memset(PyArray_DATA(ret), 0, PyArray_NBYTES(ret));
@@ -937,8 +894,11 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
gemv(typenum, CblasRowMajor, CblasNoTrans, ap2, lda, ap1, 1, ret);
}
- else { /* (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2) */
- /* Matrix matrix multiplication -- Level 3 BLAS */
+ else {
+ /*
+ * (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2)
+ * Matrix matrix multiplication -- Level 3 BLAS
+ */
lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
gemm(typenum, CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -966,51 +926,56 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
*/
static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
PyObject *op1, *op2;
- PyArrayObject *ap1=NULL, *ap2=NULL, *ret=NULL;
+ PyArrayObject *ap1 = NULL, *ap2 = NULL, *ret = NULL;
int l;
int typenum;
npy_intp dimensions[NPY_MAXDIMS];
PyArray_Descr *type;
- if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) return NULL;
+ if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) {
+ return NULL;
+ }
/*
* Conjugating dot product using the BLAS for vectors.
* Multiplies op1 and op2, each of which must be vector.
*/
-
typenum = PyArray_ObjectType(op1, 0);
typenum = PyArray_ObjectType(op2, typenum);
type = PyArray_DescrFromType(typenum);
Py_INCREF(type);
ap1 = (PyArrayObject *)PyArray_FromAny(op1, type, 0, 0, 0, NULL);
- if (ap1==NULL) {Py_DECREF(type); goto fail;}
+ if (ap1 == NULL) {
+ Py_DECREF(type);
+ goto fail;
+ }
op1 = PyArray_Flatten(ap1, 0);
- if (op1==NULL) {Py_DECREF(type); goto fail;}
+ if (op1 == NULL) {
+ Py_DECREF(type);
+ goto fail;
+ }
Py_DECREF(ap1);
ap1 = (PyArrayObject *)op1;
ap2 = (PyArrayObject *)PyArray_FromAny(op2, type, 0, 0, 0, NULL);
- if (ap2==NULL) goto fail;
+ if (ap2 == NULL) {
+ goto fail;
+ }
op2 = PyArray_Flatten(ap2, 0);
- if (op2 == NULL) goto fail;
+ if (op2 == NULL) {
+ goto fail;
+ }
Py_DECREF(ap2);
ap2 = (PyArrayObject *)op2;
if (typenum != NPY_FLOAT && typenum != NPY_DOUBLE &&
typenum != NPY_CFLOAT && typenum != NPY_CDOUBLE) {
- if (!altered) {
- /* need to alter dot product */
- PyObject *tmp1, *tmp2;
- tmp1 = PyTuple_New(0);
- tmp2 = dotblas_alterdot(NULL, tmp1);
- Py_DECREF(tmp1);
- Py_DECREF(tmp2);
- }
if (PyTypeNum_ISCOMPLEX(typenum)) {
op1 = PyArray_Conjugate(ap1, NULL);
- if (op1==NULL) goto fail;
+ if (op1 == NULL) {
+ goto fail;
+ }
Py_DECREF(ap1);
ap1 = (PyArrayObject *)op1;
}
@@ -1028,7 +993,9 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
l = PyArray_DIM(ap1, PyArray_NDIM(ap1)-1);
ret = (PyArrayObject *)PyArray_SimpleNew(0, dimensions, typenum);
- if (ret == NULL) goto fail;
+ if (ret == NULL) {
+ goto fail;
+ }
NPY_BEGIN_ALLOW_THREADS;
@@ -1060,11 +1027,15 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
}
static struct PyMethodDef dotblas_module_methods[] = {
- {"dot", (PyCFunction)dotblas_matrixproduct, METH_VARARGS|METH_KEYWORDS, NULL},
- {"inner", (PyCFunction)dotblas_innerproduct, 1, NULL},
- {"vdot", (PyCFunction)dotblas_vdot, 1, NULL},
- {"alterdot", (PyCFunction)dotblas_alterdot, 1, NULL},
- {"restoredot", (PyCFunction)dotblas_restoredot, 1, NULL},
+ {"dot",
+ (PyCFunction)dotblas_matrixproduct,
+ METH_VARARGS|METH_KEYWORDS, NULL},
+ {"inner",
+ (PyCFunction)dotblas_innerproduct,
+ 1, NULL},
+ {"vdot",
+ (PyCFunction)dotblas_vdot,
+ 1, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
@@ -1092,31 +1063,17 @@ PyMODINIT_FUNC init_dotblas(void)
#endif
{
#if defined(NPY_PY3K)
- int i;
-
- PyObject *d, *s, *m;
+ PyObject *m;
m = PyModule_Create(&moduledef);
#else
- int i;
-
- PyObject *d, *s;
Py_InitModule3("_dotblas", dotblas_module_methods, module_doc);
#endif
- /* add the functions */
-
/* Import the array object */
import_array();
- /* Initialise the array of dot functions */
- for (i = 0; i < NPY_NTYPES; i++)
- oldFunctions[i] = NULL;
-
- /* alterdot at load */
- d = PyTuple_New(0);
- s = dotblas_alterdot(NULL, d);
- Py_DECREF(d);
- Py_DECREF(s);
+ /* initialize oldFunctions table */
+ init_oldFunctions();
return RETVAL;
}