diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-05-17 09:58:57 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-17 09:58:57 -0400 |
commit | b9e3ac9abb6e435cdf6bbe33e0bc894d6a879a53 (patch) | |
tree | 95a2b664887537e2995c3e73eeae1fdc07d9864d /numpy | |
parent | 692655e77b65a9186bda7a701062abd6b62d4ca9 (diff) | |
parent | e1df000d940d2367c6e86f754be5201c2051ba99 (diff) | |
download | numpy-b9e3ac9abb6e435cdf6bbe33e0bc894d6a879a53.tar.gz |
Merge pull request #9112 from mhvk/array_ufunc_fast_scalar_power
BUG: ndarray.__pow__ does not check result of fast_scalar_power
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 61 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.h | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 21 |
3 files changed, 50 insertions, 33 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index d86cef5a1..9c1343497 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -91,6 +91,7 @@ PyArray_SetNumericOps(PyObject *dict) SET(sqrt); SET(cbrt); SET(negative); + SET(positive); SET(absolute); SET(invert); SET(left_shift); @@ -143,6 +144,7 @@ PyArray_GetNumericOps(void) GET(_ones_like); GET(sqrt); GET(negative); + GET(positive); GET(absolute); GET(invert); GET(left_shift); @@ -453,9 +455,14 @@ is_scalar_with_conversion(PyObject *o2, double* out_exponent) return NPY_NOSCALAR; } -/* optimize float array or complex array to a scalar power */ -static PyObject * -fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) +/* + * optimize float array or complex array to a scalar power + * returns 0 on success, -1 if no optimization is possible + * the result is in value (can be NULL if an error occurred) + */ +static int +fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace, + PyObject **value) { double exponent; NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */ @@ -464,17 +471,7 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) PyObject *fastop = NULL; if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) { if (exponent == 1.0) { - /* we have to do this one special, as the - "copy" method of array objects isn't set - up early enough to be added - by PyArray_SetNumericOps. - */ - if (inplace) { - Py_INCREF(a1); - return (PyObject *)a1; - } else { - return PyArray_Copy(a1); - } + fastop = n_ops.positive; } else if (exponent == -1.0) { fastop = n_ops.reciprocal; @@ -489,15 +486,16 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) fastop = n_ops.square; } else { - return NULL; + return -1; } if (inplace || can_elide_temp_unary(a1)) { - return PyArray_GenericInplaceUnaryFunction(a1, fastop); + *value = PyArray_GenericInplaceUnaryFunction(a1, fastop); } else { - return PyArray_GenericUnaryFunction(a1, fastop); + *value = PyArray_GenericUnaryFunction(a1, fastop); } + return 0; } /* Because this is called with all arrays, we need to * change the output if the kind of the scalar is different @@ -507,36 +505,35 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) else if (exponent == 2.0) { fastop = n_ops.square; if (inplace) { - return PyArray_GenericInplaceUnaryFunction(a1, fastop); + *value = PyArray_GenericInplaceUnaryFunction(a1, fastop); } else { /* We only special-case the FLOAT_SCALAR and integer types */ if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) { - PyObject *res; PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE); a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1)); - if (a1 == NULL) { - return NULL; + if (a1 != NULL) { + /* cast always creates a new array */ + *value = PyArray_GenericInplaceUnaryFunction(a1, fastop); + Py_DECREF(a1); } - /* cast always creates a new array */ - res = PyArray_GenericInplaceUnaryFunction(a1, fastop); - Py_DECREF(a1); - return res; } else { - return PyArray_GenericUnaryFunction(a1, fastop); + *value = PyArray_GenericUnaryFunction(a1, fastop); } } + return 0; } } - return NULL; + /* no fast operation found */ + return -1; } static PyObject * array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo) { - PyObject *value; + PyObject *value = NULL; if (modulo != Py_None) { /* modular exponentiation is not implemented (gh-8804) */ @@ -545,8 +542,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo) } BINOP_GIVE_UP_IF_NEEDED(a1, o2, nb_power, array_power); - value = fast_scalar_power(a1, o2, 0); - if (!value) { + if (fast_scalar_power(a1, o2, 0, &value) != 0) { value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power); } return value; @@ -686,12 +682,11 @@ static PyObject * array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo)) { /* modulo is ignored! */ - PyObject *value; + PyObject *value = NULL; INPLACE_GIVE_UP_IF_NEEDED( a1, o2, nb_inplace_power, array_inplace_power); - value = fast_scalar_power(a1, o2, 1); - if (!value) { + if (fast_scalar_power(a1, o2, 1, &value) != 0) { value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power); } return value; diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h index 113fc2475..99a2a722b 100644 --- a/numpy/core/src/multiarray/number.h +++ b/numpy/core/src/multiarray/number.h @@ -15,6 +15,7 @@ typedef struct { PyObject *sqrt; PyObject *cbrt; PyObject *negative; + PyObject *positive; PyObject *absolute; PyObject *invert; PyObject *left_shift; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 571d0ceb9..30bda20de 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3072,6 +3072,27 @@ class TestBinop(object): assert_equal(A[0], 30) assert_(isinstance(A, OutClass)) + def test_pow_override_with_errors(self): + # regression test for gh-9112 + class PowerOnly(np.ndarray): + def __array_ufunc__(self, ufunc, method, *inputs, **kw): + if ufunc is not np.power: + raise NotImplementedError + return "POWER!" + # explicit cast to float, to ensure the fast power path is taken. + a = np.array(5., dtype=np.float64).view(PowerOnly) + assert_equal(a ** 2.5, "POWER!") + with assert_raises(NotImplementedError): + a ** 0.5 + with assert_raises(NotImplementedError): + a ** 0 + with assert_raises(NotImplementedError): + a ** 1 + with assert_raises(NotImplementedError): + a ** -1 + with assert_raises(NotImplementedError): + a ** 2 + class TestTemporaryElide(TestCase): # elision is only triggered on relatively large arrays |