summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-04-25 12:22:53 +0200
committerSebastian Berg <sebastianb@nvidia.com>2023-04-25 12:39:33 +0200
commit61610e74340a4a22f2782274600ae34bd882b929 (patch)
tree5ed0568467fcdb3ca534b8d58432ec6d6e579bac
parent6f3e1f458e04d13bdd56cff5669f9fd96a25fb66 (diff)
downloadnumpy-61610e74340a4a22f2782274600ae34bd882b929.tar.gz
ENH: Restore TypeError cleanup in array function dispatching
When the dispathcer raises a TypeError and it starts with the dispatchers name (or actually __qualname__ not that it normally matters), then it is nicer for users if we just raise a new error with the public symbol name. Python does not seem to normalize exception and goes down the unicode path, but I assume that e.g. PyPy may not do that. And there might be other weirder reason why we go down the full path. I have manually tested it by forcing Normalization. Closes gh-23029
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c82
-rw-r--r--numpy/core/tests/test_overrides.py13
2 files changed, 93 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
index 04768504e..63d109ecb 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.c
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -419,6 +419,9 @@ typedef struct {
PyObject *dict;
PyObject *relevant_arg_func;
PyObject *default_impl;
+ /* The following fields are used to clean up TypeError messages only: */
+ PyObject *dispatcher_name;
+ PyObject *public_name;
} PyArray_ArrayFunctionDispatcherObject;
@@ -428,10 +431,69 @@ dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self)
Py_CLEAR(self->relevant_arg_func);
Py_CLEAR(self->default_impl);
Py_CLEAR(self->dict);
+ Py_CLEAR(self->dispatcher_name);
+ Py_CLEAR(self->public_name);
PyObject_FREE(self);
}
+static void
+fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self)
+{
+ if (!PyErr_ExceptionMatches(PyExc_TypeError)) {
+ return;
+ }
+
+ PyObject *exc, *val, *tb, *message;
+ PyErr_Fetch(&exc, &val, &tb);
+
+ if (!PyUnicode_CheckExact(val)) {
+ /*
+ * We expect the error to be unnormalized, but maybe it isn't always
+ * the case, so normalize and fetch args[0] if it isn't a string.
+ */
+ PyErr_NormalizeException(&exc, &val, &tb);
+
+ PyObject *args = PyObject_GetAttrString(val, "args");
+ if (args == NULL || !PyTuple_CheckExact(args)
+ || PyTuple_GET_SIZE(args) != 1) {
+ Py_XDECREF(args);
+ goto restore_error;
+ }
+ message = PyTuple_GET_ITEM(args, 0);
+ Py_INCREF(message);
+ Py_DECREF(args);
+ if (!PyUnicode_CheckExact(message)) {
+ Py_DECREF(message);
+ goto restore_error;
+ }
+ }
+ else {
+ Py_INCREF(val);
+ message = val;
+ }
+
+ Py_ssize_t cmp = PyUnicode_Tailmatch(
+ message, self->dispatcher_name, 0, -1, 0);
+ if (cmp <= 0) {
+ Py_DECREF(message);
+ goto restore_error;
+ }
+ Py_SETREF(message, PyUnicode_Replace(
+ message, self->dispatcher_name, self->public_name, 1));
+ if (message == NULL) {
+ goto restore_error;
+ }
+ PyErr_SetObject(PyExc_TypeError, message);
+ Py_DECREF(message);
+ return;
+
+ restore_error:
+ /* replacement not successful, so restore original error */
+ PyErr_Restore(exc, val, tb);
+}
+
+
static PyObject *
dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
@@ -458,6 +520,7 @@ dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
relevant_args = PyObject_Vectorcall(
self->relevant_arg_func, args, len_args, kwnames);
if (relevant_args == NULL) {
+ fix_name_if_typeerror(self);
return NULL;
}
Py_SETREF(relevant_args, PySequence_Fast(relevant_args,
@@ -600,14 +663,31 @@ dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs)
}
self->vectorcall = (vectorcallfunc)dispatcher_vectorcall;
+ Py_INCREF(self->default_impl);
+ self->dict = NULL;
+ self->dispatcher_name = NULL;
+ self->public_name = NULL;
+
if (self->relevant_arg_func == Py_None) {
/* NULL in the relevant arg function means we use `like=` */
Py_CLEAR(self->relevant_arg_func);
}
else {
+ /* Fetch names to clean up TypeErrors (show actual name) */
Py_INCREF(self->relevant_arg_func);
+ self->dispatcher_name = PyObject_GetAttrString(
+ self->relevant_arg_func, "__qualname__");
+ if (self->dispatcher_name == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
+ self->public_name = PyObject_GetAttrString(
+ self->default_impl, "__qualname__");
+ if (self->public_name == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
}
- Py_INCREF(self->default_impl);
/* Need to be like a Python function that has arbitrary attributes */
self->dict = PyDict_New();
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 25f551f6f..65155b207 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -359,6 +359,17 @@ class TestArrayFunctionImplementation:
TypeError, "no implementation found for 'my.func'"):
func(MyArray())
+ @pytest.mark.parametrize("name", ["concatenate", "mean", "asarray"])
+ def test_signature_error_message_simple(self, name):
+ func = getattr(np, name)
+ try:
+ # all of these functions need an argument:
+ func()
+ except TypeError as e:
+ exc = e
+
+ assert exc.args[0].startswith(f"{name}()")
+
def test_signature_error_message(self):
# The lambda function will be named "<lambda>", but the TypeError
# should show the name as "func"
@@ -370,7 +381,7 @@ class TestArrayFunctionImplementation:
pass
try:
- func(bad_arg=3)
+ func._implementation(bad_arg=3)
except TypeError as e:
expected_exception = e