summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-05-17 09:58:57 -0400
committerGitHub <noreply@github.com>2017-05-17 09:58:57 -0400
commitb9e3ac9abb6e435cdf6bbe33e0bc894d6a879a53 (patch)
tree95a2b664887537e2995c3e73eeae1fdc07d9864d /numpy
parent692655e77b65a9186bda7a701062abd6b62d4ca9 (diff)
parente1df000d940d2367c6e86f754be5201c2051ba99 (diff)
downloadnumpy-b9e3ac9abb6e435cdf6bbe33e0bc894d6a879a53.tar.gz
Merge pull request #9112 from mhvk/array_ufunc_fast_scalar_power
BUG: ndarray.__pow__ does not check result of fast_scalar_power
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/number.c61
-rw-r--r--numpy/core/src/multiarray/number.h1
-rw-r--r--numpy/core/tests/test_multiarray.py21
3 files changed, 50 insertions, 33 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index d86cef5a1..9c1343497 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -91,6 +91,7 @@ PyArray_SetNumericOps(PyObject *dict)
SET(sqrt);
SET(cbrt);
SET(negative);
+ SET(positive);
SET(absolute);
SET(invert);
SET(left_shift);
@@ -143,6 +144,7 @@ PyArray_GetNumericOps(void)
GET(_ones_like);
GET(sqrt);
GET(negative);
+ GET(positive);
GET(absolute);
GET(invert);
GET(left_shift);
@@ -453,9 +455,14 @@ is_scalar_with_conversion(PyObject *o2, double* out_exponent)
return NPY_NOSCALAR;
}
-/* optimize float array or complex array to a scalar power */
-static PyObject *
-fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
+/*
+ * optimize float array or complex array to a scalar power
+ * returns 0 on success, -1 if no optimization is possible
+ * the result is in value (can be NULL if an error occurred)
+ */
+static int
+fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace,
+ PyObject **value)
{
double exponent;
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
@@ -464,17 +471,7 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
- /* we have to do this one special, as the
- "copy" method of array objects isn't set
- up early enough to be added
- by PyArray_SetNumericOps.
- */
- if (inplace) {
- Py_INCREF(a1);
- return (PyObject *)a1;
- } else {
- return PyArray_Copy(a1);
- }
+ fastop = n_ops.positive;
}
else if (exponent == -1.0) {
fastop = n_ops.reciprocal;
@@ -489,15 +486,16 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
fastop = n_ops.square;
}
else {
- return NULL;
+ return -1;
}
if (inplace || can_elide_temp_unary(a1)) {
- return PyArray_GenericInplaceUnaryFunction(a1, fastop);
+ *value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
}
else {
- return PyArray_GenericUnaryFunction(a1, fastop);
+ *value = PyArray_GenericUnaryFunction(a1, fastop);
}
+ return 0;
}
/* Because this is called with all arrays, we need to
* change the output if the kind of the scalar is different
@@ -507,36 +505,35 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
else if (exponent == 2.0) {
fastop = n_ops.square;
if (inplace) {
- return PyArray_GenericInplaceUnaryFunction(a1, fastop);
+ *value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
}
else {
/* We only special-case the FLOAT_SCALAR and integer types */
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
- PyObject *res;
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
PyArray_ISFORTRAN(a1));
- if (a1 == NULL) {
- return NULL;
+ if (a1 != NULL) {
+ /* cast always creates a new array */
+ *value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
+ Py_DECREF(a1);
}
- /* cast always creates a new array */
- res = PyArray_GenericInplaceUnaryFunction(a1, fastop);
- Py_DECREF(a1);
- return res;
}
else {
- return PyArray_GenericUnaryFunction(a1, fastop);
+ *value = PyArray_GenericUnaryFunction(a1, fastop);
}
}
+ return 0;
}
}
- return NULL;
+ /* no fast operation found */
+ return -1;
}
static PyObject *
array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo)
{
- PyObject *value;
+ PyObject *value = NULL;
if (modulo != Py_None) {
/* modular exponentiation is not implemented (gh-8804) */
@@ -545,8 +542,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo)
}
BINOP_GIVE_UP_IF_NEEDED(a1, o2, nb_power, array_power);
- value = fast_scalar_power(a1, o2, 0);
- if (!value) {
+ if (fast_scalar_power(a1, o2, 0, &value) != 0) {
value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power);
}
return value;
@@ -686,12 +682,11 @@ static PyObject *
array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo))
{
/* modulo is ignored! */
- PyObject *value;
+ PyObject *value = NULL;
INPLACE_GIVE_UP_IF_NEEDED(
a1, o2, nb_inplace_power, array_inplace_power);
- value = fast_scalar_power(a1, o2, 1);
- if (!value) {
+ if (fast_scalar_power(a1, o2, 1, &value) != 0) {
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
}
return value;
diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h
index 113fc2475..99a2a722b 100644
--- a/numpy/core/src/multiarray/number.h
+++ b/numpy/core/src/multiarray/number.h
@@ -15,6 +15,7 @@ typedef struct {
PyObject *sqrt;
PyObject *cbrt;
PyObject *negative;
+ PyObject *positive;
PyObject *absolute;
PyObject *invert;
PyObject *left_shift;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 571d0ceb9..30bda20de 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -3072,6 +3072,27 @@ class TestBinop(object):
assert_equal(A[0], 30)
assert_(isinstance(A, OutClass))
+ def test_pow_override_with_errors(self):
+ # regression test for gh-9112
+ class PowerOnly(np.ndarray):
+ def __array_ufunc__(self, ufunc, method, *inputs, **kw):
+ if ufunc is not np.power:
+ raise NotImplementedError
+ return "POWER!"
+ # explicit cast to float, to ensure the fast power path is taken.
+ a = np.array(5., dtype=np.float64).view(PowerOnly)
+ assert_equal(a ** 2.5, "POWER!")
+ with assert_raises(NotImplementedError):
+ a ** 0.5
+ with assert_raises(NotImplementedError):
+ a ** 0
+ with assert_raises(NotImplementedError):
+ a ** 1
+ with assert_raises(NotImplementedError):
+ a ** -1
+ with assert_raises(NotImplementedError):
+ a ** 2
+
class TestTemporaryElide(TestCase):
# elision is only triggered on relatively large arrays