diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2023-05-16 09:28:34 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-16 09:28:34 -0600 |
commit | 6a4abb065afa5cd9967b92476cd8a245edf3dfcd (patch) | |
tree | 7a16f842650373e869fccdf2299259bd27a989cd | |
parent | a4a951d2ba256f6a391ae6dca30bee2bb491a59f (diff) | |
parent | 8b7f69ceae5cd99592f79121a1bd7b014af4833c (diff) | |
download | numpy-6a4abb065afa5cd9967b92476cd8a245edf3dfcd.tar.gz |
Merge pull request #23659 from seberg/issue-23029
ENH: Restore TypeError cleanup in array function dispatching
-rw-r--r-- | numpy/core/src/multiarray/arrayfunction_override.c | 85 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 19 |
2 files changed, 102 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c index 04768504e..3c55e2164 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.c +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -419,6 +419,9 @@ typedef struct { PyObject *dict; PyObject *relevant_arg_func; PyObject *default_impl; + /* The following fields are used to clean up TypeError messages only: */ + PyObject *dispatcher_name; + PyObject *public_name; } PyArray_ArrayFunctionDispatcherObject; @@ -428,10 +431,72 @@ dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self) Py_CLEAR(self->relevant_arg_func); Py_CLEAR(self->default_impl); Py_CLEAR(self->dict); + Py_CLEAR(self->dispatcher_name); + Py_CLEAR(self->public_name); PyObject_FREE(self); } +static void +fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self) +{ + if (!PyErr_ExceptionMatches(PyExc_TypeError)) { + return; + } + + PyObject *exc, *val, *tb, *message; + PyErr_Fetch(&exc, &val, &tb); + + if (!PyUnicode_CheckExact(val)) { + /* + * We expect the error to be unnormalized, but maybe it isn't always + * the case, so normalize and fetch args[0] if it isn't a string. + */ + PyErr_NormalizeException(&exc, &val, &tb); + + PyObject *args = PyObject_GetAttrString(val, "args"); + if (args == NULL || !PyTuple_CheckExact(args) + || PyTuple_GET_SIZE(args) != 1) { + Py_XDECREF(args); + goto restore_error; + } + message = PyTuple_GET_ITEM(args, 0); + Py_INCREF(message); + Py_DECREF(args); + if (!PyUnicode_CheckExact(message)) { + Py_DECREF(message); + goto restore_error; + } + } + else { + Py_INCREF(val); + message = val; + } + + Py_ssize_t cmp = PyUnicode_Tailmatch( + message, self->dispatcher_name, 0, -1, -1); + if (cmp <= 0) { + Py_DECREF(message); + goto restore_error; + } + Py_SETREF(message, PyUnicode_Replace( + message, self->dispatcher_name, self->public_name, 1)); + if (message == NULL) { + goto restore_error; + } + PyErr_SetObject(PyExc_TypeError, message); + Py_DECREF(exc); + Py_XDECREF(val); + Py_XDECREF(tb); + Py_DECREF(message); + return; + + restore_error: + /* replacement not successful, so restore original error */ + PyErr_Restore(exc, val, tb); +} + + static PyObject * dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) @@ -458,6 +523,7 @@ dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, relevant_args = PyObject_Vectorcall( self->relevant_arg_func, args, len_args, kwnames); if (relevant_args == NULL) { + fix_name_if_typeerror(self); return NULL; } Py_SETREF(relevant_args, PySequence_Fast(relevant_args, @@ -600,14 +666,31 @@ dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs) } self->vectorcall = (vectorcallfunc)dispatcher_vectorcall; + Py_INCREF(self->default_impl); + self->dict = NULL; + self->dispatcher_name = NULL; + self->public_name = NULL; + if (self->relevant_arg_func == Py_None) { /* NULL in the relevant arg function means we use `like=` */ Py_CLEAR(self->relevant_arg_func); } else { + /* Fetch names to clean up TypeErrors (show actual name) */ Py_INCREF(self->relevant_arg_func); + self->dispatcher_name = PyObject_GetAttrString( + self->relevant_arg_func, "__qualname__"); + if (self->dispatcher_name == NULL) { + Py_DECREF(self); + return NULL; + } + self->public_name = PyObject_GetAttrString( + self->default_impl, "__qualname__"); + if (self->public_name == NULL) { + Py_DECREF(self); + return NULL; + } } - Py_INCREF(self->default_impl); /* Need to be like a Python function that has arbitrary attributes */ self->dict = PyDict_New(); diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 25f551f6f..5924358ea 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -359,6 +359,17 @@ class TestArrayFunctionImplementation: TypeError, "no implementation found for 'my.func'"): func(MyArray()) + @pytest.mark.parametrize("name", ["concatenate", "mean", "asarray"]) + def test_signature_error_message_simple(self, name): + func = getattr(np, name) + try: + # all of these functions need an argument: + func() + except TypeError as e: + exc = e + + assert exc.args[0].startswith(f"{name}()") + def test_signature_error_message(self): # The lambda function will be named "<lambda>", but the TypeError # should show the name as "func" @@ -370,7 +381,7 @@ class TestArrayFunctionImplementation: pass try: - func(bad_arg=3) + func._implementation(bad_arg=3) except TypeError as e: expected_exception = e @@ -378,6 +389,12 @@ class TestArrayFunctionImplementation: func(bad_arg=3) raise AssertionError("must fail") except TypeError as exc: + if exc.args[0].startswith("_dispatcher"): + # We replace the qualname currently, but it used `__name__` + # (relevant functions have the same name and qualname anyway) + pytest.skip("Python version is not using __qualname__ for " + "TypeError formatting.") + assert exc.args == expected_exception.args @pytest.mark.parametrize("value", [234, "this func is not replaced"]) |