diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-03-23 03:17:06 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-03-23 03:17:06 +0000 |
commit | 0e72b1b67f1291aef293a77138ad94d459d8448b (patch) | |
tree | 0bff186392f1965c726afd74aa0d68759586c768 /numpy/core/src/arrayobject.c | |
parent | bf8b6bc167fadcec411d1b56f1590e297ae2fab3 (diff) | |
download | numpy-0e72b1b67f1291aef293a77138ad94d459d8448b.tar.gz |
Fix segfault in fast power.
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 122 |
1 files changed, 68 insertions, 54 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 306cfeae7..60449a2c1 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2637,7 +2637,7 @@ array_remainder(PyArrayObject *m1, PyObject *m2) static PyObject *array_float(PyArrayObject *v); - +static PyObject *gentype_float(PyObject *v); static int array_power_is_scalar(PyObject *o2, double* exp) @@ -2645,22 +2645,32 @@ array_power_is_scalar(PyObject *o2, double* exp) PyObject *temp; const int optimize_fpexps = 1; if (PyInt_Check(o2)) { - *exp = (double)PyInt_AsLong(o2); - return 1; + *exp = (double)PyInt_AsLong(o2); + return 1; } if (optimize_fpexps && PyFloat_Check(o2)) { - *exp = PyFloat_AsDouble(o2); - return 1; + *exp = PyFloat_AsDouble(o2); + return 1; } - if (PyArray_CheckScalar(o2)) { - if (PyArray_ISINTEGER(o2) || (optimize_fpexps && PyArray_ISFLOAT(o2))) { - temp = array_float((PyArrayObject *)o2); - if (temp != NULL) { - *exp = PyFloat_AsDouble(o2); - Py_DECREF(temp); - return 1; - } - } + if (PyArray_IsZeroDim(o2)) { + if (PyArray_ISINTEGER(o2) || + (optimize_fpexps && PyArray_ISFLOAT(o2))) { + temp = array_float((PyArrayObject *)o2); + if (temp != NULL) { + *exp = PyFloat_AsDouble(o2); + Py_DECREF(temp); + return 1; + } + } + } + if (PyArray_IsScalar(o2, Integer) || + (optimize_fpexps && PyArray_IsScalar(o2, Floating))) { + temp = gentype_float(o2); + if (temp != NULL) { + *exp = PyFloat_AsDouble(o2); + Py_DECREF(temp); + return 1; + } } return 0; } @@ -2668,51 +2678,55 @@ array_power_is_scalar(PyObject *o2, double* exp) /* optimize float array or complex array to a scalar power */ static PyObject * fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) { - double exp; - if (PyArray_Check(a1) && (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) { - if (array_power_is_scalar(o2, &exp)) { - PyObject *fastop = NULL; - if (exp == 1.0) { - /* we have to do this one special, as the "copy" method of - array objects isn't set up early enough to be added - by PyArray_SetNumericOps. - */ - if (inplace) { - return (PyObject *)a1; - } else { - return PyArray_Copy(a1); - } - } else if (exp == -1.0) { - fastop = n_ops.reciprocal; - } else if (exp == 0.0) { - fastop = n_ops.ones_like; - } else if (exp == 0.5) { - fastop = n_ops.sqrt; - } else if (exp == 2.0) { - fastop = n_ops.square; - } else { - return NULL; - } - if (inplace) { - PyArray_GenericInplaceUnaryFunction(a1, fastop); - } else { - return PyArray_GenericUnaryFunction(a1, fastop); - } - } - } - return NULL; + double exp; + if (PyArray_Check(a1) && (PyArray_ISFLOAT(a1) || + PyArray_ISCOMPLEX(a1))) { + if (array_power_is_scalar(o2, &exp)) { + PyObject *fastop = NULL; + if (exp == 1.0) { + /* we have to do this one special, as the + "copy" method of array objects isn't set + up early enough to be added + by PyArray_SetNumericOps. + */ + if (inplace) { + return (PyObject *)a1; + } else { + return PyArray_Copy(a1); + } + } else if (exp == -1.0) { + fastop = n_ops.reciprocal; + } else if (exp == 0.0) { + fastop = n_ops.ones_like; + } else if (exp == 0.5) { + fastop = n_ops.sqrt; + } else if (exp == 2.0) { + fastop = n_ops.square; + } else { + return NULL; + } + if (inplace) { + PyArray_GenericInplaceUnaryFunction(a1, + fastop); + } else { + return PyArray_GenericUnaryFunction(a1, + fastop); + } + } + } + return NULL; } static PyObject * array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo) { - /* modulo is ignored! */ - PyObject *value; - value = fast_scalar_power(a1, o2, 0); - if (!value) { - value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power); - } - return value; + /* modulo is ignored! */ + PyObject *value; + value = fast_scalar_power(a1, o2, 0); + if (!value) { + value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power); + } + return value; } |