summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--astroid/inference.py14
-rw-r--r--astroid/node_classes.py4
-rw-r--r--astroid/protocols.py17
-rw-r--r--astroid/tests/unittest_protocols.py54
4 files changed, 86 insertions, 3 deletions
diff --git a/astroid/inference.py b/astroid/inference.py
index 1ef5849f..1ea45248 100644
--- a/astroid/inference.py
+++ b/astroid/inference.py
@@ -71,6 +71,11 @@ def _infer_sequence_helper(node, context=None):
if not hasattr(starred, "elts"):
raise exceptions.InferenceError(node=node, context=context)
values.extend(_infer_sequence_helper(starred))
+ elif isinstance(elt, nodes.NamedExpr):
+ value = helpers.safe_infer(elt.value, context)
+ if not value:
+ raise exceptions.InferenceError(node=node, context=context)
+ values.append(value)
else:
values.append(elt)
return values
@@ -78,9 +83,10 @@ def _infer_sequence_helper(node, context=None):
@decorators.raise_if_nothing_inferred
def infer_sequence(self, context=None):
- if not any(isinstance(e, nodes.Starred) for e in self.elts):
- yield self
- else:
+ has_starred_named_expr = any(
+ isinstance(e, (nodes.Starred, nodes.NamedExpr)) for e in self.elts
+ )
+ if has_starred_named_expr:
values = _infer_sequence_helper(self, context)
new_seq = type(self)(
lineno=self.lineno, col_offset=self.col_offset, parent=self.parent
@@ -88,6 +94,8 @@ def infer_sequence(self, context=None):
new_seq.postinit(values)
yield new_seq
+ else:
+ yield self
nodes.List._infer = infer_sequence
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index d3d7634d..44fe3df4 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -1201,6 +1201,10 @@ class LookupMixIn:
_stmt_parents = [stmt.parent]
continue
+ if isinstance(assign_type, NamedExpr):
+ _stmts = [node]
+ continue
+
# XXX comment various branches below!!!
try:
pindex = _stmt_parents.index(stmt.parent)
diff --git a/astroid/protocols.py b/astroid/protocols.py
index bf497ff1..b598ec3d 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -572,6 +572,23 @@ def with_assigned_stmts(self, node=None, context=None, assign_path=None):
nodes.With.assigned_stmts = with_assigned_stmts
+@decorators.raise_if_nothing_inferred
+def named_expr_assigned_stmts(self, node, context=None, assign_path=None):
+ """Infer names and other nodes from an assignment expression"""
+ if self.target == node:
+ yield from self.value.infer(context=context)
+ else:
+ raise exceptions.InferenceError(
+ "Cannot infer NamedExpr node {node!r}",
+ node=self,
+ assign_path=assign_path,
+ context=context,
+ )
+
+
+nodes.NamedExpr.assigned_stmts = named_expr_assigned_stmts
+
+
@decorators.yes_if_nothing_inferred
def starred_assigned_stmts(self, node=None, context=None, assign_path=None):
"""
diff --git a/astroid/tests/unittest_protocols.py b/astroid/tests/unittest_protocols.py
index 79648dcf..e05c43cc 100644
--- a/astroid/tests/unittest_protocols.py
+++ b/astroid/tests/unittest_protocols.py
@@ -213,5 +213,59 @@ class ProtocolTests(unittest.TestCase):
parsed.accept(Visitor())
+def test_named_expr_inference():
+ code = """
+ if (a := 2) == 2:
+ a #@
+
+
+ # Test a function call
+ def test():
+ return 24
+
+ if (a := test()):
+ a #@
+
+ # Normal assignments in sequences
+ { (a:= 4) } #@
+ [ (a:= 5) ] #@
+
+ # Something more complicated
+ def test(value=(p := 24)): return p
+ [ y:= test()] #@
+
+ # Priority assignment
+ (x := 1, 2)
+ x #@
+ """
+ ast_nodes = extract_node(code)
+ node = next(ast_nodes[0].infer())
+ assert isinstance(node, nodes.Const)
+ assert node.value == 2
+
+ node = next(ast_nodes[1].infer())
+ assert isinstance(node, nodes.Const)
+ assert node.value == 24
+
+ node = next(ast_nodes[2].infer())
+ assert isinstance(node, nodes.Set)
+ assert isinstance(node.elts[0], nodes.Const)
+ assert node.elts[0].value == 4
+
+ node = next(ast_nodes[3].infer())
+ assert isinstance(node, nodes.List)
+ assert isinstance(node.elts[0], nodes.Const)
+ assert node.elts[0].value == 5
+
+ node = next(ast_nodes[4].infer())
+ assert isinstance(node, nodes.List)
+ assert isinstance(node.elts[0], nodes.Const)
+ assert node.elts[0].value == 24
+
+ node = next(ast_nodes[4].infer())
+ assert isinstance(node, nodes.Const)
+ assert node.value == 1
+
+
if __name__ == "__main__":
unittest.main()