diff options
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 101 |
1 files changed, 50 insertions, 51 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 295924fea..0dc920c7c 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -6,7 +6,7 @@ Original file Copyright (c) 1995, 1996, 1997 Jim Hugunin, hugunin@mit.edu - Modified for numpy_core in 2005 + Modified extensively for numpy in 2005 Travis E. Oliphant Assistant Professor at @@ -2249,21 +2249,25 @@ static PyObject * PyArray_InnerProduct(PyObject *op1, PyObject *op2) { PyArrayObject *ap1, *ap2, *ret=NULL; - intp i, j, l, i1, i2, n1, n2; + PyArrayIterObject *it1, *it2; + intp i, j, l; int typenum, nd; intp is1, is2, os; - char *ip1, *ip2, *op; + char *op; intp dimensions[MAX_DIMS]; PyArray_DotFunc *dot; + PyArray_Descr *typec; typenum = PyArray_ObjectType(op1, 0); typenum = PyArray_ObjectType(op2, typenum); - - ap1 = (PyArrayObject *)PyArray_ContiguousFromAny(op1, typenum, - 0, 0); - if (ap1 == NULL) return NULL; - ap2 = (PyArrayObject *)PyArray_ContiguousFromAny(op2, typenum, - 0, 0); + + typec = PyArray_DescrFromType(typenum); + Py_INCREF(typec); + ap1 = (PyArrayObject *)PyArray_FromAny(op1, typec, 0, 0, + BEHAVED_FLAGS, NULL); + if (ap1 == NULL) {Py_DECREF(typec); return NULL;} + ap2 = (PyArrayObject *)PyArray_FromAny(op2, typec, 0, 0, + BEHAVED_FLAGS, NULL); if (ap2 == NULL) goto fail; if (ap1->nd == 0 || ap2->nd == 0) { @@ -2282,12 +2286,6 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2) goto fail; } - if (l == 0) n1 = n2 = 0; - else { - n1 = PyArray_SIZE(ap1)/l; - n2 = PyArray_SIZE(ap2)/l; - } - nd = ap1->nd+ap2->nd-2; j = 0; for(i=0; i<ap1->nd-1; i++) { @@ -2311,22 +2309,29 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2) "dot not available for this type"); goto fail; } - is1 = ap1->strides[ap1->nd-1]; is2 = ap2->strides[ap2->nd-1]; op = ret->data; os = ret->descr->elsize; - ip1 = ap1->data; - for(i1=0; i1<n1; i1++) { - ip2 = ap2->data; - for(i2=0; i2<n2; i2++) { - dot(ip1, is1, ip2, is2, op, l, ret); - ip2 += is2*l; + it1 = (PyArrayIterObject *)\ + PyArray_IterAllButAxis((PyObject *)ap1, ap1->nd-1); + it2 = (PyArrayIterObject *)\ + PyArray_IterAllButAxis((PyObject *)ap2, ap2->nd-1); + + while(1) { + while(it2->index < it2->size) { + dot(it1->dataptr, is1, it2->dataptr, is2, op, l, ret); op += os; + PyArray_ITER_NEXT(it2); } - ip1 += is1*l; + PyArray_ITER_NEXT(it1); + if (it1->index >= it1->size) break; + PyArray_ITER_RESET(it2); } + Py_DECREF(it1); + Py_DECREF(it2); + if (PyErr_Occurred()) goto fail; @@ -2350,13 +2355,14 @@ static PyObject * PyArray_MatrixProduct(PyObject *op1, PyObject *op2) { PyArrayObject *ap1, *ap2, *ret=NULL; - intp i, j, l, i1, i2, n1, n2; + PyArrayIterObject *it1, *it2; + intp i, j, l; int typenum, nd; intp is1, is2, os; - char *ip1, *ip2, *op; + char *op; intp dimensions[MAX_DIMS]; PyArray_DotFunc *dot; - intp matchDim, otherDim, is2r, is1r; + intp matchDim; PyArray_Descr *typec; typenum = PyArray_ObjectType(op1, 0); @@ -2383,11 +2389,9 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2) l = ap1->dimensions[ap1->nd-1]; if (ap2->nd > 1) { matchDim = ap2->nd - 2; - otherDim = ap2->nd - 1; } else { matchDim = 0; - otherDim = 0; } if (ap2->dimensions[matchDim] != l) { @@ -2395,12 +2399,6 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2) goto fail; } - if (l == 0) n1 = n2 = 0; - else { - n1 = PyArray_SIZE(ap1)/l; - n2 = PyArray_SIZE(ap2)/l; - } - nd = ap1->nd+ap2->nd-2; j = 0; for(i=0; i<ap1->nd-1; i++) { @@ -2419,6 +2417,8 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2) fprintf(stderr, "\n"); */ + is1 = ap1->strides[ap1->nd-1]; is2 = ap2->strides[matchDim]; + /* Choose which subtype to return */ ret = new_array_for_sum(ap1, ap2, nd, dimensions, typenum); if (ret == NULL) goto fail; @@ -2433,28 +2433,27 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2) goto fail; } - is1 = ap1->strides[ap1->nd-1]; is2 = ap2->strides[matchDim]; - if(ap1->nd > 1) - is1r = ap1->strides[ap1->nd-2]; - else - is1r = ap1->strides[ap1->nd-1]; - is2r = ap2->strides[otherDim]; - op = ret->data; os = ret->descr->elsize; - ip1 = ap1->data; - for(i1=0; i1<n1; i1++) { - ip2 = ap2->data; - for(i2=0; i2<n2; i2++) { - dot(ip1, is1, ip2, is2, op, l, ret); - ip2 += is2r; + it1 = (PyArrayIterObject *)\ + PyArray_IterAllButAxis((PyObject *)ap1, ap1->nd-1); + it2 = (PyArrayIterObject *)\ + PyArray_IterAllButAxis((PyObject *)ap2, matchDim); + + while(1) { + while(it2->index < it2->size) { + dot(it1->dataptr, is1, it2->dataptr, is2, op, l, ret); op += os; + PyArray_ITER_NEXT(it2); } - ip1 += is1r; + PyArray_ITER_NEXT(it1); + if (it1->index >= it1->size) break; + PyArray_ITER_RESET(it2); } - if (PyErr_Occurred()) goto fail; - - + Py_DECREF(it1); + Py_DECREF(it2); + if (PyErr_Occurred()) goto fail; /* only for OBJECT arrays */ + Py_DECREF(ap1); Py_DECREF(ap2); return (PyObject *)ret; |