diff options
-rw-r--r-- | astroid/inference.py | 14 | ||||
-rw-r--r-- | astroid/node_classes.py | 4 | ||||
-rw-r--r-- | astroid/protocols.py | 17 | ||||
-rw-r--r-- | astroid/tests/unittest_protocols.py | 54 |
4 files changed, 86 insertions, 3 deletions
diff --git a/astroid/inference.py b/astroid/inference.py index 1ef5849f..1ea45248 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -71,6 +71,11 @@ def _infer_sequence_helper(node, context=None): if not hasattr(starred, "elts"): raise exceptions.InferenceError(node=node, context=context) values.extend(_infer_sequence_helper(starred)) + elif isinstance(elt, nodes.NamedExpr): + value = helpers.safe_infer(elt.value, context) + if not value: + raise exceptions.InferenceError(node=node, context=context) + values.append(value) else: values.append(elt) return values @@ -78,9 +83,10 @@ def _infer_sequence_helper(node, context=None): @decorators.raise_if_nothing_inferred def infer_sequence(self, context=None): - if not any(isinstance(e, nodes.Starred) for e in self.elts): - yield self - else: + has_starred_named_expr = any( + isinstance(e, (nodes.Starred, nodes.NamedExpr)) for e in self.elts + ) + if has_starred_named_expr: values = _infer_sequence_helper(self, context) new_seq = type(self)( lineno=self.lineno, col_offset=self.col_offset, parent=self.parent @@ -88,6 +94,8 @@ def infer_sequence(self, context=None): new_seq.postinit(values) yield new_seq + else: + yield self nodes.List._infer = infer_sequence diff --git a/astroid/node_classes.py b/astroid/node_classes.py index d3d7634d..44fe3df4 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1201,6 +1201,10 @@ class LookupMixIn: _stmt_parents = [stmt.parent] continue + if isinstance(assign_type, NamedExpr): + _stmts = [node] + continue + # XXX comment various branches below!!! try: pindex = _stmt_parents.index(stmt.parent) diff --git a/astroid/protocols.py b/astroid/protocols.py index bf497ff1..b598ec3d 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -572,6 +572,23 @@ def with_assigned_stmts(self, node=None, context=None, assign_path=None): nodes.With.assigned_stmts = with_assigned_stmts +@decorators.raise_if_nothing_inferred +def named_expr_assigned_stmts(self, node, context=None, assign_path=None): + """Infer names and other nodes from an assignment expression""" + if self.target == node: + yield from self.value.infer(context=context) + else: + raise exceptions.InferenceError( + "Cannot infer NamedExpr node {node!r}", + node=self, + assign_path=assign_path, + context=context, + ) + + +nodes.NamedExpr.assigned_stmts = named_expr_assigned_stmts + + @decorators.yes_if_nothing_inferred def starred_assigned_stmts(self, node=None, context=None, assign_path=None): """ diff --git a/astroid/tests/unittest_protocols.py b/astroid/tests/unittest_protocols.py index 79648dcf..e05c43cc 100644 --- a/astroid/tests/unittest_protocols.py +++ b/astroid/tests/unittest_protocols.py @@ -213,5 +213,59 @@ class ProtocolTests(unittest.TestCase): parsed.accept(Visitor()) +def test_named_expr_inference(): + code = """ + if (a := 2) == 2: + a #@ + + + # Test a function call + def test(): + return 24 + + if (a := test()): + a #@ + + # Normal assignments in sequences + { (a:= 4) } #@ + [ (a:= 5) ] #@ + + # Something more complicated + def test(value=(p := 24)): return p + [ y:= test()] #@ + + # Priority assignment + (x := 1, 2) + x #@ + """ + ast_nodes = extract_node(code) + node = next(ast_nodes[0].infer()) + assert isinstance(node, nodes.Const) + assert node.value == 2 + + node = next(ast_nodes[1].infer()) + assert isinstance(node, nodes.Const) + assert node.value == 24 + + node = next(ast_nodes[2].infer()) + assert isinstance(node, nodes.Set) + assert isinstance(node.elts[0], nodes.Const) + assert node.elts[0].value == 4 + + node = next(ast_nodes[3].infer()) + assert isinstance(node, nodes.List) + assert isinstance(node.elts[0], nodes.Const) + assert node.elts[0].value == 5 + + node = next(ast_nodes[4].infer()) + assert isinstance(node, nodes.List) + assert isinstance(node.elts[0], nodes.Const) + assert node.elts[0].value == 24 + + node = next(ast_nodes[4].infer()) + assert isinstance(node, nodes.Const) + assert node.value == 1 + + if __name__ == "__main__": unittest.main() |