diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-05-12 15:23:57 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-05-12 15:26:01 -0400 |
commit | ededeb529406abaaa31edf786c0060ae94f7f8f1 (patch) | |
tree | 3fa8c45ebb630b86076dbaa5c1cb1c0b77573733 /numpy | |
parent | 14ff219a13e194c5e7995218fea3c7648eb1c875 (diff) | |
download | numpy-ededeb529406abaaa31edf786c0060ae94f7f8f1.tar.gz |
BUG: errors in fast_scalar_power are not propagated.
If one raises an ndarray or subclass to particular powers,
fast_scalar_power is used in number.c -- this defers to another
function (e.g., sqrt for 0.5), but does not check whether an error
occurred. In consequece, if a subclass raises an error, and returns
NULL, that gets interpreted as meaning the fast operation is not
possible, and power is tried in turn, likely raising another
exception. This commit fixes fast_scalar_power such that it
returns success any time an operation is attempted.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 61 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.h | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 21 |
3 files changed, 50 insertions, 33 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 1f5523b90..384520652 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -91,6 +91,7 @@ PyArray_SetNumericOps(PyObject *dict) SET(sqrt); SET(cbrt); SET(negative); + SET(positive); SET(absolute); SET(invert); SET(left_shift); @@ -143,6 +144,7 @@ PyArray_GetNumericOps(void) GET(_ones_like); GET(sqrt); GET(negative); + GET(positive); GET(absolute); GET(invert); GET(left_shift); @@ -453,9 +455,14 @@ is_scalar_with_conversion(PyObject *o2, double* out_exponent) return NPY_NOSCALAR; } -/* optimize float array or complex array to a scalar power */ -static PyObject * -fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) +/* + * optimize float array or complex array to a scalar power + * returns 0 on success, -1 if no optimization is possible + * the result is in value (can be NULL if an error occurred) + */ +static int +fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace, + PyObject **value) { double exponent; NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */ @@ -464,17 +471,7 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) PyObject *fastop = NULL; if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) { if (exponent == 1.0) { - /* we have to do this one special, as the - "copy" method of array objects isn't set - up early enough to be added - by PyArray_SetNumericOps. - */ - if (inplace) { - Py_INCREF(a1); - return (PyObject *)a1; - } else { - return PyArray_Copy(a1); - } + fastop = n_ops.positive; } else if (exponent == -1.0) { fastop = n_ops.reciprocal; @@ -489,15 +486,16 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) fastop = n_ops.square; } else { - return NULL; + return -1; } if (inplace || can_elide_temp_unary(a1)) { - return PyArray_GenericInplaceUnaryFunction(a1, fastop); + *value = PyArray_GenericInplaceUnaryFunction(a1, fastop); } else { - return PyArray_GenericUnaryFunction(a1, fastop); + *value = PyArray_GenericUnaryFunction(a1, fastop); } + return 0; } /* Because this is called with all arrays, we need to * change the output if the kind of the scalar is different @@ -507,36 +505,35 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) else if (exponent == 2.0) { fastop = n_ops.square; if (inplace) { - return PyArray_GenericInplaceUnaryFunction(a1, fastop); + *value = PyArray_GenericInplaceUnaryFunction(a1, fastop); } else { /* We only special-case the FLOAT_SCALAR and integer types */ if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) { - PyObject *res; PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE); a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1)); - if (a1 == NULL) { - return NULL; + if (a1 != NULL) { + /* cast always creates a new array */ + *value = PyArray_GenericInplaceUnaryFunction(a1, fastop); + Py_DECREF(a1); } - /* cast always creates a new array */ - res = PyArray_GenericInplaceUnaryFunction(a1, fastop); - Py_DECREF(a1); - return res; } else { - return PyArray_GenericUnaryFunction(a1, fastop); + *value = PyArray_GenericUnaryFunction(a1, fastop); } } + return 0; } } - return NULL; + /* no fast operation found */ + return -1; } static PyObject * array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo) { - PyObject *value; + PyObject *value = NULL; if (modulo != Py_None) { /* modular exponentiation is not implemented (gh-8804) */ @@ -545,8 +542,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo) } BINOP_GIVE_UP_IF_NEEDED(a1, o2, nb_power, array_power); - value = fast_scalar_power(a1, o2, 0); - if (!value) { + if (fast_scalar_power(a1, o2, 0, &value) != 0) { value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power); } return value; @@ -686,12 +682,11 @@ static PyObject * array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo)) { /* modulo is ignored! */ - PyObject *value; + PyObject *value = NULL; INPLACE_GIVE_UP_IF_NEEDED( a1, o2, nb_inplace_power, array_inplace_power); - value = fast_scalar_power(a1, o2, 1); - if (!value) { + if (fast_scalar_power(a1, o2, 1, &value) != 0) { value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power); } return value; diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h index 113fc2475..99a2a722b 100644 --- a/numpy/core/src/multiarray/number.h +++ b/numpy/core/src/multiarray/number.h @@ -15,6 +15,7 @@ typedef struct { PyObject *sqrt; PyObject *cbrt; PyObject *negative; + PyObject *positive; PyObject *absolute; PyObject *invert; PyObject *left_shift; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 835d03528..4d531efb3 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3139,6 +3139,27 @@ class TestBinop(object): assert_equal(A[0], 30) assert_(isinstance(A, OutClass)) + def test_pow_override_with_errors(self): + # regression test for gh-9112 + class PowerOnly(np.ndarray): + def __array_ufunc__(self, ufunc, method, *inputs, **kw): + if ufunc is not np.power: + raise NotImplementedError + return "POWER!" + # explicit cast to float, to ensure the fast power path is taken. + a = np.array(5., dtype=np.float64).view(PowerOnly) + assert_equal(a ** 2.5, "POWER!") + with assert_raises(NotImplementedError): + a ** 0.5 + with assert_raises(NotImplementedError): + a ** 0 + with assert_raises(NotImplementedError): + a ** 1 + with assert_raises(NotImplementedError): + a ** -1 + with assert_raises(NotImplementedError): + a ** 2 + class TestCAPI(TestCase): def test_IsPythonScalar(self): |