summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
blob: f2997b600d4d65ad3700cefca6d88cd3906475b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Preliminary implementation of NEP-18

TODO: rewrite this in C for performance.
"""
import functools
from numpy.core.multiarray import ndarray


_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__


def get_overloaded_types_and_args(relevant_args):
    """Returns a list of arguments on which to call __array_function__.

    Parameters
    ----------
    relevant_args : iterable of array-like
        Iterable of array-like arguments to check for __array_function__
        methods.

    Returns
    -------
    overloaded_types : collection of types
        Types of arguments from relevant_args with __array_function__ methods.
    overloaded_args : list
        Arguments from relevant_args on which to call __array_function__
        methods, in the order in which they should be called.
    """
    # Runtime is O(num_arguments * num_unique_types)
    overloaded_types = []
    overloaded_args = []
    for arg in relevant_args:
        arg_type = type(arg)
        if (arg_type not in overloaded_types and
                hasattr(arg_type, '__array_function__')):

            overloaded_types.append(arg_type)

            # By default, insert this argument at the end, but if it is
            # subclass of another argument, insert it before that argument.
            # This ensures "subclasses before superclasses".
            index = len(overloaded_args)
            for i, old_arg in enumerate(overloaded_args):
                if issubclass(arg_type, type(old_arg)):
                    index = i
                    break
            overloaded_args.insert(index, arg)

    # Special handling for ndarray.__array_function__
    overloaded_args = [
        arg for arg in overloaded_args
        if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
    ]

    return overloaded_types, overloaded_args


def array_function_implementation_or_override(
        implementation_func, api_func, dispatcher, args, kwargs):
    """Implement a function with checks for __array_function__ overrides.

    Arguments
    ---------
    implementation_func : function
        Function that implements the operation on NumPy array without
        overrides when called like `implementation_func(*args, **kwargs)`.
    api_func : function
        Function exposed by NumPy's public API  on which overrides are being
        checked here.
    dispatcher : callable
        Function that when called like `dispatcher(*args, **kwargs)` returns an
        iterable of relevant argument to check to for __array_function__
        attributes.
    args : tuple
        Arbitrary positional arguments originally passed into api_func.
    kwargs : tuple
        Arbitrary keyword arguments originally passed into api_func.

    Returns
    -------
    Result from calling `implementation_func()` or an `__array_function__`
    method, as appropriate.

    Raises
    ------
    TypeError : if no implementation is found.
    """

    # Collect array-like arguments.
    relevant_arguments = dispatcher(*args, **kwargs)

    # Check for __array_function__ methods.
    types, overloaded_args = get_overloaded_types_and_args(
        relevant_arguments)

    # Fast path
    if not overloaded_args:
        return implementation_func(*args, **kwargs)

    # Call overrides
    for overloaded_arg in overloaded_args:
        # Note that we're only calling __array_function__ on the *first*
        # occurence of each argument type. This is necessary for reasonable
        # performance with a possibly long list of overloaded arguments, for
        # which each __array_function__ implementation might reasonably need to
        # check all argument types.
        # api_func is the function exposed in NumPy's public API. We
        # use it instead of func so __array_function__ implementations
        # can do equality/identity comparisons.
        result = overloaded_arg.__array_function__(
            api_func, types, args, kwargs)

        if result is not NotImplemented:
            return result

    raise TypeError('no implementation found for {} on types that implement '
                    '__array_function__: {}'
                    .format(api_func, list(map(type, overloaded_args))))


def array_function_dispatch(dispatcher):
    """Wrap a function for dispatch with the __array_function__ protocol."""
    def decorator(implementation_func):
        @functools.wraps(implementation_func)
        def api_func(*args, **kwargs):
            return array_function_implementation_or_override(
                implementation_func, api_func, dispatcher, args, kwargs)
        return api_func
    return decorator