summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
authorTim Hochberg <tim_hochberg@local>2006-02-19 13:57:36 +0000
committerTim Hochberg <tim_hochberg@local>2006-02-19 13:57:36 +0000
commit5937d593cb1ed086236713db74828205b2002286 (patch)
treeaabb2cd8c3ab4f474daa5fea17d81ad40cef28a4 /numpy/core/src/arrayobject.c
parent8d9af69878ff93bc21859494f301f689cef8e0e4 (diff)
downloadnumpy-5937d593cb1ed086236713db74828205b2002286.tar.gz
Dispatch to reciprocal, ones_like, copy, sqrt, square inside array_power and array_inplace_power when power is a scalar in [-1, 0, 1, 0.5, 1, 2]. Also, added the ufuncs reciprocal and ones_like.
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c109
1 files changed, 92 insertions, 17 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index a90fb7cca..43c148548 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2311,7 +2311,11 @@ typedef struct {
*multiply,
*divide,
*remainder,
- *power,
+ *power,
+ *square,
+ *reciprocal,
+ *ones_like,
+ *copy,
*sqrt,
*negative,
*absolute,
@@ -2341,7 +2345,8 @@ typedef struct {
static NumericOps n_ops = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
- NULL, NULL, NULL, NULL, NULL};
+ NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
+ NULL};
/* Dictionary can contain any of the numeric operations, by name.
Those not present will not be changed
@@ -2367,7 +2372,11 @@ PyArray_SetNumericOps(PyObject *dict)
SET(multiply);
SET(divide);
SET(remainder);
- SET(power);
+ SET(power);
+ SET(square);
+ SET(reciprocal);
+ SET(ones_like);
+ SET(copy);
SET(sqrt);
SET(negative);
SET(absolute);
@@ -2412,7 +2421,11 @@ PyArray_GetNumericOps(void)
GET(multiply);
GET(divide);
GET(remainder);
- GET(power);
+ GET(power);
+ GET(square);
+ GET(reciprocal);
+ GET(ones_like);
+ GET(copy);
GET(sqrt);
GET(negative);
GET(absolute);
@@ -2527,6 +2540,16 @@ PyArray_GenericInplaceBinaryFunction(PyArrayObject *m1,
}
return PyObject_CallFunction(op, "OOO", m1, m2, m1);
}
+
+static PyObject *
+PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op)
+{
+ if (op == NULL) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ return PyObject_CallFunction(op, "OO", m1, m1);
+}
static PyObject *
array_add(PyArrayObject *m1, PyObject *m2)
@@ -2557,12 +2580,61 @@ array_remainder(PyArrayObject *m1, PyObject *m2)
{
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}
-
-static PyObject *
-array_power(PyArrayObject *m1, PyObject *m2)
-{
- return PyArray_GenericBinaryFunction(m1, m2, n_ops.power);
-}
+
+
+static PyObject *array_float(PyArrayObject *v);
+
+
+static int
+array_power_is_scalar(PyObject *o2, longdouble* exp)
+{
+ PyObject *temp;
+ const int optimize_fpexps = 1;
+ if (PyInt_Check(o2)) {
+ *exp = (longdouble)PyInt_AsLong(o2);
+ return 1;
+ }
+ if (optimize_fpexps && PyFloat_Check(o2)) {
+ *exp = PyFloat_AsDouble(o2);
+ return 1;
+ }
+ if (PyArray_CheckScalar(o2)) {
+ if (PyArray_ISINTEGER(o2) || (optimize_fpexps && PyArray_ISFLOAT(o2))) {
+ temp = array_float(o2);
+ if (temp != NULL) {
+ *exp = PyFloat_AsDouble(o2);
+ Py_DECREF(temp);
+ return 1;
+ }
+ }
+ }
+ return 0;
+}
+
+static PyObject *
+fast_scalar_power_op(PyArrayObject *a1, PyObject *o2) {
+ double exp;
+ if (PyArray_Check(a1) && (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) {
+ if (array_power_is_scalar(o2, &exp)) {
+ if (exp == -1.0) return n_ops.reciprocal;
+ if (exp == 0.0) return n_ops.ones_like;
+ if (exp == 0.5) return n_ops.sqrt;
+ if (exp == 1.0) return n_ops.copy;
+ if (exp == 2.0) return n_ops.square;
+ }
+ }
+ return NULL;
+}
+
+static PyObject *
+array_power(PyArrayObject *a1, PyObject *o2)
+{
+ PyObject *fastop;
+ fastop = fast_scalar_power_op(a1, o2);
+ if (fastop) return PyArray_GenericUnaryFunction(a1, fastop);
+ return PyArray_GenericBinaryFunction(a1, o2, n_ops.power);
+}
+
static PyObject *
array_negative(PyArrayObject *m1)
@@ -2640,12 +2712,15 @@ static PyObject *
array_inplace_remainder(PyArrayObject *m1, PyObject *m2)
{
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder);
-}
-
+}
+
static PyObject *
-array_inplace_power(PyArrayObject *m1, PyObject *m2)
-{
- return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.power);
+array_inplace_power(PyArrayObject *a1, PyObject *o2)
+{
+ PyObject *fastop;
+ fastop = fast_scalar_power_op(a1, o2);
+ if (fastop) return PyArray_GenericInplaceUnaryFunction(a1, fastop);
+ return PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
}
static PyObject *
@@ -2808,8 +2883,8 @@ array_float(PyArrayObject *v)
pv = v->descr->f->getitem(v->data, v);
if (pv == NULL) return NULL;
if (pv->ob_type->tp_as_number == 0) {
- PyErr_SetString(PyExc_TypeError, "cannot convert to an "\
- "int; scalar object is not a number");
+ PyErr_SetString(PyExc_TypeError, "cannot convert to a "\
+ "float; scalar object is not a number");
Py_DECREF(pv);
return NULL;
}