summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--astroid/nodes/__init__.py1
-rw-r--r--astroid/nodes/node_classes.py101
-rw-r--r--astroid/rebuilder.py16
-rw-r--r--tests/test_group_exceptions.py111
4 files changed, 229 insertions, 0 deletions
diff --git a/astroid/nodes/__init__.py b/astroid/nodes/__init__.py
index 68ddad74..b527ff7c 100644
--- a/astroid/nodes/__init__.py
+++ b/astroid/nodes/__init__.py
@@ -84,6 +84,7 @@ from astroid.nodes.node_classes import ( # pylint: disable=redefined-builtin (E
Subscript,
TryExcept,
TryFinally,
+ TryStar,
Tuple,
UnaryOp,
Unknown,
diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py
index 3cec0891..b7772c3c 100644
--- a/astroid/nodes/node_classes.py
+++ b/astroid/nodes/node_classes.py
@@ -4216,6 +4216,107 @@ class TryFinally(_base_nodes.MultiLineWithElseBlockNode, _base_nodes.Statement):
yield from self.finalbody
+class TryStar(_base_nodes.MultiLineWithElseBlockNode, _base_nodes.Statement):
+ """Class representing an :class:`ast.TryStar` node."""
+
+ _astroid_fields = ("body", "handlers", "orelse", "finalbody")
+ _multi_line_block_fields = ("body", "handlers", "orelse", "finalbody")
+
+ def __init__(
+ self,
+ *,
+ lineno: int | None = None,
+ col_offset: int | None = None,
+ end_lineno: int | None = None,
+ end_col_offset: int | None = None,
+ parent: NodeNG | None = None,
+ ) -> None:
+ """
+ :param lineno: The line that this node appears on in the source code.
+ :param col_offset: The column that this node appears on in the
+ source code.
+ :param parent: The parent node in the syntax tree.
+ :param end_lineno: The last line this node appears on in the source code.
+ :param end_col_offset: The end column this node appears on in the
+ source code. Note: This is after the last symbol.
+ """
+ self.body: list[NodeNG] = []
+ """The contents of the block to catch exceptions from."""
+
+ self.handlers: list[ExceptHandler] = []
+ """The exception handlers."""
+
+ self.orelse: list[NodeNG] = []
+ """The contents of the ``else`` block."""
+
+ self.finalbody: list[NodeNG] = []
+ """The contents of the ``finally`` block."""
+
+ super().__init__(
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=end_lineno,
+ end_col_offset=end_col_offset,
+ parent=parent,
+ )
+
+ def postinit(
+ self,
+ *,
+ body: list[NodeNG] | None = None,
+ handlers: list[ExceptHandler] | None = None,
+ orelse: list[NodeNG] | None = None,
+ finalbody: list[NodeNG] | None = None,
+ ) -> None:
+ """Do some setup after initialisation.
+ :param body: The contents of the block to catch exceptions from.
+ :param handlers: The exception handlers.
+ :param orelse: The contents of the ``else`` block.
+ :param finalbody: The contents of the ``finally`` block.
+ """
+ if body:
+ self.body = body
+ if handlers:
+ self.handlers = handlers
+ if orelse:
+ self.orelse = orelse
+ if finalbody:
+ self.finalbody = finalbody
+
+ def _infer_name(self, frame, name):
+ return name
+
+ def block_range(self, lineno: int) -> tuple[int, int]:
+ """Get a range from a given line number to where this node ends."""
+ if lineno == self.fromlineno:
+ return lineno, lineno
+ if self.body and self.body[0].fromlineno <= lineno <= self.body[-1].tolineno:
+ # Inside try body - return from lineno till end of try body
+ return lineno, self.body[-1].tolineno
+ for exhandler in self.handlers:
+ if exhandler.type and lineno == exhandler.type.fromlineno:
+ return lineno, lineno
+ if exhandler.body[0].fromlineno <= lineno <= exhandler.body[-1].tolineno:
+ return lineno, exhandler.body[-1].tolineno
+ if self.orelse:
+ if self.orelse[0].fromlineno - 1 == lineno:
+ return lineno, lineno
+ if self.orelse[0].fromlineno <= lineno <= self.orelse[-1].tolineno:
+ return lineno, self.orelse[-1].tolineno
+ if self.finalbody:
+ if self.finalbody[0].fromlineno - 1 == lineno:
+ return lineno, lineno
+ if self.finalbody[0].fromlineno <= lineno <= self.finalbody[-1].tolineno:
+ return lineno, self.finalbody[-1].tolineno
+ return lineno, self.tolineno
+
+ def get_children(self):
+ yield from self.body
+ yield from self.handlers
+ yield from self.orelse
+ yield from self.finalbody
+
+
class Tuple(BaseContainer):
"""Class representing an :class:`ast.Tuple` node.
diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py
index 0407dbfb..6e996def 100644
--- a/astroid/rebuilder.py
+++ b/astroid/rebuilder.py
@@ -1822,6 +1822,22 @@ class TreeRebuilder:
return self.visit_tryexcept(node, parent)
return None
+ def visit_trystar(self, node: ast.TryStar, parent: NodeNG) -> nodes.TryStar:
+ newnode = nodes.TryStar(
+ lineno=node.lineno,
+ col_offset=node.col_offset,
+ end_lineno=getattr(node, "end_lineno", None),
+ end_col_offset=getattr(node, "end_col_offset", None),
+ parent=parent,
+ )
+ newnode.postinit(
+ body=[self.visit(n, newnode) for n in node.body],
+ handlers=[self.visit(n, newnode) for n in node.handlers],
+ orelse=[self.visit(n, newnode) for n in node.orelse],
+ finalbody=[self.visit(n, newnode) for n in node.finalbody],
+ )
+ return newnode
+
def visit_tuple(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple:
"""Visit a Tuple node by returning a fresh instance of it."""
context = self._get_context(node)
diff --git a/tests/test_group_exceptions.py b/tests/test_group_exceptions.py
new file mode 100644
index 00000000..173c25ed
--- /dev/null
+++ b/tests/test_group_exceptions.py
@@ -0,0 +1,111 @@
+# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
+# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
+# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
+import textwrap
+
+import pytest
+
+from astroid import (
+ AssignName,
+ ExceptHandler,
+ For,
+ Name,
+ TryExcept,
+ Uninferable,
+ bases,
+ extract_node,
+)
+from astroid.const import PY311_PLUS
+from astroid.context import InferenceContext
+from astroid.nodes import Expr, Raise, TryStar
+
+
+@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher")
+def test_group_exceptions() -> None:
+ node = extract_node(
+ textwrap.dedent(
+ """
+ try:
+ raise ExceptionGroup("group", [ValueError(654)])
+ except ExceptionGroup as eg:
+ for err in eg.exceptions:
+ if isinstance(err, ValueError):
+ print("Handling ValueError")
+ elif isinstance(err, TypeError):
+ print("Handling TypeError")"""
+ )
+ )
+ assert isinstance(node, TryExcept)
+ handler = node.handlers[0]
+ exception_group_block_range = (1, 4)
+ assert node.block_range(lineno=1) == exception_group_block_range
+ assert node.block_range(lineno=2) == (2, 2)
+ assert node.block_range(lineno=5) == (5, 9)
+ assert isinstance(handler, ExceptHandler)
+ assert handler.type.name == "ExceptionGroup"
+ children = list(handler.get_children())
+ assert len(children) == 3
+ exception_group, short_name, for_loop = children
+ assert isinstance(exception_group, Name)
+ assert exception_group.block_range(1) == exception_group_block_range
+ assert isinstance(short_name, AssignName)
+ assert isinstance(for_loop, For)
+
+
+@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher")
+def test_star_exceptions() -> None:
+ node = extract_node(
+ textwrap.dedent(
+ """
+ try:
+ raise ExceptionGroup("group", [ValueError(654)])
+ except* ValueError:
+ print("Handling ValueError")
+ except* TypeError:
+ print("Handling TypeError")
+ else:
+ sys.exit(127)
+ finally:
+ sys.exit(0)"""
+ )
+ )
+ assert isinstance(node, TryStar)
+ assert isinstance(node.body[0], Raise)
+ assert node.block_range(1) == (1, 11)
+ assert node.block_range(2) == (2, 2)
+ assert node.block_range(3) == (3, 3)
+ assert node.block_range(4) == (4, 4)
+ assert node.block_range(5) == (5, 5)
+ assert node.block_range(6) == (6, 6)
+ assert node.block_range(7) == (7, 7)
+ assert node.block_range(8) == (8, 8)
+ assert node.block_range(9) == (9, 9)
+ assert node.block_range(10) == (10, 10)
+ assert node.block_range(11) == (11, 11)
+ assert node.handlers
+ handler = node.handlers[0]
+ assert isinstance(handler, ExceptHandler)
+ assert handler.type.name == "ValueError"
+ orelse = node.orelse[0]
+ assert isinstance(orelse, Expr)
+ assert orelse.value.args[0].value == 127
+ final = node.finalbody[0]
+ assert isinstance(final, Expr)
+ assert final.value.args[0].value == 0
+
+
+@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher")
+def test_star_exceptions_infer_name() -> None:
+ trystar = extract_node(
+ """
+try:
+ 1/0
+except* ValueError:
+ pass"""
+ )
+ name = "arbitraryName"
+ context = InferenceContext()
+ context.lookupname = name
+ stmts = bases._infer_stmts([trystar], context)
+ assert list(stmts) == [Uninferable]
+ assert context.lookupname == name