summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-06-10 18:49:06 -0700
committerSebastian Berg <sebastian@sipsolutions.net>2022-06-10 19:41:48 -0700
commit8fabdecb556354559f9dcc280c75001f3df8aaa8 (patch)
tree40c1250ef7a0bd2a56bf477a4359b3ddf42de02a
parent4cb688938be2ff20bf939c3917fde2f7634f0500 (diff)
downloadnumpy-8fabdecb556354559f9dcc280c75001f3df8aaa8.tar.gz
ENH: Ensure dispatcher TypeErrors report original name
This replaces the name in the TypeError with the actually raised name. In principle we could add one more check, because a signature related TypeError must have a traceback with exactly one entry (so `sys.exc_info()[2].tb_next is None`). In practice this seems unnecessary though. This ensures the following message: >>> np.histogram(asdf=3) TypeError: histogram() got an unexpected keyword argument 'asdf' Closes gh-21647
-rw-r--r--numpy/core/overrides.py23
-rw-r--r--numpy/core/tests/test_overrides.py39
2 files changed, 61 insertions, 1 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index cb550152e..663436a4c 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -2,6 +2,7 @@
import collections
import functools
import os
+import sys
from numpy.core._multiarray_umath import (
add_docstring, implement_array_function, _get_implementing_args)
@@ -176,7 +177,27 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
@functools.wraps(implementation)
def public_api(*args, **kwargs):
- relevant_args = dispatcher(*args, **kwargs)
+ try:
+ relevant_args = dispatcher(*args, **kwargs)
+ except TypeError as exc:
+ # Try to clean up a signature related TypeError. Such an
+ # error will be something like:
+ # dispatcher.__name__() got an unexpected keyword argument
+ #
+ # So replace the dispatcher name in this case. In principle
+ # TypeErrors may be raised from _within_ the dispatcher, so
+ # we check that the traceback contains a string that starts
+ # with the name. (In principle we could also check the
+ # traceback length, as it would be deeper.)
+ msg = exc.args[0]
+ disp_name = dispatcher.__name__
+ if not isinstance(msg, str) or not msg.startswith(disp_name):
+ raise
+
+ # Replace with the correct name and re-raise:
+ new_msg = msg.replace(disp_name, public_api.__name__)
+ raise TypeError(new_msg) from None
+
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 36970dbc0..e68406ebd 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -355,6 +355,45 @@ class TestArrayFunctionImplementation:
TypeError, "no implementation found for 'my.func'"):
func(MyArray())
+ def test_signature_error_message(self):
+ # The lambda function will be named "<lambda>", but the TypeError
+ # should show the name as "func"
+ def _dispatcher():
+ return ()
+
+ @array_function_dispatch(_dispatcher)
+ def func():
+ pass
+
+ try:
+ func(bad_arg=3)
+ except TypeError as e:
+ expected_exception = e
+
+ try:
+ func(bad_arg=3)
+ raise AssertionError("must fail")
+ except TypeError as exc:
+ assert exc.args == expected_exception.args
+
+ @pytest.mark.parametrize("value", [234, "this func is not replaced"])
+ def test_dispatcher_error(self, value):
+ # If the dispatcher raises an error, we must not attempt to mutate it
+ error = TypeError(value)
+
+ def dispatcher():
+ raise error
+
+ @array_function_dispatch(dispatcher)
+ def func():
+ return 3
+
+ try:
+ func()
+ raise AssertionError("must fail")
+ except TypeError as exc:
+ assert exc is error # unmodified exception
+
class TestNDArrayMethods: