diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-06-10 18:49:06 -0700 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2022-06-10 19:41:48 -0700 |
commit | 8fabdecb556354559f9dcc280c75001f3df8aaa8 (patch) | |
tree | 40c1250ef7a0bd2a56bf477a4359b3ddf42de02a | |
parent | 4cb688938be2ff20bf939c3917fde2f7634f0500 (diff) | |
download | numpy-8fabdecb556354559f9dcc280c75001f3df8aaa8.tar.gz |
ENH: Ensure dispatcher TypeErrors report original name
This replaces the name in the TypeError with the actually raised
name. In principle we could add one more check, because a
signature related TypeError must have a traceback with exactly
one entry (so `sys.exc_info()[2].tb_next is None`).
In practice this seems unnecessary though.
This ensures the following message:
>>> np.histogram(asdf=3)
TypeError: histogram() got an unexpected keyword argument 'asdf'
Closes gh-21647
-rw-r--r-- | numpy/core/overrides.py | 23 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 39 |
2 files changed, 61 insertions, 1 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index cb550152e..663436a4c 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -2,6 +2,7 @@ import collections import functools import os +import sys from numpy.core._multiarray_umath import ( add_docstring, implement_array_function, _get_implementing_args) @@ -176,7 +177,27 @@ def array_function_dispatch(dispatcher, module=None, verify=True, @functools.wraps(implementation) def public_api(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) + try: + relevant_args = dispatcher(*args, **kwargs) + except TypeError as exc: + # Try to clean up a signature related TypeError. Such an + # error will be something like: + # dispatcher.__name__() got an unexpected keyword argument + # + # So replace the dispatcher name in this case. In principle + # TypeErrors may be raised from _within_ the dispatcher, so + # we check that the traceback contains a string that starts + # with the name. (In principle we could also check the + # traceback length, as it would be deeper.) + msg = exc.args[0] + disp_name = dispatcher.__name__ + if not isinstance(msg, str) or not msg.startswith(disp_name): + raise + + # Replace with the correct name and re-raise: + new_msg = msg.replace(disp_name, public_api.__name__) + raise TypeError(new_msg) from None + return implement_array_function( implementation, public_api, relevant_args, args, kwargs) diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 36970dbc0..e68406ebd 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -355,6 +355,45 @@ class TestArrayFunctionImplementation: TypeError, "no implementation found for 'my.func'"): func(MyArray()) + def test_signature_error_message(self): + # The lambda function will be named "<lambda>", but the TypeError + # should show the name as "func" + def _dispatcher(): + return () + + @array_function_dispatch(_dispatcher) + def func(): + pass + + try: + func(bad_arg=3) + except TypeError as e: + expected_exception = e + + try: + func(bad_arg=3) + raise AssertionError("must fail") + except TypeError as exc: + assert exc.args == expected_exception.args + + @pytest.mark.parametrize("value", [234, "this func is not replaced"]) + def test_dispatcher_error(self, value): + # If the dispatcher raises an error, we must not attempt to mutate it + error = TypeError(value) + + def dispatcher(): + raise error + + @array_function_dispatch(dispatcher) + def func(): + return 3 + + try: + func() + raise AssertionError("must fail") + except TypeError as exc: + assert exc is error # unmodified exception + class TestNDArrayMethods: |