summaryrefslogtreecommitdiff
path: root/pylint/checkers/classes/special_methods_checker.py
blob: 025f28562269c225d16e13211df110f9daff9593 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt

"""Special methods checker and helper function's module."""

from __future__ import annotations

from collections.abc import Callable

import astroid
from astroid import bases, nodes, util
from astroid.context import InferenceContext
from astroid.typing import InferenceResult

from pylint.checkers import BaseChecker
from pylint.checkers.utils import (
    PYMETHODS,
    SPECIAL_METHODS_PARAMS,
    decorated_with,
    is_function_body_ellipsis,
    only_required_for_messages,
    safe_infer,
)
from pylint.lint.pylinter import PyLinter

NEXT_METHOD = "__next__"


def _safe_infer_call_result(
    node: nodes.FunctionDef,
    caller: nodes.FunctionDef,
    context: InferenceContext | None = None,
) -> InferenceResult | None:
    """Safely infer the return value of a function.

    Returns None if inference failed or if there is some ambiguity (more than
    one node has been inferred). Otherwise, returns inferred value.
    """
    try:
        inferit = node.infer_call_result(caller, context=context)
        value = next(inferit)
    except astroid.InferenceError:
        return None  # inference failed
    except StopIteration:
        return None  # no values inferred
    try:
        next(inferit)
        return None  # there is ambiguity on the inferred node
    except astroid.InferenceError:
        return None  # there is some kind of ambiguity
    except StopIteration:
        return value


