diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 230 |
1 files changed, 143 insertions, 87 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 906292613..55c7bd1ea 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -1,69 +1,24 @@ -"""Preliminary implementation of NEP-18 - -TODO: rewrite this in C for performance. -""" +"""Implementation of __array_function__ overrides from NEP-18.""" import collections import functools +import os +import textwrap -from numpy.core.multiarray import ndarray +from numpy.core._multiarray_umath import ( + add_docstring, implement_array_function, _get_implementing_args) from numpy.compat._inspect import getargspec -_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ - - -def get_overloaded_types_and_args(relevant_args): - """Returns a list of arguments on which to call __array_function__. +ARRAY_FUNCTION_ENABLED = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1))) - Parameters - ---------- - relevant_args : iterable of array-like - Iterable of array-like arguments to check for __array_function__ - methods. - Returns - ------- - overloaded_types : collection of types - Types of arguments from relevant_args with __array_function__ methods. - overloaded_args : list - Arguments from relevant_args on which to call __array_function__ - methods, in the order in which they should be called. +add_docstring( + implement_array_function, """ - # Runtime is O(num_arguments * num_unique_types) - overloaded_types = [] - overloaded_args = [] - for arg in relevant_args: - arg_type = type(arg) - # We only collect arguments if they have a unique type, which ensures - # reasonable performance even with a long list of possibly overloaded - # arguments. - if (arg_type not in overloaded_types and - hasattr(arg_type, '__array_function__')): - - overloaded_types.append(arg_type) - - # By default, insert this argument at the end, but if it is - # subclass of another argument, insert it before that argument. - # This ensures "subclasses before superclasses". - index = len(overloaded_args) - for i, old_arg in enumerate(overloaded_args): - if issubclass(arg_type, type(old_arg)): - index = i - break - overloaded_args.insert(index, arg) - - # Special handling for ndarray.__array_function__ - overloaded_args = [ - arg for arg in overloaded_args - if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION - ] - - return overloaded_types, overloaded_args - - -def array_function_implementation_or_override( - implementation, public_api, relevant_args, args, kwargs): - """Implement a function with checks for __array_function__ overrides. + Implement a function with checks for __array_function__ overrides. + + All arguments are required, and can only be passed by position. Arguments --------- @@ -71,43 +26,44 @@ def array_function_implementation_or_override( Function that implements the operation on NumPy array without overrides when called like ``implementation(*args, **kwargs)``. public_api : function - Function exposed by NumPy's public API riginally called like - ``public_api(*args, **kwargs`` on which arguments are now being + Function exposed by NumPy's public API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being checked. relevant_args : iterable Iterable of arguments to check for __array_function__ methods. args : tuple Arbitrary positional arguments originally passed into ``public_api``. - kwargs : tuple + kwargs : dict Arbitrary keyword arguments originally passed into ``public_api``. Returns ------- - Result from calling `implementation()` or an `__array_function__` + Result from calling ``implementation()`` or an ``__array_function__`` method, as appropriate. Raises ------ TypeError : if no implementation is found. - """ - # Check for __array_function__ methods. - types, overloaded_args = get_overloaded_types_and_args(relevant_args) - if not overloaded_args: - return implementation(*args, **kwargs) + """) + - # Call overrides - for overloaded_arg in overloaded_args: - # Use `public_api` instead of `implemenation` so __array_function__ - # implementations can do equality/identity comparisons. - result = overloaded_arg.__array_function__( - public_api, types, args, kwargs) +# exposed for testing purposes; used internally by implement_array_function +add_docstring( + _get_implementing_args, + """ + Collect arguments on which to call __array_function__. - if result is not NotImplemented: - return result + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of possibly array-like arguments to check for + __array_function__ methods. - raise TypeError('no implementation found for {} on types that implement ' - '__array_function__: {}' - .format(public_api, list(map(type, overloaded_args)))) + Returns + ------- + Sequence of arguments with __array_function__ methods, in the order in + which they should be called. + """) ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') @@ -135,20 +91,120 @@ def verify_matching_signatures(implementation, dispatcher): 'default argument values') -def array_function_dispatch(dispatcher, verify=True): - """Decorator for adding dispatch with the __array_function__ protocol.""" +def set_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @set_module('numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator + + + +# Call textwrap.dedent here instead of in the function so as to avoid +# calling dedent multiple times on the same text +_wrapped_func_source = textwrap.dedent(""" + @functools.wraps(implementation) + def {name}(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return implement_array_function( + implementation, {name}, relevant_args, args, kwargs) + """) + + +def array_function_dispatch(dispatcher, module=None, verify=True, + docs_from_dispatcher=False): + """Decorator for adding dispatch with the __array_function__ protocol. + + See NEP-18 for example usage. + + Parameters + ---------- + dispatcher : callable + Function that when called like ``dispatcher(*args, **kwargs)`` with + arguments from the NumPy function call returns an iterable of + array-like arguments to check for ``__array_function__``. + module : str, optional + __module__ attribute to set on new function, e.g., ``module='numpy'``. + By default, module is copied from the decorated function. + verify : bool, optional + If True, verify the that the signature of the dispatcher and decorated + function signatures match exactly: all required and optional arguments + should appear in order with the same names, but the default values for + all optional arguments should be ``None``. Only disable verification + if the dispatcher's signature needs to deviate for some particular + reason, e.g., because the function has a signature like + ``func(*args, **kwargs)``. + docs_from_dispatcher : bool, optional + If True, copy docs from the dispatcher function onto the dispatched + function, rather than from the implementation. This is useful for + functions defined in C, which otherwise don't have docstrings. + + Returns + ------- + Function suitable for decorating the implementation of a NumPy function. + """ + + if not ARRAY_FUNCTION_ENABLED: + def decorator(implementation): + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + if module is not None: + implementation.__module__ = module + return implementation + return decorator + def decorator(implementation): - # TODO: only do this check when the appropriate flag is enabled or for - # a dev install. We want this check for testing but don't want to - # slow down all numpy imports. if verify: verify_matching_signatures(implementation, dispatcher) - @functools.wraps(implementation) - def public_api(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) - return array_function_implementation_or_override( - implementation, public_api, relevant_args, args, kwargs) + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + + # Equivalently, we could define this function directly instead of using + # exec. This version has the advantage of giving the helper function a + # more interpettable name. Otherwise, the original function does not + # show up at all in many cases, e.g., if it's written in C or if the + # dispatcher gets an invalid keyword argument. + source = _wrapped_func_source.format(name=implementation.__name__) + + source_object = compile( + source, filename='<__array_function__ internals>', mode='exec') + scope = { + 'implementation': implementation, + 'dispatcher': dispatcher, + 'functools': functools, + 'implement_array_function': implement_array_function, + } + exec(source_object, scope) + + public_api = scope[implementation.__name__] + + if module is not None: + public_api.__module__ = module + + public_api._implementation = implementation + return public_api return decorator + + +def array_function_from_dispatcher( + implementation, module=None, verify=True, docs_from_dispatcher=True): + """Like array_function_dispatcher, but with function arguments flipped.""" + + def decorator(dispatcher): + return array_function_dispatch( + dispatcher, module, verify=verify, + docs_from_dispatcher=docs_from_dispatcher)(implementation) + return decorator |