diff options
-rw-r--r-- | ChangeLog | 2 | ||||
-rw-r--r-- | astroid/node_classes.py | 234 | ||||
-rw-r--r-- | astroid/nodes.py | 20 | ||||
-rw-r--r-- | astroid/rebuilder.py | 92 | ||||
-rw-r--r-- | doc/api/astroid.nodes.rst | 30 | ||||
-rw-r--r-- | setup.cfg | 2 | ||||
-rw-r--r-- | tests/unittest_nodes.py | 176 |
7 files changed, 555 insertions, 1 deletions
@@ -64,6 +64,8 @@ Release Date: TBA * Astroid's tags are now the standard form ``vX.Y.Z`` and not ``astroid-X.Y.Z`` anymore. +* Add initial support for Pattern Matching in Python 3.10 + What's New in astroid 2.5.6? ============================ Release Date: 2021-04-25 diff --git a/astroid/node_classes.py b/astroid/node_classes.py index f8d0b23c..5b92345d 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -36,6 +36,7 @@ import builtins as builtins_mod import itertools import pprint import sys +import typing from functools import lru_cache from functools import singledispatch as _singledispatch @@ -43,6 +44,12 @@ from astroid import as_string, bases from astroid import context as contextmod from astroid import decorators, exceptions, manager, mixins, util +try: + from typing import Literal +except ImportError: + # typing.Literal was added in Python 3.8 + from typing_extensions import Literal + BUILTINS = builtins_mod.__name__ MANAGER = manager.AstroidManager() PY38 = sys.version_info[:2] >= (3, 8) @@ -4801,6 +4808,233 @@ class EvaluatedObject(NodeNG): yield self.value +# Pattern matching ####################################################### + + +class Match(NodeNG): + """Class representing a :class:`ast.Match` node.""" + + _astroid_fields = ("subject", "cases") + subject: typing.Optional[NodeNG] = None + cases: typing.Optional[typing.List["MatchCase"]] = None + + def postinit( + self, + *, + subject: typing.Optional[NodeNG] = None, + cases: typing.Optional[typing.List["MatchCase"]] = None, + ) -> None: + self.subject = subject + self.cases = cases + + def get_children(self) -> typing.Generator[NodeNG, None, None]: + if self.subject is not None: + yield self.subject + if self.cases is not None: + yield from self.cases + + +class MatchCase(NodeNG): + """Class representing a :class:`ast.match_case` node.""" + + _astroid_fields = ("pattern", "guard", "body") + pattern: typing.Optional["PatternTypes"] = None + guard: typing.Optional[NodeNG] = None # can actually be None + body: typing.Optional[typing.List[NodeNG]] = None + + def postinit( + self, + *, + pattern=None, + guard: typing.Optional[NodeNG] = None, + body: typing.Optional[typing.List[NodeNG]] = None, + ) -> None: + self.pattern = pattern + self.guard = guard + self.body = body + + def get_children(self) -> typing.Generator[NodeNG, None, None]: + if self.pattern is not None: + yield self.pattern + if self.guard is not None: + yield self.guard + if self.body is not None: + yield from self.body + + +class MatchValue(NodeNG): + """Class representing a :class:`ast.MatchValue` node.""" + + _astroid_fields = ("value",) + value: typing.Optional[NodeNG] = None + + def postinit(self, *, value: NodeNG) -> None: + self.value = value + + def get_children(self) -> typing.Generator[NodeNG, None, None]: + if self.value is not None: + yield self.value + + +class MatchSingleton(NodeNG): + """Class representing a :class:`ast.MatchSingleton` node.""" + + _other_fields = ("value",) + + def __init__( + self, + lineno: int, + col_offset: int, + parent: NodeNG, + *, + value: Literal[True, False, None], + ) -> None: + self.value = value + super().__init__(lineno, col_offset, parent) + + +class MatchSequence(NodeNG): + """Class representing a :class:`ast.MatchSequence` node.""" + + _astroid_fields = ("patterns",) + patterns: typing.Optional[typing.List["PatternTypes"]] = None + + def postinit( + self, *, patterns: typing.Optional[typing.List["PatternTypes"]] + ) -> None: + self.patterns = patterns + + def get_children(self) -> typing.Generator["PatternTypes", None, None]: + if self.patterns is not None: + yield from self.patterns + + +class MatchMapping(NodeNG): + """Class representing a :class:`ast.MatchMapping` node.""" + + _astroid_fields = ("keys", "patterns") + _other_fields = ("rest",) + keys: typing.Optional[typing.List[NodeNG]] = None + patterns: typing.Optional[typing.List["PatternTypes"]] = None + rest: typing.Optional[str] = None + + def postinit( + self, + *, + keys=None, + patterns: typing.Optional[typing.List["PatternTypes"]] = None, + rest: typing.Optional[str] = None, + ) -> None: + self.keys = keys + self.patterns = patterns + self.rest = rest + + def get_children(self) -> typing.Generator[NodeNG, None, None]: + if self.keys is not None: + yield from self.keys + if self.patterns is not None: + yield from self.patterns + + +class MatchClass(NodeNG): + """Class representing a :class:`ast.MatchClass` node.""" + + _astroid_fields = ("cls", "patterns", "kwd_attrs", "kwd_patterns") + cls: typing.Optional[NodeNG] = None + patterns: typing.Optional[typing.List["PatternTypes"]] = None + kwd_attrs: typing.Optional[typing.List[str]] = None + kwd_patterns: typing.Optional[typing.List["PatternTypes"]] = None + + def postinit( + self, + *, + cls: typing.Optional[NodeNG] = None, + patterns: typing.Optional[typing.List["PatternTypes"]] = None, + kwd_attrs: typing.Optional[typing.List[str]] = None, + kwd_patterns: typing.Optional[typing.List["PatternTypes"]] = None, + ) -> None: + self.cls = cls + self.patterns = patterns + self.kwd_attrs = kwd_attrs + self.kwd_patterns = kwd_patterns + + def get_children(self) -> typing.Generator[NodeNG, None, None]: + if self.cls is not None: + yield self.cls + if self.patterns is not None: + yield from self.patterns + if self.kwd_patterns is not None: + yield from self.kwd_patterns + + +class MatchStar(NodeNG): + """Class representing a :class:`ast.MatchStar` node.""" + + _other_fields = ("name",) + name: typing.Optional[str] = None + + def __init__( + self, + lineno: int, + col_offset: int, + parent: NodeNG, + *, + name: typing.Optional[str], + ) -> None: + self.name = name + super().__init__(lineno, col_offset, parent) + + +class MatchAs(NodeNG): + """Class representing a :class:`ast.MatchAs` node.""" + + _astroid_fields = ("pattern",) + _other_fields = ("name",) + pattern: typing.Optional["PatternTypes"] = None + name: typing.Optional[str] = None + + def postinit( + self, + *, + pattern: typing.Optional["PatternTypes"] = None, + name: typing.Optional[str] = None, + ) -> None: + self.pattern = pattern + self.name = name + + def get_children(self) -> typing.Generator["PatternTypes", None, None]: + if self.pattern is not None: + yield self.pattern + + +class MatchOr(NodeNG): + """Class representing a :class:`ast.MatchOr` node.""" + + _astroid_fields = ("patterns",) + patterns: typing.Optional[typing.List["PatternTypes"]] = None + + def postinit( + self, *, patterns: typing.Optional[typing.List["PatternTypes"]] + ) -> None: + self.patterns = patterns + + def get_children(self) -> typing.Generator["PatternTypes", None, None]: + if self.patterns is not None: + yield from self.patterns + + +PatternTypes = typing.Union[ + MatchValue, + MatchSingleton, + MatchSequence, + MatchMapping, + MatchClass, + MatchStar, + MatchAs, + MatchOr, +] + + # constants ############################################################## CONST_CLS = { diff --git a/astroid/nodes.py b/astroid/nodes.py index af915b01..dfad2ce0 100644 --- a/astroid/nodes.py +++ b/astroid/nodes.py @@ -66,6 +66,16 @@ from astroid.node_classes import ( JoinedStr, Keyword, List, + Match, + MatchAs, + MatchCase, + MatchClass, + MatchMapping, + MatchOr, + MatchSequence, + MatchSingleton, + MatchStar, + MatchValue, Name, NamedExpr, Nonlocal, @@ -152,6 +162,16 @@ ALL_NODE_CLASSES = ( Lambda, List, ListComp, + Match, + MatchAs, + MatchCase, + MatchClass, + MatchMapping, + MatchOr, + MatchSequence, + MatchSingleton, + MatchStar, + MatchValue, Name, NamedExpr, Nonlocal, diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 2532b61d..61195895 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -28,11 +28,15 @@ order to get a single Astroid representation """ import sys -from typing import Optional +from typing import TYPE_CHECKING, Optional import astroid from astroid import nodes from astroid._ast import ParserModule, get_parser_module, parse_function_type_comment +from astroid.node_classes import NodeNG + +if TYPE_CHECKING: + import ast CONST_NAME_TRANSFORMS = {"None": None, "True": True, "False": False} @@ -43,6 +47,7 @@ REDIRECT = { "GenExprFor": "Comprehension", "excepthandler": "ExceptHandler", "keyword": "Keyword", + "match_case": "MatchCase", } PY37 = sys.version_info >= (3, 7) PY38 = sys.version_info >= (3, 8) @@ -1010,3 +1015,88 @@ class TreeRebuilder: if node.value is not None: newnode.postinit(self.visit(node.value, newnode)) return newnode + + def visit_match(self, node: "ast.Match", parent: NodeNG) -> nodes.Match: + newnode = nodes.Match(node.lineno, node.col_offset, parent) + newnode.postinit( + subject=self.visit(node.subject, newnode), + cases=[self.visit(case, newnode) for case in node.cases], + ) + return newnode + + def visit_matchcase( + self, node: "ast.match_case", parent: NodeNG + ) -> nodes.MatchCase: + newnode = nodes.MatchCase(parent=parent) + newnode.postinit( + pattern=self.visit(node.pattern, newnode), + guard=_visit_or_none(node, "guard", self, newnode), + body=[self.visit(child, newnode) for child in node.body], + ) + return newnode + + def visit_matchvalue( + self, node: "ast.MatchValue", parent: NodeNG + ) -> nodes.MatchValue: + newnode = nodes.MatchValue(node.lineno, node.col_offset, parent) + newnode.postinit(value=self.visit(node.value, newnode)) + return newnode + + def visit_matchsingleton( + self, node: "ast.MatchSingleton", parent: NodeNG + ) -> nodes.MatchSingleton: + return nodes.MatchSingleton( + node.lineno, node.col_offset, parent, value=node.value + ) + + def visit_matchsequence( + self, node: "ast.MatchSequence", parent: NodeNG + ) -> nodes.MatchSequence: + newnode = nodes.MatchSequence(node.lineno, node.col_offset, parent) + newnode.postinit( + patterns=[self.visit(pattern, newnode) for pattern in node.patterns] + ) + return newnode + + def visit_matchmapping( + self, node: "ast.MatchMapping", parent: NodeNG + ) -> nodes.MatchMapping: + newnode = nodes.MatchMapping(node.lineno, node.col_offset, parent) + newnode.postinit( + keys=[self.visit(child, newnode) for child in node.keys], + patterns=[self.visit(pattern, newnode) for pattern in node.patterns], + rest=node.rest, + ) + return newnode + + def visit_matchclass( + self, node: "ast.MatchClass", parent: NodeNG + ) -> nodes.MatchClass: + newnode = nodes.MatchClass(node.lineno, node.col_offset, parent) + newnode.postinit( + cls=self.visit(node.cls, newnode), + patterns=[self.visit(pattern, newnode) for pattern in node.patterns], + kwd_attrs=node.kwd_attrs, + kwd_patterns=[ + self.visit(pattern, newnode) for pattern in node.kwd_patterns + ], + ) + return newnode + + def visit_matchstar(self, node: "ast.MatchStar", parent: NodeNG) -> nodes.MatchStar: + return nodes.MatchStar(node.lineno, node.col_offset, parent, name=node.name) + + def visit_matchas(self, node: "ast.MatchAs", parent: NodeNG) -> nodes.MatchAs: + newnode = nodes.MatchAs(None, None, parent) + newnode.postinit( + pattern=_visit_or_none(node, "pattern", self, newnode), + name=node.name, + ) + return newnode + + def visit_matchor(self, node: "ast.MatchOr", parent: NodeNG) -> nodes.MatchOr: + newnode = nodes.MatchOr(None, None, parent) + newnode.postinit( + patterns=[self.visit(pattern, newnode) for pattern in node.patterns] + ) + return newnode diff --git a/doc/api/astroid.nodes.rst b/doc/api/astroid.nodes.rst index 3031c8b5..549ccd25 100644 --- a/doc/api/astroid.nodes.rst +++ b/doc/api/astroid.nodes.rst @@ -60,6 +60,16 @@ Nodes Lambda List ListComp + Match + MatchAs + MatchCase + MatchClass + MatchMapping + MatchOr + MatchSequence + MatchSingleton + MatchStar + MatchValue Module Name Nonlocal @@ -181,6 +191,26 @@ Nodes .. autoclass:: ListComp +.. autoclass:: Match + +.. autoclass:: MatchAs + +.. autoclass:: MatchCase + +.. autoclass:: MatchClass + +.. autoclass:: MatchMapping + +.. autoclass:: MatchOr + +.. autoclass:: MatchSequence + +.. autoclass:: MatchSingleton + +.. autoclass:: MatchStar + +.. autoclass:: MatchValue + .. autoclass:: Module .. autoclass:: Name @@ -21,6 +21,7 @@ classifiers = Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 Programming Language :: Python :: Implementation :: CPython Programming Language :: Python :: Implementation :: PyPy Topic :: Software Development :: Libraries :: Python Modules @@ -37,6 +38,7 @@ install_requires = lazy_object_proxy>=1.4.0 wrapt>=1.11,<1.13 typed-ast>=1.4.0,<1.5;implementation_name=="cpython" and python_version<"3.8" + typing-extensions>=3.7.4;python_version<"3.8" setup_requires = setuptools_scm python_requires = ~=3.6 diff --git a/tests/unittest_nodes.py b/tests/unittest_nodes.py index 14f713a5..9b3971ce 100644 --- a/tests/unittest_nodes.py +++ b/tests/unittest_nodes.py @@ -1371,5 +1371,181 @@ def test_is_generator_for_yield_in_aug_assign(): assert bool(node.is_generator()) +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="pattern matching was added in PY310" +) +class TestPatternMatching: + @staticmethod + def test_match_simple(): + node = builder.extract_node( + """ + match status: + case 200: + pass + case 401 | 402 | 403: + pass + case None: + pass + case _: + pass + """ + ) + assert isinstance(node, nodes.Match) + assert isinstance(node.subject, nodes.Name) + assert node.subject.name == "status" + assert isinstance(node.cases, list) and len(node.cases) == 4 + case0, case1, case2, case3 = node.cases + + assert isinstance(case0.pattern, nodes.MatchValue) + assert ( + isinstance(case0.pattern.value, astroid.Const) + and case0.pattern.value.value == 200 + ) + assert case0.guard is None + assert isinstance(case0.body[0], astroid.Pass) + + assert isinstance(case1.pattern, nodes.MatchOr) + assert ( + isinstance(case1.pattern.patterns, list) + and len(case1.pattern.patterns) == 3 + ) + for i in range(3): + match_value = case1.pattern.patterns[i] + assert isinstance(match_value, nodes.MatchValue) + assert isinstance(match_value.value, nodes.Const) + assert match_value.value.value == (401, 402, 403)[i] + + assert isinstance(case2.pattern, nodes.MatchSingleton) + assert case2.pattern.value is None + + assert isinstance(case3.pattern, nodes.MatchAs) + assert case3.pattern.name is None + assert case3.pattern.pattern is None + + @staticmethod + def test_match_sequence(): + node = builder.extract_node( + """ + match status: + case [x, 2, *rest] as y if x > 2: + pass + """ + ) + assert isinstance(node, nodes.Match) + assert isinstance(node.cases, list) and len(node.cases) == 1 + case = node.cases[0] + + assert isinstance(case.pattern, nodes.MatchAs) + assert case.pattern.name == "y" + assert isinstance(case.guard, nodes.Compare) + assert isinstance(case.body[0], nodes.Pass) + + pattern_as = case.pattern.pattern + assert isinstance(pattern_as, nodes.MatchSequence) + assert isinstance(pattern_as.patterns, list) and len(pattern_as.patterns) == 3 + assert ( + isinstance(pattern_as.patterns[0], nodes.MatchAs) + and pattern_as.patterns[0].name == "x" + and pattern_as.patterns[0].pattern is None + ) + assert ( + isinstance(pattern_as.patterns[1], nodes.MatchValue) + and isinstance(pattern_as.patterns[1].value, nodes.Const) + and pattern_as.patterns[1].value.value == 2 + ) + assert ( + isinstance(pattern_as.patterns[2], nodes.MatchStar) + and pattern_as.patterns[2].name == "rest" + ) + + @staticmethod + def test_match_mapping(): + node = builder.extract_node( + """ + match status: + case {0: x, 1: _}: + pass + case {**rest}: + pass + """ + ) + assert isinstance(node, nodes.Match) + assert isinstance(node.cases, list) and len(node.cases) == 2 + case0, case1 = node.cases + + assert isinstance(case0.pattern, nodes.MatchMapping) + assert case0.pattern.rest is None + assert isinstance(case0.pattern.keys, list) and len(case0.pattern.keys) == 2 + assert ( + isinstance(case0.pattern.patterns, list) + and len(case0.pattern.patterns) == 2 + ) + for i in range(2): + key = case0.pattern.keys[i] + assert isinstance(key, nodes.Const) + assert key.value == i + pattern = case0.pattern.patterns[i] + assert isinstance(pattern, nodes.MatchAs) + assert pattern.name == ("x" if i == 0 else None) + + assert isinstance(case1.pattern, nodes.MatchMapping) + assert case1.pattern.rest == "rest" + assert isinstance(case1.pattern.keys, list) and len(case1.pattern.keys) == 0 + assert ( + isinstance(case1.pattern.patterns, list) + and len(case1.pattern.patterns) == 0 + ) + + @staticmethod + def test_match_class(): + node = builder.extract_node( + """ + match x: + case Point2D(0, 1): + pass + case Point3D(x=0, y=1, z=2): + pass + """ + ) + assert isinstance(node, nodes.Match) + assert isinstance(node.cases, list) and len(node.cases) == 2 + case0, case1 = node.cases + + assert isinstance(case0.pattern, nodes.MatchClass) + assert isinstance(case0.pattern.cls, nodes.Name) + assert case0.pattern.cls.name == "Point2D" + assert ( + isinstance(case0.pattern.patterns, list) + and len(case0.pattern.patterns) == 2 + ) + for i in range(2): + match_value = case0.pattern.patterns[i] + assert isinstance(match_value, nodes.MatchValue) + assert isinstance(match_value.value, nodes.Const) + assert match_value.value.value == i + + assert isinstance(case1.pattern, nodes.MatchClass) + assert isinstance(case1.pattern.cls, nodes.Name) + assert case1.pattern.cls.name == "Point3D" + assert ( + isinstance(case1.pattern.patterns, list) + and len(case1.pattern.patterns) == 0 + ) + assert ( + isinstance(case1.pattern.kwd_attrs, list) + and len(case1.pattern.kwd_attrs) == 3 + ) + assert ( + isinstance(case1.pattern.kwd_patterns, list) + and len(case1.pattern.kwd_patterns) == 3 + ) + for i in range(3): + assert case1.pattern.kwd_attrs[i] == ("x", "y", "z")[i] + kwd_pattern = case1.pattern.kwd_patterns[i] + assert isinstance(kwd_pattern, nodes.MatchValue) + assert isinstance(kwd_pattern.value, nodes.Const) + assert kwd_pattern.value.value == i + + if __name__ == "__main__": unittest.main() |