summaryrefslogtreecommitdiff
path: root/astroid/brain/brain_functools.py
blob: e863749499ae0bbfa93d97351eda8ad62a69f5ec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) 2016, 2018-2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2018 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
# Copyright (c) 2021 Alphadelta14 <alpha@alphaservcomputing.solutions>

"""Astroid hooks for understanding functools library module."""
from functools import partial
from itertools import chain
from typing import Iterator, Optional

from astroid import BoundMethod, arguments, extract_node, helpers, nodes, objects
from astroid.context import InferenceContext
from astroid.exceptions import InferenceError, UseInferenceDefault
from astroid.inference_tip import inference_tip
from astroid.interpreter import objectmodel
from astroid.manager import AstroidManager
from astroid.nodes.node_classes import AssignName, Attribute, Call, Name
from astroid.nodes.scoped_nodes import FunctionDef
from astroid.util import Uninferable

LRU_CACHE = "functools.lru_cache"


class LruWrappedModel(objectmodel.FunctionModel):
    """Special attribute model for functions decorated with functools.lru_cache.

    The said decorators patches at decoration time some functions onto
    the decorated function.
    """

    @property
    def attr___wrapped__(self):
        return self._instance

    @property
    def attr_cache_info(self):
        cache_info = extract_node(
            """
        from functools import _CacheInfo
        _CacheInfo(0, 0, 0, 0)
        """
        )

        class CacheInfoBoundMethod(BoundMethod):
            def infer_call_result(self, caller, context=None):
                yield helpers.safe_infer(cache_info)

        return CacheInfoBoundMethod(proxy=self._instance, bound=self._instance)

    @property
    def attr_cache_clear(self):
        node = extract_node("""def cache_clear(self): pass""")
        return BoundMethod(proxy=node, bound=self._instance.parent.scope())


def _transform_lru_cache(node, context=None) -> None:
    # TODO: this is not ideal, since the node should be immutable,
    # but due to https://github.com/PyCQA/astroid/issues/354,
    # there's not much we can do now.
    # Replacing the node would work partially, because,
    # in pylint, the old node would still be available, leading
    # to spurious false positives.
    node.special_attributes = LruWrappedModel()(node)


def _functools_partial_inference(
    node: nodes.Call, context: Optional[InferenceContext] = None
) -> Iterator[objects.PartialFunction]:
    call = arguments.CallSite.from_call(node, context=context)
    number_of_positional = len(call.positional_arguments)
    if number_of_positional < 1:
        raise UseInferenceDefault("functools.partial takes at least one argument")
    if number_of_positional == 1 and not call.keyword_arguments:
        raise UseInferenceDefault(
            "functools.partial needs at least to have some filled arguments"
        )

    partial_function = call.positional_arguments[0]
    try:
        inferred_wrapped_function = next(partial_function.infer(context=context))
    except (InferenceError, StopIteration) as exc:
        raise UseInferenceDefault from exc
    if inferred_wrapped_function is Uninferable:
        raise UseInferenceDefault("Cannot infer the wrapped function")
    if not isinstance(inferred_wrapped_function, FunctionDef):
        raise UseInferenceDefault("The wrapped function is not a function")

    # Determine if the passed keywords into the callsite are supported
    # by the wrapped function.
    if not inferred_wrapped_function.args:
        function_parameters = []
    else:
        function_parameters = chain(
            inferred_wrapped_function.args.args or (),
            inferred_wrapped_function.args.posonlyargs or (),
            inferred_wrapped_function.args.kwonlyargs or (),
        )
    parameter_names = {
        param.name for param in function_parameters if isinstance(param, AssignName)
    }
    if set(call.keyword_arguments) - parameter_names:
        raise UseInferenceDefault("wrapped function received unknown parameters")

    partial_function = objects.PartialFunction(
        call,
        name=inferred_wrapped_function.name,
        lineno=inferred_wrapped_function.lineno,
        col_offset=inferred_wrapped_function.col_offset,
        parent=node.parent,
    )
    partial_function.postinit(
        args=inferred_wrapped_function.args,
        body=inferred_wrapped_function.body,
        decorators=inferred_wrapped_function.decorators,
        returns=inferred_wrapped_function.returns,
        type_comment_returns=inferred_wrapped_function.type_comment_returns,
        type_comment_args=inferred_wrapped_function.type_comment_args,
        doc_node=inferred_wrapped_function.doc_node,
    )
    return iter((partial_function,))


def _looks_like_lru_cache(node):
    """Check if the given function node is decorated with lru_cache."""
    if not node.decorators:
        return False
    for decorator in node.decorators.nodes:
        if not isinstance(decorator, Call):
            continue
        if _looks_like_functools_member(decorator, "lru_cache"):
            return True
    return False


def _looks_like_functools_member(node, member) -> bool:
    """Check if the given Call node is a functools.partial call"""
    if isinstance(node.func, Name):
        return node.func.name == member
    if isinstance(node.func, Attribute):
        return (
            node.func.attrname == member
            and isinstance(node.func.expr, Name)
            and node.func.expr.name == "functools"
        )
    return False


_looks_like_partial = partial(_looks_like_functools_member, member="partial")


AstroidManager().register_transform(
    FunctionDef, _transform_lru_cache, _looks_like_lru_cache
)


AstroidManager().register_transform(
    Call,
    inference_tip(_functools_partial_inference),
    _looks_like_partial,
)