summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarraymodule.c101
1 files changed, 50 insertions, 51 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index 295924fea..0dc920c7c 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -6,7 +6,7 @@
Original file
Copyright (c) 1995, 1996, 1997 Jim Hugunin, hugunin@mit.edu
- Modified for numpy_core in 2005
+ Modified extensively for numpy in 2005
Travis E. Oliphant
Assistant Professor at
@@ -2249,21 +2249,25 @@ static PyObject *
PyArray_InnerProduct(PyObject *op1, PyObject *op2)
{
PyArrayObject *ap1, *ap2, *ret=NULL;
- intp i, j, l, i1, i2, n1, n2;
+ PyArrayIterObject *it1, *it2;
+ intp i, j, l;
int typenum, nd;
intp is1, is2, os;
- char *ip1, *ip2, *op;
+ char *op;
intp dimensions[MAX_DIMS];
PyArray_DotFunc *dot;
+ PyArray_Descr *typec;
typenum = PyArray_ObjectType(op1, 0);
typenum = PyArray_ObjectType(op2, typenum);
-
- ap1 = (PyArrayObject *)PyArray_ContiguousFromAny(op1, typenum,
- 0, 0);
- if (ap1 == NULL) return NULL;
- ap2 = (PyArrayObject *)PyArray_ContiguousFromAny(op2, typenum,
- 0, 0);
+
+ typec = PyArray_DescrFromType(typenum);
+ Py_INCREF(typec);
+ ap1 = (PyArrayObject *)PyArray_FromAny(op1, typec, 0, 0,
+ BEHAVED_FLAGS, NULL);
+ if (ap1 == NULL) {Py_DECREF(typec); return NULL;}
+ ap2 = (PyArrayObject *)PyArray_FromAny(op2, typec, 0, 0,
+ BEHAVED_FLAGS, NULL);
if (ap2 == NULL) goto fail;
if (ap1->nd == 0 || ap2->nd == 0) {
@@ -2282,12 +2286,6 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
goto fail;
}
- if (l == 0) n1 = n2 = 0;
- else {
- n1 = PyArray_SIZE(ap1)/l;
- n2 = PyArray_SIZE(ap2)/l;
- }
-
nd = ap1->nd+ap2->nd-2;
j = 0;
for(i=0; i<ap1->nd-1; i++) {
@@ -2311,22 +2309,29 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
"dot not available for this type");
goto fail;
}
-
is1 = ap1->strides[ap1->nd-1];
is2 = ap2->strides[ap2->nd-1];
op = ret->data; os = ret->descr->elsize;
- ip1 = ap1->data;
- for(i1=0; i1<n1; i1++) {
- ip2 = ap2->data;
- for(i2=0; i2<n2; i2++) {
- dot(ip1, is1, ip2, is2, op, l, ret);
- ip2 += is2*l;
+ it1 = (PyArrayIterObject *)\
+ PyArray_IterAllButAxis((PyObject *)ap1, ap1->nd-1);
+ it2 = (PyArrayIterObject *)\
+ PyArray_IterAllButAxis((PyObject *)ap2, ap2->nd-1);
+
+ while(1) {
+ while(it2->index < it2->size) {
+ dot(it1->dataptr, is1, it2->dataptr, is2, op, l, ret);
op += os;
+ PyArray_ITER_NEXT(it2);
}
- ip1 += is1*l;
+ PyArray_ITER_NEXT(it1);
+ if (it1->index >= it1->size) break;
+ PyArray_ITER_RESET(it2);
}
+ Py_DECREF(it1);
+ Py_DECREF(it2);
+
if (PyErr_Occurred()) goto fail;
@@ -2350,13 +2355,14 @@ static PyObject *
PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
{
PyArrayObject *ap1, *ap2, *ret=NULL;
- intp i, j, l, i1, i2, n1, n2;
+ PyArrayIterObject *it1, *it2;
+ intp i, j, l;
int typenum, nd;
intp is1, is2, os;
- char *ip1, *ip2, *op;
+ char *op;
intp dimensions[MAX_DIMS];
PyArray_DotFunc *dot;
- intp matchDim, otherDim, is2r, is1r;
+ intp matchDim;
PyArray_Descr *typec;
typenum = PyArray_ObjectType(op1, 0);
@@ -2383,11 +2389,9 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
l = ap1->dimensions[ap1->nd-1];
if (ap2->nd > 1) {
matchDim = ap2->nd - 2;
- otherDim = ap2->nd - 1;
}
else {
matchDim = 0;
- otherDim = 0;
}
if (ap2->dimensions[matchDim] != l) {
@@ -2395,12 +2399,6 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
goto fail;
}
- if (l == 0) n1 = n2 = 0;
- else {
- n1 = PyArray_SIZE(ap1)/l;
- n2 = PyArray_SIZE(ap2)/l;
- }
-
nd = ap1->nd+ap2->nd-2;
j = 0;
for(i=0; i<ap1->nd-1; i++) {
@@ -2419,6 +2417,8 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
fprintf(stderr, "\n");
*/
+ is1 = ap1->strides[ap1->nd-1]; is2 = ap2->strides[matchDim];
+
/* Choose which subtype to return */
ret = new_array_for_sum(ap1, ap2, nd, dimensions, typenum);
if (ret == NULL) goto fail;
@@ -2433,28 +2433,27 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
goto fail;
}
- is1 = ap1->strides[ap1->nd-1]; is2 = ap2->strides[matchDim];
- if(ap1->nd > 1)
- is1r = ap1->strides[ap1->nd-2];
- else
- is1r = ap1->strides[ap1->nd-1];
- is2r = ap2->strides[otherDim];
-
op = ret->data; os = ret->descr->elsize;
- ip1 = ap1->data;
- for(i1=0; i1<n1; i1++) {
- ip2 = ap2->data;
- for(i2=0; i2<n2; i2++) {
- dot(ip1, is1, ip2, is2, op, l, ret);
- ip2 += is2r;
+ it1 = (PyArrayIterObject *)\
+ PyArray_IterAllButAxis((PyObject *)ap1, ap1->nd-1);
+ it2 = (PyArrayIterObject *)\
+ PyArray_IterAllButAxis((PyObject *)ap2, matchDim);
+
+ while(1) {
+ while(it2->index < it2->size) {
+ dot(it1->dataptr, is1, it2->dataptr, is2, op, l, ret);
op += os;
+ PyArray_ITER_NEXT(it2);
}
- ip1 += is1r;
+ PyArray_ITER_NEXT(it1);
+ if (it1->index >= it1->size) break;
+ PyArray_ITER_RESET(it2);
}
- if (PyErr_Occurred()) goto fail;
-
-
+ Py_DECREF(it1);
+ Py_DECREF(it2);
+ if (PyErr_Occurred()) goto fail; /* only for OBJECT arrays */
+
Py_DECREF(ap1);
Py_DECREF(ap2);
return (PyObject *)ret;