summaryrefslogtreecommitdiff
path: root/pylint/checkers/modified_iterating_checker.py
blob: bdc8fff7f9c7a49c63bf37e84fd60020e538ffdf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
# Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt

from __future__ import annotations

from typing import TYPE_CHECKING

from astroid import nodes

from pylint import checkers, interfaces
from pylint.checkers import utils

if TYPE_CHECKING:
    from pylint.lint import PyLinter


_LIST_MODIFIER_METHODS = {"append", "remove"}
_SET_MODIFIER_METHODS = {"add", "remove"}


class ModifiedIterationChecker(checkers.BaseChecker):
    """Checks for modified iterators in for loops iterations.

    Currently supports `for` loops for Sets, Dictionaries and Lists.
    """

    name = "modified_iteration"

    msgs = {
        "W4701": (
            "Iterated list '%s' is being modified inside for loop body, consider iterating through a copy of it "
            "instead.",
            "modified-iterating-list",
            "Emitted when items are added or removed to a list being iterated through. "
            "Doing so can result in unexpected behaviour, that is why it is preferred to use a copy of the list.",
        ),
        "E4702": (
            "Iterated dict '%s' is being modified inside for loop body, iterate through a copy of it instead.",
            "modified-iterating-dict",
            "Emitted when items are added or removed to a dict being iterated through. "
            "Doing so raises a RuntimeError.",
        ),
        "E4703": (
            "Iterated set '%s' is being modified inside for loop body, iterate through a copy of it instead.",
            "modified-iterating-set",
            "Emitted when items are added or removed to a set being iterated through. "
            "Doing so raises a RuntimeError.",
        ),
    }

    options = ()

    @utils.only_required_for_messages(
        "modified-iterating-list", "modified-iterating-dict", "modified-iterating-set"
    )
    def visit_for(self, node: nodes.For) -> None:
        iter_obj = node.iter
        for body_node in node.body:
            self._modified_iterating_check_on_node_and_children(body_node, iter_obj)

    def _modified_iterating_check_on_node_and_children(
        self, body_node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> None:
        """See if node or any of its children raises modified iterating messages."""
        self._modified_iterating_check(body_node, iter_obj)
        for child in body_node.get_children():
            self._modified_iterating_check_on_node_and_children(child, iter_obj)

    def _modified_iterating_check(
        self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
    ) -> None:
        msg_id = None
        if isinstance(node, nodes.Delete) and any(
            self._deleted_iteration_target_cond(t, iter_obj) for t in node.targets
        ):
            inferred = utils.safe_infer(iter_obj)
            if isinstance(inferred, nodes.List):
                msg_id = "modified-iterating-list"
            elif isinstance(inferred, nodes.Dict):
                msg_id = "modified-iterating-dict"
            elif isinstance(inferred, nodes.Set):
                msg_id = "modified-iterating-set"
        elif not isinstance(iter_obj, (nodes.Name, nodes.Attribute)):
            pass
        elif self._modified_iterating_list_cond(node, iter_obj):
            msg_id = "modified-iterating-list"
        elif self._modified_iterating_dict_cond(node, iter_obj):
            msg_id = "modified-iterating-dict"
        elif self._modified_iterating_set_cond(node, iter_obj):
            msg_id = "modified-iterating-set"
        if msg_id:
            if isinstance(iter_obj, nodes.Attribute):
                obj_name = iter_obj.attrname
            else:
                obj_name = iter_obj.name
            self.add_message(
                msg_id,
                node=node,
                args=(obj_name,),
                confidence=interfaces.INFERENCE,
            )

    @staticmethod
    def _is_node_expr_that_calls_attribute_name(node: nodes.NodeNG) -> bool:
        return (
            isinstance(node, nodes.Expr)
            and isinstance(node.value, nodes.Call)
            and isinstance(node.value.func, nodes.Attribute)
            and isinstance(node.value.func.expr, nodes.Name)
        )

    @staticmethod
    def _common_cond_list_set(
        node: nodes.Expr,
        iter_obj: nodes.Name | nodes.Attribute,
        infer_val: nodes.List | nodes.Set,
    ) -> bool:
        iter_obj_name = (
            iter_obj.attrname
            if isinstance(iter_obj, nodes.Attribute)
            else iter_obj.name
        )
        return (infer_val == utils.safe_infer(iter_obj)) and (  # type: ignore[no-any-return]
            node.value.func.expr.name == iter_obj_name
        )

    @staticmethod
    def _is_node_assigns_subscript_name(node: nodes.NodeNG) -> bool:
        return isinstance(node, nodes.Assign) and (
            isinstance(node.targets[0], nodes.Subscript)
            and (isinstance(node.targets[0].value, nodes.Name))
        )

    def _modified_iterating_list_cond(
        self, node: nodes.NodeNG, iter_obj: nodes.Name | nodes.Attribute
    ) -> bool:
        if not self._is_node_expr_that_calls_attribute_name(node):
            return False
        infer_val = utils.safe_infer(node.value.func.expr)
        if not isinstance(infer_val, nodes.List):
            return False
        return (
            self._common_cond_list_set(node, iter_obj, infer_val)
            and node.value.func.attrname in _LIST_MODIFIER_METHODS
        )

    def _modified_iterating_dict_cond(
        self, node: nodes.NodeNG, iter_obj: nodes.Name | nodes.Attribute
    ) -> bool:
        if not self._is_node_assigns_subscript_name(node):
            return False
        # Do not emit when merely updating the same key being iterated
        if (
            isinstance(iter_obj, nodes.Name)
            and iter_obj.name == node.targets[0].value.name
            and isinstance(iter_obj.parent.target, nodes.AssignName)
            and isinstance(node.targets[0].slice, nodes.Name)
            and iter_obj.parent.target.name == node.targets[0].slice.name
        ):
            return False
        infer_val = utils.safe_infer(node.targets[0].value)
        if not isinstance(infer_val, nodes.Dict):
            return False
        if infer_val != utils.safe_infer(iter_obj):
            return False
        if isinstance(iter_obj, nodes.Attribute):
            iter_obj_name = iter_obj.attrname
        else:
            iter_obj_name = iter_obj.name
        return node.targets[0].value.name == iter_obj_name  # type: ignore[no-any-return]

    def _modified_iterating_set_cond(
        self, node: nodes.NodeNG, iter_obj: nodes.Name | nodes.Attribute
    ) -> bool:
        if not self._is_node_expr_that_calls_attribute_name(node):
            return False
        infer_val = utils.safe_infer(node.value.func.expr)
        if not isinstance(infer_val, nodes.Set):
            return False
        return (
            self._common_cond_list_set(node, iter_obj, infer_val)
            and node.value.func.attrname in _SET_MODIFIER_METHODS
        )

    def _deleted_iteration_target_cond(
        self, node: nodes.DelName, iter_obj: nodes.NodeNG
    ) -> bool:
        if not isinstance(node, nodes.DelName):
            return False
        if not isinstance(iter_obj.parent, nodes.For):
            return False
        if not isinstance(
            iter_obj.parent.target, (nodes.AssignName, nodes.BaseContainer)
        ):
            return False
        return any(
            t == node.name
            for t in utils.find_assigned_names_recursive(iter_obj.parent.target)
        )


def register(linter: PyLinter) -> None:
    linter.register_checker(ModifiedIterationChecker(linter))