1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
from typing import TYPE_CHECKING, Union
from astroid import nodes
from pylint import checkers, interfaces
from pylint.checkers import utils
if TYPE_CHECKING:
from pylint.lint import PyLinter
_LIST_MODIFIER_METHODS = {"append", "remove"}
_SET_MODIFIER_METHODS = {"add", "remove"}
class ModifiedIterationChecker(checkers.BaseChecker):
"""Checks for modified iterators in for loops iterations.
Currently supports `for` loops for Sets, Dictionaries and Lists.
"""
__implements__ = interfaces.IAstroidChecker
name = "modified_iteration"
msgs = {
"W4701": (
"Iterated list '%s' is being modified inside for loop body, consider iterating through a copy of it "
"instead.",
"modified-iterating-list",
"Emitted when items are added or removed to a list being iterated through. "
"Doing so can result in unexpected behaviour, that is why it is preferred to use a copy of the list.",
),
"E4702": (
"Iterated dict '%s' is being modified inside for loop body, iterate through a copy of it instead.",
"modified-iterating-dict",
"Emitted when items are added or removed to a dict being iterated through. "
"Doing so raises a RuntimeError.",
),
"E4703": (
"Iterated set '%s' is being modified inside for loop body, iterate through a copy of it instead.",
"modified-iterating-set",
"Emitted when items are added or removed to a set being iterated through. "
"Doing so raises a RuntimeError.",
),
}
options = ()
priority = -2
@utils.check_messages(
"modified-iterating-list", "modified-iterating-dict", "modified-iterating-set"
)
def visit_for(self, node: nodes.For) -> None:
iter_obj = node.iter
for body_node in node.body:
self._modified_iterating_check_on_node_and_children(body_node, iter_obj)
def _modified_iterating_check_on_node_and_children(
self, body_node: nodes.NodeNG, iter_obj: nodes.NodeNG
) -> None:
"""See if node or any of its children raises modified iterating messages."""
self._modified_iterating_check(body_node, iter_obj)
for child in body_node.get_children():
self._modified_iterating_check_on_node_and_children(child, iter_obj)
def _modified_iterating_check(
self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
) -> None:
msg_id = None
if self._modified_iterating_list_cond(node, iter_obj):
msg_id = "modified-iterating-list"
elif self._modified_iterating_dict_cond(node, iter_obj):
msg_id = "modified-iterating-dict"
elif self._modified_iterating_set_cond(node, iter_obj):
msg_id = "modified-iterating-set"
if msg_id:
self.add_message(
msg_id,
node=node,
args=(iter_obj.name,),
confidence=interfaces.INFERENCE,
)
@staticmethod
def _is_node_expr_that_calls_attribute_name(node: nodes.NodeNG) -> bool:
return (
isinstance(node, nodes.Expr)
and isinstance(node.value, nodes.Call)
and isinstance(node.value.func, nodes.Attribute)
and isinstance(node.value.func.expr, nodes.Name)
)
@staticmethod
def _common_cond_list_set(
node: nodes.Expr,
iter_obj: nodes.NodeNG,
infer_val: Union[nodes.List, nodes.Set],
) -> bool:
return (infer_val == utils.safe_infer(iter_obj)) and (
node.value.func.expr.name == iter_obj.name
)
@staticmethod
def _is_node_assigns_subscript_name(node: nodes.NodeNG) -> bool:
return isinstance(node, nodes.Assign) and (
isinstance(node.targets[0], nodes.Subscript)
and (isinstance(node.targets[0].value, nodes.Name))
)
def _modified_iterating_list_cond(
self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
) -> bool:
if not self._is_node_expr_that_calls_attribute_name(node):
return False
infer_val = utils.safe_infer(node.value.func.expr)
if not isinstance(infer_val, nodes.List):
return False
return (
self._common_cond_list_set(node, iter_obj, infer_val)
and node.value.func.attrname in _LIST_MODIFIER_METHODS
)
def _modified_iterating_dict_cond(
self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
) -> bool:
if not self._is_node_assigns_subscript_name(node):
return False
infer_val = utils.safe_infer(node.targets[0].value)
if not isinstance(infer_val, nodes.Dict):
return False
if infer_val != utils.safe_infer(iter_obj):
return False
return node.targets[0].value.name == iter_obj.name
def _modified_iterating_set_cond(
self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
) -> bool:
if not self._is_node_expr_that_calls_attribute_name(node):
return False
infer_val = utils.safe_infer(node.value.func.expr)
if not isinstance(infer_val, nodes.Set):
return False
return (
self._common_cond_list_set(node, iter_obj, infer_val)
and node.value.func.attrname in _SET_MODIFIER_METHODS
)
def register(linter: "PyLinter") -> None:
linter.register_checker(ModifiedIterationChecker(linter))
|