diff options
author | Anne Archibald <peridot.faceted@gmail.com> | 2021-03-28 13:36:25 +0100 |
---|---|---|
committer | Anne Archibald <peridot.faceted@gmail.com> | 2021-03-28 16:06:07 +0100 |
commit | 7a4098825a301be93f6e13e69cf12add9a940cd4 (patch) | |
tree | c05a3ac4c755b21207f5cd17aa959fd9041a398f | |
parent | f0767ababd553efb30edd3c0c04dcafdc44f1458 (diff) | |
download | numpy-7a4098825a301be93f6e13e69cf12add9a940cd4.tar.gz |
BUG: fix segfault in object/longdouble operations
The operation None*np.longdouble(3) was causing infinite recursion as it
searched for the appropriate conversion method. This resolves that, both
for general operations and for remainders specifically (they fail in a
subtly different way).
Closes #18548
-rw-r--r-- | numpy/core/src/umath/scalarmath.c.src | 61 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarmath.py | 81 |
2 files changed, 141 insertions, 1 deletions
diff --git a/numpy/core/src/umath/scalarmath.c.src b/numpy/core/src/umath/scalarmath.c.src index 86dade0f1..32ddb58a2 100644 --- a/numpy/core/src/umath/scalarmath.c.src +++ b/numpy/core/src/umath/scalarmath.c.src @@ -737,6 +737,9 @@ _@name@_convert2_to_ctypes(PyObject *a, @type@ *arg1, { int ret; ret = _@name@_convert_to_ctype(a, arg1); + if (ret == -2) { + ret = -3; + } if (ret < 0) { return ret; } @@ -1029,7 +1032,7 @@ static PyObject * /**begin repeat * - * #name = cfloat, cdouble, clongdouble# + * #name = cfloat, cdouble# * */ @@ -1047,6 +1050,62 @@ static PyObject * /**begin repeat * + * #oper = divmod, remainder# + * + */ + +/* +Complex numbers do not support remainder operations. Unfortunately, +the type inference for long doubles is complicated, and if a remainder +operation is not defined - if the relevant field is left NULL - then +operations between long doubles and objects lead to an infinite recursion +instead of a TypeError. This should ensure that once everything gets +converted to complex long doubles you correctly get a reasonably +informative TypeError. This fixes the last part of bug gh-18548. +*/ + +static PyObject * +clongdouble_@oper@(PyObject *a, PyObject *b) +{ + PyObject *ret; + npy_clongdouble arg1, arg2; + npy_clongdouble out; + + BINOP_GIVE_UP_IF_NEEDED(a, b, nb_@oper@, clongdouble_@oper@); + + switch(_clongdouble_convert2_to_ctypes(a, &arg1, b, &arg2)) { + case 0: + break; + case -1: + /* one of them can't be cast safely must be mixed-types*/ + return PyArray_Type.tp_as_number->nb_@oper@(a,b); + case -2: + /* use default handling */ + if (PyErr_Occurred()) { + return NULL; + } + return PyGenericArrType_Type.tp_as_number->nb_@oper@(a,b); + case -3: + /* + * special case for longdouble and clongdouble + * because they have a recursive getitem in their dtype + */ + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + /* + * here we do the actual calculation with arg1 and arg2 + * as a function call. + */ + PyErr_SetString(PyExc_TypeError, "complex long doubles do not support remainder"); + return NULL; +} + +/**end repeat**/ + +/**begin repeat + * * #name = half, float, double, longdouble, cfloat, cdouble, clongdouble# * */ diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index 0b615edfa..c27a732a7 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -1,9 +1,12 @@ +import contextlib import sys import warnings import itertools import operator import platform import pytest +from hypothesis import given, settings, Verbosity, assume +from hypothesis.strategies import sampled_from import numpy as np from numpy.testing import ( @@ -707,3 +710,81 @@ class TestBitShifts: shift_arr = np.array([shift]*32, dtype=dt) res_arr = op(val_arr, shift_arr) assert_equal(res_arr, res_scl) + + +@contextlib.contextmanager +def recursionlimit(n): + o = sys.getrecursionlimit() + try: + sys.setrecursionlimit(n) + yield + finally: + sys.setrecursionlimit(o) + + +objecty_things = [object(), None] +reasonable_operators_for_scalars = [ + operator.lt, operator.le, operator.eq, operator.ne, operator.ge, + operator.gt, operator.add, operator.floordiv, operator.mod, + operator.mul, operator.matmul, operator.pow, operator.sub, + operator.truediv, +] + + +@given(sampled_from(objecty_things), + sampled_from(reasonable_operators_for_scalars), + sampled_from(types)) +@settings(verbosity=Verbosity.verbose) +def test_operator_object_left(o, op, type_): + try: + with recursionlimit(100): + op(o, type_(1)) + except TypeError: + pass + + +@given(sampled_from(objecty_things), + sampled_from(reasonable_operators_for_scalars), + sampled_from(types)) +def test_operator_object_right(o, op, type_): + try: + with recursionlimit(100): + op(type_(1), o) + except TypeError: + pass + + +@given(sampled_from(reasonable_operators_for_scalars), + sampled_from(types), + sampled_from(types)) +def test_operator_scalars(op, type1, type2): + try: + op(type1(1), type2(1)) + except TypeError: + pass + + +@pytest.mark.parametrize("op", reasonable_operators_for_scalars) +def test_longdouble_inf_loop(op): + try: + op(np.longdouble(3), None) + except TypeError: + pass + try: + op(None, np.longdouble(3)) + except TypeError: + pass + + +@pytest.mark.parametrize("op", reasonable_operators_for_scalars) +def test_clongdouble_inf_loop(op): + if op in {operator.mod} and False: + pytest.xfail("The modulo operator is known to be broken") + try: + op(np.clongdouble(3), None) + except TypeError: + pass + try: + op(None, np.longdouble(3)) + except TypeError: + pass |