From 5937d593cb1ed086236713db74828205b2002286 Mon Sep 17 00:00:00 2001 From: Tim Hochberg Date: Sun, 19 Feb 2006 13:57:36 +0000 Subject: Dispatch to reciprocal, ones_like, copy, sqrt, square inside array_power and array_inplace_power when power is a scalar in [-1, 0, 1, 0.5, 1, 2]. Also, added the ufuncs reciprocal and ones_like. --- numpy/core/src/arrayobject.c | 109 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 92 insertions(+), 17 deletions(-) (limited to 'numpy/core/src/arrayobject.c') diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index a90fb7cca..43c148548 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2311,7 +2311,11 @@ typedef struct { *multiply, *divide, *remainder, - *power, + *power, + *square, + *reciprocal, + *ones_like, + *copy, *sqrt, *negative, *absolute, @@ -2341,7 +2345,8 @@ typedef struct { static NumericOps n_ops = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL}; + NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, + NULL}; /* Dictionary can contain any of the numeric operations, by name. Those not present will not be changed @@ -2367,7 +2372,11 @@ PyArray_SetNumericOps(PyObject *dict) SET(multiply); SET(divide); SET(remainder); - SET(power); + SET(power); + SET(square); + SET(reciprocal); + SET(ones_like); + SET(copy); SET(sqrt); SET(negative); SET(absolute); @@ -2412,7 +2421,11 @@ PyArray_GetNumericOps(void) GET(multiply); GET(divide); GET(remainder); - GET(power); + GET(power); + GET(square); + GET(reciprocal); + GET(ones_like); + GET(copy); GET(sqrt); GET(negative); GET(absolute); @@ -2527,6 +2540,16 @@ PyArray_GenericInplaceBinaryFunction(PyArrayObject *m1, } return PyObject_CallFunction(op, "OOO", m1, m2, m1); } + +static PyObject * +PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op) +{ + if (op == NULL) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + return PyObject_CallFunction(op, "OO", m1, m1); +} static PyObject * array_add(PyArrayObject *m1, PyObject *m2) @@ -2557,12 +2580,61 @@ array_remainder(PyArrayObject *m1, PyObject *m2) { return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder); } - -static PyObject * -array_power(PyArrayObject *m1, PyObject *m2) -{ - return PyArray_GenericBinaryFunction(m1, m2, n_ops.power); -} + + +static PyObject *array_float(PyArrayObject *v); + + +static int +array_power_is_scalar(PyObject *o2, longdouble* exp) +{ + PyObject *temp; + const int optimize_fpexps = 1; + if (PyInt_Check(o2)) { + *exp = (longdouble)PyInt_AsLong(o2); + return 1; + } + if (optimize_fpexps && PyFloat_Check(o2)) { + *exp = PyFloat_AsDouble(o2); + return 1; + } + if (PyArray_CheckScalar(o2)) { + if (PyArray_ISINTEGER(o2) || (optimize_fpexps && PyArray_ISFLOAT(o2))) { + temp = array_float(o2); + if (temp != NULL) { + *exp = PyFloat_AsDouble(o2); + Py_DECREF(temp); + return 1; + } + } + } + return 0; +} + +static PyObject * +fast_scalar_power_op(PyArrayObject *a1, PyObject *o2) { + double exp; + if (PyArray_Check(a1) && (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) { + if (array_power_is_scalar(o2, &exp)) { + if (exp == -1.0) return n_ops.reciprocal; + if (exp == 0.0) return n_ops.ones_like; + if (exp == 0.5) return n_ops.sqrt; + if (exp == 1.0) return n_ops.copy; + if (exp == 2.0) return n_ops.square; + } + } + return NULL; +} + +static PyObject * +array_power(PyArrayObject *a1, PyObject *o2) +{ + PyObject *fastop; + fastop = fast_scalar_power_op(a1, o2); + if (fastop) return PyArray_GenericUnaryFunction(a1, fastop); + return PyArray_GenericBinaryFunction(a1, o2, n_ops.power); +} + static PyObject * array_negative(PyArrayObject *m1) @@ -2640,12 +2712,15 @@ static PyObject * array_inplace_remainder(PyArrayObject *m1, PyObject *m2) { return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder); -} - +} + static PyObject * -array_inplace_power(PyArrayObject *m1, PyObject *m2) -{ - return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.power); +array_inplace_power(PyArrayObject *a1, PyObject *o2) +{ + PyObject *fastop; + fastop = fast_scalar_power_op(a1, o2); + if (fastop) return PyArray_GenericInplaceUnaryFunction(a1, fastop); + return PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power); } static PyObject * @@ -2808,8 +2883,8 @@ array_float(PyArrayObject *v) pv = v->descr->f->getitem(v->data, v); if (pv == NULL) return NULL; if (pv->ob_type->tp_as_number == 0) { - PyErr_SetString(PyExc_TypeError, "cannot convert to an "\ - "int; scalar object is not a number"); + PyErr_SetString(PyExc_TypeError, "cannot convert to a "\ + "float; scalar object is not a number"); Py_DECREF(pv); return NULL; } -- cgit v1.2.1