summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnne Archibald <peridot.faceted@gmail.com>2021-03-28 13:36:25 +0100
committerAnne Archibald <peridot.faceted@gmail.com>2021-03-28 16:06:07 +0100
commit7a4098825a301be93f6e13e69cf12add9a940cd4 (patch)
treec05a3ac4c755b21207f5cd17aa959fd9041a398f
parentf0767ababd553efb30edd3c0c04dcafdc44f1458 (diff)
downloadnumpy-7a4098825a301be93f6e13e69cf12add9a940cd4.tar.gz
BUG: fix segfault in object/longdouble operations
The operation None*np.longdouble(3) was causing infinite recursion as it searched for the appropriate conversion method. This resolves that, both for general operations and for remainders specifically (they fail in a subtly different way). Closes #18548
-rw-r--r--numpy/core/src/umath/scalarmath.c.src61
-rw-r--r--numpy/core/tests/test_scalarmath.py81
2 files changed, 141 insertions, 1 deletions
diff --git a/numpy/core/src/umath/scalarmath.c.src b/numpy/core/src/umath/scalarmath.c.src
index 86dade0f1..32ddb58a2 100644
--- a/numpy/core/src/umath/scalarmath.c.src
+++ b/numpy/core/src/umath/scalarmath.c.src
@@ -737,6 +737,9 @@ _@name@_convert2_to_ctypes(PyObject *a, @type@ *arg1,
{
int ret;
ret = _@name@_convert_to_ctype(a, arg1);
+ if (ret == -2) {
+ ret = -3;
+ }
if (ret < 0) {
return ret;
}
@@ -1029,7 +1032,7 @@ static PyObject *
/**begin repeat
*
- * #name = cfloat, cdouble, clongdouble#
+ * #name = cfloat, cdouble#
*
*/
@@ -1047,6 +1050,62 @@ static PyObject *
/**begin repeat
*
+ * #oper = divmod, remainder#
+ *
+ */
+
+/*
+Complex numbers do not support remainder operations. Unfortunately,
+the type inference for long doubles is complicated, and if a remainder
+operation is not defined - if the relevant field is left NULL - then
+operations between long doubles and objects lead to an infinite recursion
+instead of a TypeError. This should ensure that once everything gets
+converted to complex long doubles you correctly get a reasonably
+informative TypeError. This fixes the last part of bug gh-18548.
+*/
+
+static PyObject *
+clongdouble_@oper@(PyObject *a, PyObject *b)
+{
+ PyObject *ret;
+ npy_clongdouble arg1, arg2;
+ npy_clongdouble out;
+
+ BINOP_GIVE_UP_IF_NEEDED(a, b, nb_@oper@, clongdouble_@oper@);
+
+ switch(_clongdouble_convert2_to_ctypes(a, &arg1, b, &arg2)) {
+ case 0:
+ break;
+ case -1:
+ /* one of them can't be cast safely must be mixed-types*/
+ return PyArray_Type.tp_as_number->nb_@oper@(a,b);
+ case -2:
+ /* use default handling */
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ return PyGenericArrType_Type.tp_as_number->nb_@oper@(a,b);
+ case -3:
+ /*
+ * special case for longdouble and clongdouble
+ * because they have a recursive getitem in their dtype
+ */
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
+ /*
+ * here we do the actual calculation with arg1 and arg2
+ * as a function call.
+ */
+ PyErr_SetString(PyExc_TypeError, "complex long doubles do not support remainder");
+ return NULL;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ *
* #name = half, float, double, longdouble, cfloat, cdouble, clongdouble#
*
*/
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index 0b615edfa..c27a732a7 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -1,9 +1,12 @@
+import contextlib
import sys
import warnings
import itertools
import operator
import platform
import pytest
+from hypothesis import given, settings, Verbosity, assume
+from hypothesis.strategies import sampled_from
import numpy as np
from numpy.testing import (
@@ -707,3 +710,81 @@ class TestBitShifts:
shift_arr = np.array([shift]*32, dtype=dt)
res_arr = op(val_arr, shift_arr)
assert_equal(res_arr, res_scl)
+
+
+@contextlib.contextmanager
+def recursionlimit(n):
+ o = sys.getrecursionlimit()
+ try:
+ sys.setrecursionlimit(n)
+ yield
+ finally:
+ sys.setrecursionlimit(o)
+
+
+objecty_things = [object(), None]
+reasonable_operators_for_scalars = [
+ operator.lt, operator.le, operator.eq, operator.ne, operator.ge,
+ operator.gt, operator.add, operator.floordiv, operator.mod,
+ operator.mul, operator.matmul, operator.pow, operator.sub,
+ operator.truediv,
+]
+
+
+@given(sampled_from(objecty_things),
+ sampled_from(reasonable_operators_for_scalars),
+ sampled_from(types))
+@settings(verbosity=Verbosity.verbose)
+def test_operator_object_left(o, op, type_):
+ try:
+ with recursionlimit(100):
+ op(o, type_(1))
+ except TypeError:
+ pass
+
+
+@given(sampled_from(objecty_things),
+ sampled_from(reasonable_operators_for_scalars),
+ sampled_from(types))
+def test_operator_object_right(o, op, type_):
+ try:
+ with recursionlimit(100):
+ op(type_(1), o)
+ except TypeError:
+ pass
+
+
+@given(sampled_from(reasonable_operators_for_scalars),
+ sampled_from(types),
+ sampled_from(types))
+def test_operator_scalars(op, type1, type2):
+ try:
+ op(type1(1), type2(1))
+ except TypeError:
+ pass
+
+
+@pytest.mark.parametrize("op", reasonable_operators_for_scalars)
+def test_longdouble_inf_loop(op):
+ try:
+ op(np.longdouble(3), None)
+ except TypeError:
+ pass
+ try:
+ op(None, np.longdouble(3))
+ except TypeError:
+ pass
+
+
+@pytest.mark.parametrize("op", reasonable_operators_for_scalars)
+def test_clongdouble_inf_loop(op):
+ if op in {operator.mod} and False:
+ pytest.xfail("The modulo operator is known to be broken")
+ try:
+ op(np.clongdouble(3), None)
+ except TypeError:
+ pass
+ try:
+ op(None, np.longdouble(3))
+ except TypeError:
+ pass