summaryrefslogtreecommitdiff
path: root/pylint/checkers/modified_iterating_checker.py
blob: 711d37bb0e0a8408c08c8eebdf5276ab5199eb40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE

from typing import TYPE_CHECKING, Union

from astroid import nodes

from pylint import checkers, interfaces
from pylint.checkers import utils

if TYPE_CHECKING:
    from pylint.lint import PyLinter


_LIST_MODIFIER_METHODS = {"append", "remove"}
_SET_MODIFIER_METHODS = {"add", "remove"}


class ModifiedIterationChecker(checkers.BaseChecker):
    """Checks for modified iterators in for loops iterations.

    Currently supports `for` loops for Sets, Dictionaries and Lists.
    """

    __implements__ = interfaces.IAstroidChecker

    name = "modified_iteration"

    msgs = {
        "W4701": (
            "Iterated list '%s' is being modified inside for loop body, consider iterating through a copy of it "
            "instead.",
            "modified-iterating-list",
            "Emitted when items are added or removed to a list being iterated through. "
            "Doing so can result in unexpected behaviour, that is why it is preferred to use a copy of the list.",
        ),
        "E4702": (
            "Iterated dict '%s' is being modified inside for loop body, iterate through a copy of it instead.",
            "modified-iterating-dict",
            "Emitted when items are added or removed to a dict being iterated through. "
            "Doing so raises a RuntimeError.",
        ),
        "E4703": (
            "Iterated set '%s' is being modified inside for loop body, iterate through a copy of it instead.",
            "modified-iterating-set",
            "Emitted when items are added or removed to a set being iterated through. "
            "Doing so raises a RuntimeError.",
        ),
    }

    options = ()
    priority = -2

    @utils.check_messages(
        "modified-iterating-list", "modified-iterating-dict", "modified-iterating-set"
    )
    def visit_for(self, node: nodes.For) -> None:
        iter_obj = node.iter
        for body_node in node.body:
            self._modified_iterating_check_on_node_and_children(body_node, iter_obj)

    def _modified_iterating_check_on_node_and_children(
        self, body_node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> None:
        """See if node or any of its children raises modified iterating messages."""
        self._modified_iterating_check(body_node, iter_obj)
        for child in body_node.get_children():
            self._modified_iterating_check_on_node_and_children(child, iter_obj)

    def _modified_iterating_check(
        self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> None:
        msg_id = None
        if self._modified_iterating_list_cond(node, iter_obj):
            msg_id = "modified-iterating-list"
        elif self._modified_iterating_dict_cond(node, iter_obj):
            msg_id = "modified-iterating-dict"
        elif self._modified_iterating_set_cond(node, iter_obj):
            msg_id = "modified-iterating-set"
        if msg_id:
            self.add_message(
                msg_id,
                node=node,
                args=(iter_obj.name,),
                confidence=interfaces.INFERENCE,
            )

    @staticmethod
    def _is_node_expr_that_calls_attribute_name(node: nodes.NodeNG) -> bool:
        return (
            isinstance(node, nodes.Expr)
            and isinstance(node.value, nodes.Call)
            and isinstance(node.value.func, nodes.Attribute)
            and isinstance(node.value.func.expr, nodes.Name)
        )

    @staticmethod
    def _common_cond_list_set(
        node: nodes.Expr,
        iter_obj: nodes.NodeNG,
        infer_val: Union[nodes.List, nodes.Set],
    ) -> bool:
        return (infer_val == utils.safe_infer(iter_obj)) and (
            node.value.func.expr.name == iter_obj.name
        )

    @staticmethod
    def _is_node_assigns_subscript_name(node: nodes.NodeNG) -> bool:
        return isinstance(node, nodes.Assign) and (
            isinstance(node.targets[0], nodes.Subscript)
            and (isinstance(node.targets[0].value, nodes.Name))
        )

    def _modified_iterating_list_cond(
        self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> bool:
        if not self._is_node_expr_that_calls_attribute_name(node):
            return False
        infer_val = utils.safe_infer(node.value.func.expr)
        if not isinstance(infer_val, nodes.List):
            return False
        return (
            self._common_cond_list_set(node, iter_obj, infer_val)
            and node.value.func.attrname in _LIST_MODIFIER_METHODS
        )

    def _modified_iterating_dict_cond(
        self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> bool:
        if not self._is_node_assigns_subscript_name(node):
            return False
        infer_val = utils.safe_infer(node.targets[0].value)
        if not isinstance(infer_val, nodes.Dict):
            return False
        if infer_val != utils.safe_infer(iter_obj):
            return False
        return node.targets[0].value.name == iter_obj.name

    def _modified_iterating_set_cond(
        self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> bool:
        if not self._is_node_expr_that_calls_attribute_name(node):
            return False
        infer_val = utils.safe_infer(node.value.func.expr)
        if not isinstance(infer_val, nodes.Set):
            return False
        return (
            self._common_cond_list_set(node, iter_obj, infer_val)
            and node.value.func.attrname in _SET_MODIFIER_METHODS
        )


def register(linter: "PyLinter") -> None:
    linter.register_checker(ModifiedIterationChecker(linter))