diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2014-08-15 13:33:23 -0600 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-09-04 21:45:48 +0200 |
commit | a3919717c47070bc8a245ffa61b3a75dd7d5d879 (patch) | |
tree | 90a9fb972dfb5027f7975c853b04ae0017fb07f9 /numpy/core/blasdot | |
parent | 175cea4dc0590ae520f32476ffb9129b8524bcad (diff) | |
download | numpy-a3919717c47070bc8a245ffa61b3a75dd7d5d879.tar.gz |
MAINT, STY: Remove use of alterdot, restoredot in _dotblas.c.
These are no longer needed. Also do C style cleanups.
Diffstat (limited to 'numpy/core/blasdot')
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 259 |
1 files changed, 108 insertions, 151 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index c1b4e7b05..812d947b7 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -19,8 +19,11 @@ #include <limits.h> #include <stdio.h> + static char module_doc[] = -"This module provides a BLAS optimized\nmatrix multiply, inner product and dot for numpy arrays"; + "This module provides a BLAS optimized\n" + "matrix multiply, inner product and dot for numpy arrays"; + static PyArray_DotFunc *oldFunctions[NPY_NTYPES]; @@ -33,7 +36,7 @@ static void blas_dot(int typenum, npy_intp n, void *a, npy_intp stridea, void *b, npy_intp strideb, void *res) { - PyArray_DotFunc *dot = NULL; + PyArray_DotFunc *dot; dot = oldFunctions[typenum]; assert(dot != NULL); @@ -44,6 +47,7 @@ blas_dot(int typenum, npy_intp n, static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0}; static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0}; + /* * Helper: dispatch to appropriate cblas_?gemm for typenum. */ @@ -113,80 +117,37 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans, } -static npy_bool altered=NPY_FALSE; - /* - * alterdot() changes all dot functions to use blas. + * Initialize oldFunctions table. */ -static PyObject * -dotblas_alterdot(PyObject *NPY_UNUSED(dummy), PyObject *args) +static void +init_oldFunctions(void) { PyArray_Descr *descr; + int i; - if (!PyArg_ParseTuple(args, "")) return NULL; - - /* Replace the dot functions to the ones using blas */ - - if (!altered) { - descr = PyArray_DescrFromType(NPY_FLOAT); - oldFunctions[NPY_FLOAT] = descr->f->dotfunc; - - descr = PyArray_DescrFromType(NPY_DOUBLE); - oldFunctions[NPY_DOUBLE] = descr->f->dotfunc; + /* Initialise the array of dot functions */ + for (i = 0; i < NPY_NTYPES; i++) + oldFunctions[i] = NULL; - descr = PyArray_DescrFromType(NPY_CFLOAT); - oldFunctions[NPY_CFLOAT] = descr->f->dotfunc; + /* index dot functions we want to use here */ + descr = PyArray_DescrFromType(NPY_FLOAT); + oldFunctions[NPY_FLOAT] = descr->f->dotfunc; - descr = PyArray_DescrFromType(NPY_CDOUBLE); - oldFunctions[NPY_CDOUBLE] = descr->f->dotfunc; + descr = PyArray_DescrFromType(NPY_DOUBLE); + oldFunctions[NPY_DOUBLE] = descr->f->dotfunc; - altered = NPY_TRUE; - } + descr = PyArray_DescrFromType(NPY_CFLOAT); + oldFunctions[NPY_CFLOAT] = descr->f->dotfunc; - Py_INCREF(Py_None); - return Py_None; + descr = PyArray_DescrFromType(NPY_CDOUBLE); + oldFunctions[NPY_CDOUBLE] = descr->f->dotfunc; } -/* - * restoredot() restores dots to defaults. - */ -static PyObject * -dotblas_restoredot(PyObject *NPY_UNUSED(dummy), PyObject *args) -{ - PyArray_Descr *descr; - - if (!PyArg_ParseTuple(args, "")) return NULL; - - if (altered) { - descr = PyArray_DescrFromType(NPY_FLOAT); - descr->f->dotfunc = oldFunctions[NPY_FLOAT]; - oldFunctions[NPY_FLOAT] = NULL; - Py_XDECREF(descr); - - descr = PyArray_DescrFromType(NPY_DOUBLE); - descr->f->dotfunc = oldFunctions[NPY_DOUBLE]; - oldFunctions[NPY_DOUBLE] = NULL; - Py_XDECREF(descr); - - descr = PyArray_DescrFromType(NPY_CFLOAT); - descr->f->dotfunc = oldFunctions[NPY_CFLOAT]; - oldFunctions[NPY_CFLOAT] = NULL; - Py_XDECREF(descr); - - descr = PyArray_DescrFromType(NPY_CDOUBLE); - descr->f->dotfunc = oldFunctions[NPY_CDOUBLE]; - oldFunctions[NPY_CDOUBLE] = NULL; - Py_XDECREF(descr); - - altered = NPY_FALSE; - } - - Py_INCREF(Py_None); - return Py_None; -} typedef enum {_scalar, _column, _row, _matrix} MatrixShape; + static MatrixShape _select_matrix_shape(PyArrayObject *array) { @@ -212,21 +173,24 @@ _select_matrix_shape(PyArrayObject *array) } -/* This also makes sure that the data segment is aligned with - an itemsize address as well by returning one if not true. -*/ +/* + * This also makes sure that the data segment is aligned with + * an itemsize address as well by returning one if not true. + */ static int _bad_strides(PyArrayObject *ap) { - register int itemsize = PyArray_ITEMSIZE(ap); - register int i, N=PyArray_NDIM(ap); - register npy_intp *strides = PyArray_STRIDES(ap); + int itemsize = PyArray_ITEMSIZE(ap); + int i, N=PyArray_NDIM(ap); + npy_intp *strides = PyArray_STRIDES(ap); - if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) + if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) { return 1; - for (i=0; i<N; i++) { - if ((strides[i] < 0) || (strides[i] % itemsize) != 0) + } + for (i = 0; i < N; i++) { + if ((strides[i] < 0) || (strides[i] % itemsize) != 0) { return 1; + } } return 0; @@ -268,7 +232,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa Py_DECREF(module); } - errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwargs, + errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwargs, &override, 2); if (errval) { return NULL; @@ -295,7 +259,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa /* This function doesn't handle other types */ if ((typenum != NPY_DOUBLE && typenum != NPY_CDOUBLE && - typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) { + typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) { return PyArray_Return((PyArrayObject *)PyArray_MatrixProduct2( (PyObject *)op1, (PyObject *)op2, @@ -319,19 +283,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa } if ((PyArray_NDIM(ap1) > 2) || (PyArray_NDIM(ap2) > 2)) { - /* - * This function doesn't handle dimensions greater than 2 - * (or negative striding) -- other - * than to ensure the dot function is altered - */ - if (!altered) { - /* need to alter dot product */ - PyObject *tmp1, *tmp2; - tmp1 = PyTuple_New(0); - tmp2 = dotblas_alterdot(NULL, tmp1); - Py_DECREF(tmp1); - Py_DECREF(tmp2); - } + /* This function doesn't handle dimensions greater than 2 */ ret = (PyArrayObject *)PyArray_MatrixProduct2((PyObject *)ap1, (PyObject *)ap2, out); @@ -363,7 +315,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa PyArrayObject *oap1, *oap2; oap1 = ap1; oap2 = ap2; /* One of ap1 or ap2 is a scalar */ - if (ap1shape == _scalar) { /* Make ap2 the scalar */ + if (ap1shape == _scalar) { + /* Make ap2 the scalar */ PyArrayObject *t = ap1; ap1 = ap2; ap2 = t; @@ -453,8 +406,9 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa } nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2; - if (nd == 1) + if (nd == 1) { dimensions[0] = (PyArray_NDIM(ap1) == 2) ? PyArray_DIM(ap1, 0) : PyArray_DIM(ap2, 1); + } else if (nd == 2) { dimensions[0] = PyArray_DIM(ap1, 0); dimensions[1] = PyArray_DIM(ap2, 1); @@ -474,6 +428,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa if (out) { int d; + /* verify that out is usable */ if (Py_TYPE(out) != subtype || PyArray_NDIM(out) != nd || @@ -506,7 +461,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa } numbytes = PyArray_NBYTES(ret); memset(PyArray_DATA(ret), 0, numbytes); - if (numbytes==0 || l == 0) { + if (numbytes == 0 || l == 0) { Py_DECREF(ap1); Py_DECREF(ap2); return PyArray_Return(ret); @@ -523,7 +478,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa if (typenum == NPY_DOUBLE) { if (l == 1) { *((double *)PyArray_DATA(ret)) = *((double *)PyArray_DATA(ap2)) * - *((double *)PyArray_DATA(ap1)); + *((double *)PyArray_DATA(ap1)); } else if (ap1shape != _matrix) { cblas_daxpy(l, *((double *)PyArray_DATA(ap2)), (double *)PyArray_DATA(ap1), @@ -815,7 +770,9 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) PyTypeObject *subtype; double prior1, prior2; - if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) return NULL; + if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) { + return NULL; + } /* * Inner product using the BLAS. The product sum is taken along the last @@ -829,28 +786,22 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) /* This function doesn't handle other types */ if ((typenum != NPY_DOUBLE && typenum != NPY_CDOUBLE && - typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) { - return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(op1, op2)); + typenum != NPY_FLOAT && typenum != NPY_CFLOAT)) { + return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(op1, op2)); } ret = NULL; ap1 = (PyArrayObject *)PyArray_ContiguousFromObject(op1, typenum, 0, 0); - if (ap1 == NULL) return NULL; + if (ap1 == NULL) { + return NULL; + } ap2 = (PyArrayObject *)PyArray_ContiguousFromObject(op2, typenum, 0, 0); - if (ap2 == NULL) goto fail; + if (ap2 == NULL) { + goto fail; + } if ((PyArray_NDIM(ap1) > 2) || (PyArray_NDIM(ap2) > 2)) { - /* This function doesn't handle dimensions greater than 2 -- other - than to ensure the dot function is altered - */ - if (!altered) { - /* need to alter dot product */ - PyObject *tmp1, *tmp2; - tmp1 = PyTuple_New(0); - tmp2 = dotblas_alterdot(NULL, tmp1); - Py_DECREF(tmp1); - Py_DECREF(tmp2); - } + /* This function doesn't handle dimensions greater than 2 */ ret = (PyArrayObject *)PyArray_InnerProduct((PyObject *)ap1, (PyObject *)ap2); Py_DECREF(ap1); @@ -860,7 +811,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) { /* One of ap1 or ap2 is a scalar */ - if (PyArray_NDIM(ap1) == 0) { /* Make ap2 the scalar */ + if (PyArray_NDIM(ap1) == 0) { + /* Make ap2 the scalar */ PyArrayObject *t = ap1; ap1 = ap2; ap2 = t; @@ -871,8 +823,11 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) } nd = PyArray_NDIM(ap1); } - else { /* (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2) */ - /* Both ap1 and ap2 are vectors or matrices */ + else { + /* + * (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2) + * Both ap1 and ap2 are vectors or matrices + */ l = PyArray_DIM(ap1, PyArray_NDIM(ap1)-1); if (PyArray_DIM(ap2, PyArray_NDIM(ap2)-1) != l) { @@ -899,7 +854,9 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) (PyObject *)\ (prior2 > prior1 ? ap2 : ap1)); - if (ret == NULL) goto fail; + if (ret == NULL) { + goto fail; + } NPY_BEGIN_ALLOW_THREADS memset(PyArray_DATA(ret), 0, PyArray_NBYTES(ret)); @@ -937,8 +894,11 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); gemv(typenum, CblasRowMajor, CblasNoTrans, ap2, lda, ap1, 1, ret); } - else { /* (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2) */ - /* Matrix matrix multiplication -- Level 3 BLAS */ + else { + /* + * (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2) + * Matrix matrix multiplication -- Level 3 BLAS + */ lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1); ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1); gemm(typenum, CblasRowMajor, CblasNoTrans, CblasTrans, @@ -966,51 +926,56 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) */ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { PyObject *op1, *op2; - PyArrayObject *ap1=NULL, *ap2=NULL, *ret=NULL; + PyArrayObject *ap1 = NULL, *ap2 = NULL, *ret = NULL; int l; int typenum; npy_intp dimensions[NPY_MAXDIMS]; PyArray_Descr *type; - if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) return NULL; + if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) { + return NULL; + } /* * Conjugating dot product using the BLAS for vectors. * Multiplies op1 and op2, each of which must be vector. */ - typenum = PyArray_ObjectType(op1, 0); typenum = PyArray_ObjectType(op2, typenum); type = PyArray_DescrFromType(typenum); Py_INCREF(type); ap1 = (PyArrayObject *)PyArray_FromAny(op1, type, 0, 0, 0, NULL); - if (ap1==NULL) {Py_DECREF(type); goto fail;} + if (ap1 == NULL) { + Py_DECREF(type); + goto fail; + } op1 = PyArray_Flatten(ap1, 0); - if (op1==NULL) {Py_DECREF(type); goto fail;} + if (op1 == NULL) { + Py_DECREF(type); + goto fail; + } Py_DECREF(ap1); ap1 = (PyArrayObject *)op1; ap2 = (PyArrayObject *)PyArray_FromAny(op2, type, 0, 0, 0, NULL); - if (ap2==NULL) goto fail; + if (ap2 == NULL) { + goto fail; + } op2 = PyArray_Flatten(ap2, 0); - if (op2 == NULL) goto fail; + if (op2 == NULL) { + goto fail; + } Py_DECREF(ap2); ap2 = (PyArrayObject *)op2; if (typenum != NPY_FLOAT && typenum != NPY_DOUBLE && typenum != NPY_CFLOAT && typenum != NPY_CDOUBLE) { - if (!altered) { - /* need to alter dot product */ - PyObject *tmp1, *tmp2; - tmp1 = PyTuple_New(0); - tmp2 = dotblas_alterdot(NULL, tmp1); - Py_DECREF(tmp1); - Py_DECREF(tmp2); - } if (PyTypeNum_ISCOMPLEX(typenum)) { op1 = PyArray_Conjugate(ap1, NULL); - if (op1==NULL) goto fail; + if (op1 == NULL) { + goto fail; + } Py_DECREF(ap1); ap1 = (PyArrayObject *)op1; } @@ -1028,7 +993,9 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { l = PyArray_DIM(ap1, PyArray_NDIM(ap1)-1); ret = (PyArrayObject *)PyArray_SimpleNew(0, dimensions, typenum); - if (ret == NULL) goto fail; + if (ret == NULL) { + goto fail; + } NPY_BEGIN_ALLOW_THREADS; @@ -1060,11 +1027,15 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) { } static struct PyMethodDef dotblas_module_methods[] = { - {"dot", (PyCFunction)dotblas_matrixproduct, METH_VARARGS|METH_KEYWORDS, NULL}, - {"inner", (PyCFunction)dotblas_innerproduct, 1, NULL}, - {"vdot", (PyCFunction)dotblas_vdot, 1, NULL}, - {"alterdot", (PyCFunction)dotblas_alterdot, 1, NULL}, - {"restoredot", (PyCFunction)dotblas_restoredot, 1, NULL}, + {"dot", + (PyCFunction)dotblas_matrixproduct, + METH_VARARGS|METH_KEYWORDS, NULL}, + {"inner", + (PyCFunction)dotblas_innerproduct, + 1, NULL}, + {"vdot", + (PyCFunction)dotblas_vdot, + 1, NULL}, {NULL, NULL, 0, NULL} /* sentinel */ }; @@ -1092,31 +1063,17 @@ PyMODINIT_FUNC init_dotblas(void) #endif { #if defined(NPY_PY3K) - int i; - - PyObject *d, *s, *m; + PyObject *m; m = PyModule_Create(&moduledef); #else - int i; - - PyObject *d, *s; Py_InitModule3("_dotblas", dotblas_module_methods, module_doc); #endif - /* add the functions */ - /* Import the array object */ import_array(); - /* Initialise the array of dot functions */ - for (i = 0; i < NPY_NTYPES; i++) - oldFunctions[i] = NULL; - - /* alterdot at load */ - d = PyTuple_New(0); - s = dotblas_alterdot(NULL, d); - Py_DECREF(d); - Py_DECREF(s); + /* initialize oldFunctions table */ + init_oldFunctions(); return RETVAL; } |