diff options
author | Mark Byrne <31762852+mbyrnepr2@users.noreply.github.com> | 2022-05-17 00:39:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-16 18:39:46 -0400 |
commit | 4e28c0ff50c77247d2e39e1ed2b625368cd3c1b2 (patch) | |
tree | c67213d2c72decbb29c21240027f8e30def139bf /tests/unittest_protocols.py | |
parent | 3ff65be6c9e3a4480cd9338a68f418cb578e347b (diff) | |
download | astroid-git-4e28c0ff50c77247d2e39e1ed2b625368cd3c1b2.tar.gz |
Fix inference from nested unpacking of iterables (#1311)
Diffstat (limited to 'tests/unittest_protocols.py')
-rw-r--r-- | tests/unittest_protocols.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tests/unittest_protocols.py b/tests/unittest_protocols.py index 69e50f11..9a477d02 100644 --- a/tests/unittest_protocols.py +++ b/tests/unittest_protocols.py @@ -70,6 +70,48 @@ class ProtocolTests(unittest.TestCase): for2_assnode = next(assign_stmts[1].nodes_of_class(nodes.AssignName)) self.assertRaises(InferenceError, list, for2_assnode.assigned_stmts()) + def test_assigned_stmts_nested_for_tuple(self) -> None: + assign_stmts = extract_node( + """ + for a, (b, c) in [(1, (2, 3))]: #@ + pass + """ + ) + + assign_nodes = assign_stmts.nodes_of_class(nodes.AssignName) + + for1_assnode = next(assign_nodes) + assigned = list(for1_assnode.assigned_stmts()) + self.assertConstNodesEqual([1], assigned) + + for2_assnode = next(assign_nodes) + assigned2 = list(for2_assnode.assigned_stmts()) + self.assertConstNodesEqual([2], assigned2) + + def test_assigned_stmts_nested_for_dict(self) -> None: + assign_stmts = extract_node( + """ + for a, (b, c) in {1: ("a", str), 2: ("b", bytes)}.items(): #@ + pass + """ + ) + assign_nodes = assign_stmts.nodes_of_class(nodes.AssignName) + + # assigned: [1, 2] + for1_assnode = next(assign_nodes) + assigned = list(for1_assnode.assigned_stmts()) + self.assertConstNodesEqual([1, 2], assigned) + + # assigned2: ["a", "b"] + for2_assnode = next(assign_nodes) + assigned2 = list(for2_assnode.assigned_stmts()) + self.assertConstNodesEqual(["a", "b"], assigned2) + + # assigned3: [str, bytes] + for3_assnode = next(assign_nodes) + assigned3 = list(for3_assnode.assigned_stmts()) + self.assertNameNodesEqual(["str", "bytes"], assigned3) + def test_assigned_stmts_starred_for(self) -> None: assign_stmts = extract_node( """ |