summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py230
1 files changed, 143 insertions, 87 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 906292613..55c7bd1ea 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -1,69 +1,24 @@
-"""Preliminary implementation of NEP-18
-
-TODO: rewrite this in C for performance.
-"""
+"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
+import os
+import textwrap
-from numpy.core.multiarray import ndarray
+from numpy.core._multiarray_umath import (
+ add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
-_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
-
-
-def get_overloaded_types_and_args(relevant_args):
- """Returns a list of arguments on which to call __array_function__.
+ARRAY_FUNCTION_ENABLED = bool(
+ int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
- Parameters
- ----------
- relevant_args : iterable of array-like
- Iterable of array-like arguments to check for __array_function__
- methods.
- Returns
- -------
- overloaded_types : collection of types
- Types of arguments from relevant_args with __array_function__ methods.
- overloaded_args : list
- Arguments from relevant_args on which to call __array_function__
- methods, in the order in which they should be called.
+add_docstring(
+ implement_array_function,
"""
- # Runtime is O(num_arguments * num_unique_types)
- overloaded_types = []
- overloaded_args = []
- for arg in relevant_args:
- arg_type = type(arg)
- # We only collect arguments if they have a unique type, which ensures
- # reasonable performance even with a long list of possibly overloaded
- # arguments.
- if (arg_type not in overloaded_types and
- hasattr(arg_type, '__array_function__')):
-
- overloaded_types.append(arg_type)
-
- # By default, insert this argument at the end, but if it is
- # subclass of another argument, insert it before that argument.
- # This ensures "subclasses before superclasses".
- index = len(overloaded_args)
- for i, old_arg in enumerate(overloaded_args):
- if issubclass(arg_type, type(old_arg)):
- index = i
- break
- overloaded_args.insert(index, arg)
-
- # Special handling for ndarray.__array_function__
- overloaded_args = [
- arg for arg in overloaded_args
- if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
- ]
-
- return overloaded_types, overloaded_args
-
-
-def array_function_implementation_or_override(
- implementation, public_api, relevant_args, args, kwargs):
- """Implement a function with checks for __array_function__ overrides.
+ Implement a function with checks for __array_function__ overrides.
+
+ All arguments are required, and can only be passed by position.
Arguments
---------
@@ -71,43 +26,44 @@ def array_function_implementation_or_override(
Function that implements the operation on NumPy array without
overrides when called like ``implementation(*args, **kwargs)``.
public_api : function
- Function exposed by NumPy's public API riginally called like
- ``public_api(*args, **kwargs`` on which arguments are now being
+ Function exposed by NumPy's public API originally called like
+ ``public_api(*args, **kwargs)`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
- kwargs : tuple
+ kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
- Result from calling `implementation()` or an `__array_function__`
+ Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
- """
- # Check for __array_function__ methods.
- types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- if not overloaded_args:
- return implementation(*args, **kwargs)
+ """)
+
- # Call overrides
- for overloaded_arg in overloaded_args:
- # Use `public_api` instead of `implemenation` so __array_function__
- # implementations can do equality/identity comparisons.
- result = overloaded_arg.__array_function__(
- public_api, types, args, kwargs)
+# exposed for testing purposes; used internally by implement_array_function
+add_docstring(
+ _get_implementing_args,
+ """
+ Collect arguments on which to call __array_function__.
- if result is not NotImplemented:
- return result
+ Parameters
+ ----------
+ relevant_args : iterable of array-like
+ Iterable of possibly array-like arguments to check for
+ __array_function__ methods.
- raise TypeError('no implementation found for {} on types that implement '
- '__array_function__: {}'
- .format(public_api, list(map(type, overloaded_args))))
+ Returns
+ -------
+ Sequence of arguments with __array_function__ methods, in the order in
+ which they should be called.
+ """)
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
@@ -135,20 +91,120 @@ def verify_matching_signatures(implementation, dispatcher):
'default argument values')
-def array_function_dispatch(dispatcher, verify=True):
- """Decorator for adding dispatch with the __array_function__ protocol."""
+def set_module(module):
+ """Decorator for overriding __module__ on a function or class.
+
+ Example usage::
+
+ @set_module('numpy')
+ def example():
+ pass
+
+ assert example.__module__ == 'numpy'
+ """
+ def decorator(func):
+ if module is not None:
+ func.__module__ = module
+ return func
+ return decorator
+
+
+
+# Call textwrap.dedent here instead of in the function so as to avoid
+# calling dedent multiple times on the same text
+_wrapped_func_source = textwrap.dedent("""
+ @functools.wraps(implementation)
+ def {name}(*args, **kwargs):
+ relevant_args = dispatcher(*args, **kwargs)
+ return implement_array_function(
+ implementation, {name}, relevant_args, args, kwargs)
+ """)
+
+
+def array_function_dispatch(dispatcher, module=None, verify=True,
+ docs_from_dispatcher=False):
+ """Decorator for adding dispatch with the __array_function__ protocol.
+
+ See NEP-18 for example usage.
+
+ Parameters
+ ----------
+ dispatcher : callable
+ Function that when called like ``dispatcher(*args, **kwargs)`` with
+ arguments from the NumPy function call returns an iterable of
+ array-like arguments to check for ``__array_function__``.
+ module : str, optional
+ __module__ attribute to set on new function, e.g., ``module='numpy'``.
+ By default, module is copied from the decorated function.
+ verify : bool, optional
+ If True, verify the that the signature of the dispatcher and decorated
+ function signatures match exactly: all required and optional arguments
+ should appear in order with the same names, but the default values for
+ all optional arguments should be ``None``. Only disable verification
+ if the dispatcher's signature needs to deviate for some particular
+ reason, e.g., because the function has a signature like
+ ``func(*args, **kwargs)``.
+ docs_from_dispatcher : bool, optional
+ If True, copy docs from the dispatcher function onto the dispatched
+ function, rather than from the implementation. This is useful for
+ functions defined in C, which otherwise don't have docstrings.
+
+ Returns
+ -------
+ Function suitable for decorating the implementation of a NumPy function.
+ """
+
+ if not ARRAY_FUNCTION_ENABLED:
+ def decorator(implementation):
+ if docs_from_dispatcher:
+ add_docstring(implementation, dispatcher.__doc__)
+ if module is not None:
+ implementation.__module__ = module
+ return implementation
+ return decorator
+
def decorator(implementation):
- # TODO: only do this check when the appropriate flag is enabled or for
- # a dev install. We want this check for testing but don't want to
- # slow down all numpy imports.
if verify:
verify_matching_signatures(implementation, dispatcher)
- @functools.wraps(implementation)
- def public_api(*args, **kwargs):
- relevant_args = dispatcher(*args, **kwargs)
- return array_function_implementation_or_override(
- implementation, public_api, relevant_args, args, kwargs)
+ if docs_from_dispatcher:
+ add_docstring(implementation, dispatcher.__doc__)
+
+ # Equivalently, we could define this function directly instead of using
+ # exec. This version has the advantage of giving the helper function a
+ # more interpettable name. Otherwise, the original function does not
+ # show up at all in many cases, e.g., if it's written in C or if the
+ # dispatcher gets an invalid keyword argument.
+ source = _wrapped_func_source.format(name=implementation.__name__)
+
+ source_object = compile(
+ source, filename='<__array_function__ internals>', mode='exec')
+ scope = {
+ 'implementation': implementation,
+ 'dispatcher': dispatcher,
+ 'functools': functools,
+ 'implement_array_function': implement_array_function,
+ }
+ exec(source_object, scope)
+
+ public_api = scope[implementation.__name__]
+
+ if module is not None:
+ public_api.__module__ = module
+
+ public_api._implementation = implementation
+
return public_api
return decorator
+
+
+def array_function_from_dispatcher(
+ implementation, module=None, verify=True, docs_from_dispatcher=True):
+ """Like array_function_dispatcher, but with function arguments flipped."""
+
+ def decorator(dispatcher):
+ return array_function_dispatch(
+ dispatcher, module, verify=verify,
+ docs_from_dispatcher=docs_from_dispatcher)(implementation)
+ return decorator