summaryrefslogtreecommitdiff
path: root/tests/unittest_protocols.py
diff options
context:
space:
mode:
authorMark Byrne <31762852+mbyrnepr2@users.noreply.github.com>2022-05-17 00:39:46 +0200
committerGitHub <noreply@github.com>2022-05-16 18:39:46 -0400
commit4e28c0ff50c77247d2e39e1ed2b625368cd3c1b2 (patch)
treec67213d2c72decbb29c21240027f8e30def139bf /tests/unittest_protocols.py
parent3ff65be6c9e3a4480cd9338a68f418cb578e347b (diff)
downloadastroid-git-4e28c0ff50c77247d2e39e1ed2b625368cd3c1b2.tar.gz
Fix inference from nested unpacking of iterables (#1311)
Diffstat (limited to 'tests/unittest_protocols.py')
-rw-r--r--tests/unittest_protocols.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tests/unittest_protocols.py b/tests/unittest_protocols.py
index 69e50f11..9a477d02 100644
--- a/tests/unittest_protocols.py
+++ b/tests/unittest_protocols.py
@@ -70,6 +70,48 @@ class ProtocolTests(unittest.TestCase):
for2_assnode = next(assign_stmts[1].nodes_of_class(nodes.AssignName))
self.assertRaises(InferenceError, list, for2_assnode.assigned_stmts())
+ def test_assigned_stmts_nested_for_tuple(self) -> None:
+ assign_stmts = extract_node(
+ """
+ for a, (b, c) in [(1, (2, 3))]: #@
+ pass
+ """
+ )
+
+ assign_nodes = assign_stmts.nodes_of_class(nodes.AssignName)
+
+ for1_assnode = next(assign_nodes)
+ assigned = list(for1_assnode.assigned_stmts())
+ self.assertConstNodesEqual([1], assigned)
+
+ for2_assnode = next(assign_nodes)
+ assigned2 = list(for2_assnode.assigned_stmts())
+ self.assertConstNodesEqual([2], assigned2)
+
+ def test_assigned_stmts_nested_for_dict(self) -> None:
+ assign_stmts = extract_node(
+ """
+ for a, (b, c) in {1: ("a", str), 2: ("b", bytes)}.items(): #@
+ pass
+ """
+ )
+ assign_nodes = assign_stmts.nodes_of_class(nodes.AssignName)
+
+ # assigned: [1, 2]
+ for1_assnode = next(assign_nodes)
+ assigned = list(for1_assnode.assigned_stmts())
+ self.assertConstNodesEqual([1, 2], assigned)
+
+ # assigned2: ["a", "b"]
+ for2_assnode = next(assign_nodes)
+ assigned2 = list(for2_assnode.assigned_stmts())
+ self.assertConstNodesEqual(["a", "b"], assigned2)
+
+ # assigned3: [str, bytes]
+ for3_assnode = next(assign_nodes)
+ assigned3 = list(for3_assnode.assigned_stmts())
+ self.assertNameNodesEqual(["str", "bytes"], assigned3)
+
def test_assigned_stmts_starred_for(self) -> None:
assign_stmts = extract_node(
"""