class SpecialMethodsChecker(BaseChecker):
    """Checker which verifies that special methods
    are implemented correctly.
    """

    name = "classes"
    msgs = {
        "E0301": (
            "__iter__ returns non-iterator",
            "non-iterator-returned",
            "Used when an __iter__ method returns something which is not an "
            f"iterable (i.e. has no `{NEXT_METHOD}` method)",
            {
                "old_names": [
                    ("W0234", "old-non-iterator-returned-1"),
                    ("E0234", "old-non-iterator-returned-2"),
                ]
            },
        ),
        "E0302": (
            "The special method %r expects %s param(s), %d %s given",
            "unexpected-special-method-signature",
            "Emitted when a special method was defined with an "
            "invalid number of parameters. If it has too few or "
            "too many, it might not work at all.",
            {"old_names": [("E0235", "bad-context-manager")]},
        ),
        "E0303": (
            "__len__ does not return non-negative integer",
            "invalid-length-returned",
            "Used when a __len__ method returns something which is not a "
            "non-negative integer",
        ),
        "E0304": (
            "__bool__ does not return bool",
            "invalid-bool-returned",
            "Used when a __bool__ method returns something which is not a bool",
        ),
        "E0305": (
            "__index__ does not return int",
            "invalid-index-returned",
            "Used when an __index__ method returns something which is not "
            "an integer",
        ),
        "E0306": (
            "__repr__ does not return str",
            "invalid-repr-returned",
            "Used when a __repr__ method returns something which is not a string",
        ),
        "E0307": (
            "__str__ does not return str",
            "invalid-str-returned",
            "Used when a __str__ method returns something which is not a string",
        ),
        "E0308": (
            "__bytes__ does not return bytes",
            "invalid-bytes-returned",
            "Used when a __bytes__ method returns something which is not bytes",
        ),
        "E0309": (
            "__hash__ does not return int",
            "invalid-hash-returned",
            "Used when a __hash__ method returns something which is not an integer",
        ),
        "E0310": (
            "__length_hint__ does not return non-negative integer",
            "invalid-length-hint-returned",
            "Used when a __length_hint__ method returns something which is not a "
            "non-negative integer",
        ),
        "E0311": (
            "__format__ does not return str",
            "invalid-format-returned",
            "Used when a __format__ method returns something which is not a string",
        ),
        "E0312": (
            "__getnewargs__ does not return a tuple",
            "invalid-getnewargs-returned",
            "Used when a __getnewargs__ method returns something which is not "
            "a tuple",
        ),
        "E0313": (
            "__getnewargs_ex__ does not return a tuple containing (tuple, dict)",
            "invalid-getnewargs-ex-returned",
            "Used when a __getnewargs_ex__ method returns something which is not "
            "of the form tuple(tuple, dict)",
        ),
    }

    def __init__(self, linter: PyLinter) -> None:
        super().__init__(linter)
        self._protocol_map: dict[
            str, Callable[[nodes.FunctionDef, InferenceResult], None]
        ] = {
            "__iter__": self._check_iter,
            "__len__": self._check_len,
            "__bool__": self._check_bool,
            "__index__": self._check_index,
            "__repr__": self._check_repr,
            "__str__": self._check_str,
            "__bytes__": self._check_bytes,
            "__hash__": self._check_hash,
            "__length_hint__": self._check_length_hint,
            "__format__": self._check_format,
            "__getnewargs__": self._check_getnewargs,
            "__getnewargs_ex__": self._check_getnewargs_ex,
        }

    @only_required_for_messages(
        "unexpected-special-method-signature",
        "non-iterator-returned",
        "invalid-length-returned",
        "invalid-bool-returned",
        "invalid-index-returned",
        "invalid-repr-returned",
        "invalid-str-returned",
        "invalid-bytes-returned",
        "invalid-hash-returned",
        "invalid-length-hint-returned",
        "invalid-format-returned",
        "invalid-getnewargs-returned",
        "invalid-getnewargs-ex-returned",
    )
    def visit_functiondef(self, node: nodes.FunctionDef) -> None:
        if not node.is_method():
            return

        inferred = _safe_infer_call_result(node, node)
        # Only want to check types that we are able to infer
        if (
            inferred
            and node.name in self._protocol_map
            and not is_function_body_ellipsis(node)
        ):
            self._protocol_map[node.name](node, inferred)

        if node.name in PYMETHODS:
            self._check_unexpected_method_signature(node)

    visit_asyncfunctiondef = visit_functiondef

    def _check_unexpected_method_signature(self, node: nodes.FunctionDef) -> None:
        expected_params = SPECIAL_METHODS_PARAMS[node.name]

        if expected_params is None:
            # This can support a variable number of parameters.
            return
        if not node.args.args and not node.args.vararg:
            # Method has no parameter, will be caught
            # by no-method-argument.
            return

        if decorated_with(node, ["builtins.staticmethod"]):
            # We expect to not take in consideration self.
            all_args = node.args.args
        else:
            all_args = node.args.args[1:]
        mandatory = len(all_args) - len(node.args.defaults)
        optional = len(node.args.defaults)
        current_params = mandatory + optional

        emit = False  # If we don't know we choose a false negative
        if isinstance(expected_params, tuple):
            # The expected number of parameters can be any value from this
            # tuple, although the user should implement the method
            # to take all of them in consideration.
            emit = mandatory not in expected_params
            # mypy thinks that expected_params has type tuple[int, int] | int | None
            # But at this point it must be 'tuple[int, int]' because of the type check
            expected_params = f"between {expected_params[0]} or {expected_params[1]}"  # type: ignore[assignment]
        else:
            # If the number of mandatory parameters doesn't
            # suffice, the expected parameters for this
            # function will be deduced from the optional
            # parameters.
            rest = expected_params - mandatory
            if rest == 0:
                emit = False
            elif rest < 0:
                emit = True
            elif rest > 0:
                emit = not ((optional - rest) >= 0 or node.args.vararg)

        if emit:
            verb = "was" if current_params <= 1 else "were"
            self.add_message(
                "unexpected-special-method-signature",
                args=(node.name, expected_params, current_params, verb),
                node=node,
            )

    @staticmethod
    def _is_wrapped_type(node: InferenceResult, type_: str) -> bool:
        return (
            isinstance(node, bases.Instance)
            and node.name == type_
            and not isinstance(node, nodes.Const)
        )

    @staticmethod
    def _is_int(node: InferenceResult) -> bool:
        if SpecialMethodsChecker._is_wrapped_type(node, "int"):
            return True

        return isinstance(node, nodes.Const) and isinstance(node.value, int)

    @staticmethod
    def _is_str(node: InferenceResult) -> bool:
        if SpecialMethodsChecker._is_wrapped_type(node, "str"):
            return True

        return isinstance(node, nodes.Const) and isinstance(node.value, str)

    @staticmethod
    def _is_bool(node: InferenceResult) -> bool:
        if SpecialMethodsChecker._is_wrapped_type(node, "bool"):
            return True

        return isinstance(node, nodes.Const) and isinstance(node.value, bool)

    @staticmethod
    def _is_bytes(node: InferenceResult) -> bool:
        if SpecialMethodsChecker._is_wrapped_type(node, "bytes"):
            return True

        return isinstance(node, nodes.Const) and isinstance(node.value, bytes)

    @staticmethod
    def _is_tuple(node: InferenceResult) -> bool:
        if SpecialMethodsChecker._is_wrapped_type(node, "tuple"):
            return True

        return isinstance(node, nodes.Const) and isinstance(node.value, tuple)

    @staticmethod
    def _is_dict(node: InferenceResult) -> bool:
        if SpecialMethodsChecker._is_wrapped_type(node, "dict"):
            return True

        return isinstance(node, nodes.Const) and isinstance(node.value, dict)

    @staticmethod
    def _is_iterator(node: InferenceResult) -> bool:
        if isinstance(node, bases.Generator):
            # Generators can be iterated.
            return True
        if isinstance(node, nodes.ComprehensionScope):
            # Comprehensions can be iterated.
            return True

        if isinstance(node, bases.Instance):
            try:
                node.local_attr(NEXT_METHOD)
                return True
            except astroid.NotFoundError:
                pass
        elif isinstance(node, nodes.ClassDef):
            metaclass = node.metaclass()
            if metaclass and isinstance(metaclass, nodes.ClassDef):
                try:
                    metaclass.local_attr(NEXT_METHOD)
                    return True
                except astroid.NotFoundError:
                    pass
        return False

    def _check_iter(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_iterator(inferred):
            self.add_message("non-iterator-returned", node=node)

    def _check_len(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_int(inferred):
            self.add_message("invalid-length-returned", node=node)
        elif isinstance(inferred, nodes.Const) and inferred.value < 0:
            self.add_message("invalid-length-returned", node=node)

    def _check_bool(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_bool(inferred):
            self.add_message("invalid-bool-returned", node=node)

    def _check_index(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_int(inferred):
            self.add_message("invalid-index-returned", node=node)

    def _check_repr(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_str(inferred):
            self.add_message("invalid-repr-returned", node=node)

    def _check_str(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_str(inferred):
            self.add_message("invalid-str-returned", node=node)

    def _check_bytes(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_bytes(inferred):
            self.add_message("invalid-bytes-returned", node=node)

    def _check_hash(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_int(inferred):
            self.add_message("invalid-hash-returned", node=node)

    def _check_length_hint(
        self, node: nodes.FunctionDef, inferred: InferenceResult
    ) -> None:
        if not self._is_int(inferred):
            self.add_message("invalid-length-hint-returned", node=node)
        elif isinstance(inferred, nodes.Const) and inferred.value < 0:
            self.add_message("invalid-length-hint-returned", node=node)

    def _check_format(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
        if not self._is_str(inferred):
            self.add_message("invalid-format-returned", node=node)

    def _check_getnewargs(
        self, node: nodes.FunctionDef, inferred: InferenceResult
    ) -> None:
        if not self._is_tuple(inferred):
            self.add_message("invalid-getnewargs-returned", node=node)

    def _check_getnewargs_ex(
        self, node: nodes.FunctionDef, inferred: InferenceResult
    ) -> None:
        if not self._is_tuple(inferred):
            self.add_message("invalid-getnewargs-ex-returned", node=node)
            return

        if not isinstance(inferred, nodes.Tuple):
            # If it's not an astroid.Tuple we can't analyze it further
            return

        found_error = False

        if len(inferred.elts) != 2:
            found_error = True
        else:
            for arg, check in (
                (inferred.elts[0], self._is_tuple),
                (inferred.elts[1], self._is_dict),
            ):
                if isinstance(arg, nodes.Call):
                    arg = safe_infer(arg)

                if arg and not isinstance(arg, util.UninferableBase):
                    if not check(arg):
                        found_error = True
                        break

        if found_error:
            self.add_message("invalid-getnewargs-ex-returned", node=node)