summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-11-10 16:12:18 -0800
committerStephan Hoyer <shoyer@google.com>2018-11-10 16:12:18 -0800
commitb44284ebc42c496e5c5d906acc33ebbc337fd3b1 (patch)
tree52322a497bc9c0def2065c48ce0c00e8eee750a3 /numpy/core/overrides.py
parent56ce2327462eb9e3980c568ce9be628892aad89f (diff)
downloadnumpy-b44284ebc42c496e5c5d906acc33ebbc337fd3b1.tar.gz
MAINT: more fixes for disabling overrides
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py34
1 files changed, 22 insertions, 12 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index efa71dfb6..84e6764af 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -150,21 +150,31 @@ def verify_matching_signatures(implementation, dispatcher):
'default argument values')
+def override_module(module):
+ """Decorator for overriding __module__ on a function or class."""
+ def decorator(func):
+ if module is not None:
+ func.__module__ = module
+ return func
+ return decorator
+
+
def array_function_dispatch(dispatcher, module=None, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
+
+ if not ENABLE_ARRAY_FUNCTION:
+ # __array_function__ requires an explicit opt-in for now
+ return override_module(module)
+
def decorator(implementation):
- if not ENABLE_ARRAY_FUNCTION:
- # __array_function__ requires an explicit opt-in for now
- public_api = implementation
- else:
- if verify:
- verify_matching_signatures(implementation, dispatcher)
-
- @functools.wraps(implementation)
- def public_api(*args, **kwargs):
- relevant_args = dispatcher(*args, **kwargs)
- return array_function_implementation_or_override(
- implementation, public_api, relevant_args, args, kwargs)
+ if verify:
+ verify_matching_signatures(implementation, dispatcher)
+
+ @functools.wraps(implementation)
+ def public_api(*args, **kwargs):
+ relevant_args = dispatcher(*args, **kwargs)
+ return array_function_implementation_or_override(
+ implementation, public_api, relevant_args, args, kwargs)
if module is not None:
public_api.__module__ = module