summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog2
-rw-r--r--astroid/node_classes.py234
-rw-r--r--astroid/nodes.py20
-rw-r--r--astroid/rebuilder.py92
-rw-r--r--doc/api/astroid.nodes.rst30
-rw-r--r--setup.cfg2
-rw-r--r--tests/unittest_nodes.py176
7 files changed, 555 insertions, 1 deletions
diff --git a/ChangeLog b/ChangeLog
index 0da3c997..fe94a0ea 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -64,6 +64,8 @@ Release Date: TBA
* Astroid's tags are now the standard form ``vX.Y.Z`` and not ``astroid-X.Y.Z`` anymore.
+* Add initial support for Pattern Matching in Python 3.10
+
What's New in astroid 2.5.6?
============================
Release Date: 2021-04-25
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index f8d0b23c..5b92345d 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -36,6 +36,7 @@ import builtins as builtins_mod
import itertools
import pprint
import sys
+import typing
from functools import lru_cache
from functools import singledispatch as _singledispatch
@@ -43,6 +44,12 @@ from astroid import as_string, bases
from astroid import context as contextmod
from astroid import decorators, exceptions, manager, mixins, util
+try:
+ from typing import Literal
+except ImportError:
+ # typing.Literal was added in Python 3.8
+ from typing_extensions import Literal
+
BUILTINS = builtins_mod.__name__
MANAGER = manager.AstroidManager()
PY38 = sys.version_info[:2] >= (3, 8)
@@ -4801,6 +4808,233 @@ class EvaluatedObject(NodeNG):
yield self.value
+# Pattern matching #######################################################
+
+
+class Match(NodeNG):
+ """Class representing a :class:`ast.Match` node."""
+
+ _astroid_fields = ("subject", "cases")
+ subject: typing.Optional[NodeNG] = None
+ cases: typing.Optional[typing.List["MatchCase"]] = None
+
+ def postinit(
+ self,
+ *,
+ subject: typing.Optional[NodeNG] = None,
+ cases: typing.Optional[typing.List["MatchCase"]] = None,
+ ) -> None:
+ self.subject = subject
+ self.cases = cases
+
+ def get_children(self) -> typing.Generator[NodeNG, None, None]:
+ if self.subject is not None:
+ yield self.subject
+ if self.cases is not None:
+ yield from self.cases
+
+
+class MatchCase(NodeNG):
+ """Class representing a :class:`ast.match_case` node."""
+
+ _astroid_fields = ("pattern", "guard", "body")
+ pattern: typing.Optional["PatternTypes"] = None
+ guard: typing.Optional[NodeNG] = None # can actually be None
+ body: typing.Optional[typing.List[NodeNG]] = None
+
+ def postinit(
+ self,
+ *,
+ pattern=None,
+ guard: typing.Optional[NodeNG] = None,
+ body: typing.Optional[typing.List[NodeNG]] = None,
+ ) -> None:
+ self.pattern = pattern
+ self.guard = guard
+ self.body = body
+
+ def get_children(self) -> typing.Generator[NodeNG, None, None]:
+ if self.pattern is not None:
+ yield self.pattern
+ if self.guard is not None:
+ yield self.guard
+ if self.body is not None:
+ yield from self.body
+
+
+class MatchValue(NodeNG):
+ """Class representing a :class:`ast.MatchValue` node."""
+
+ _astroid_fields = ("value",)
+ value: typing.Optional[NodeNG] = None
+
+ def postinit(self, *, value: NodeNG) -> None:
+ self.value = value
+
+ def get_children(self) -> typing.Generator[NodeNG, None, None]:
+ if self.value is not None:
+ yield self.value
+
+
+class MatchSingleton(NodeNG):
+ """Class representing a :class:`ast.MatchSingleton` node."""
+
+ _other_fields = ("value",)
+
+ def __init__(
+ self,
+ lineno: int,
+ col_offset: int,
+ parent: NodeNG,
+ *,
+ value: Literal[True, False, None],
+ ) -> None:
+ self.value = value
+ super().__init__(lineno, col_offset, parent)
+
+
+class MatchSequence(NodeNG):
+ """Class representing a :class:`ast.MatchSequence` node."""
+
+ _astroid_fields = ("patterns",)
+ patterns: typing.Optional[typing.List["PatternTypes"]] = None
+
+ def postinit(
+ self, *, patterns: typing.Optional[typing.List["PatternTypes"]]
+ ) -> None:
+ self.patterns = patterns
+
+ def get_children(self) -> typing.Generator["PatternTypes", None, None]:
+ if self.patterns is not None:
+ yield from self.patterns
+
+
+class MatchMapping(NodeNG):
+ """Class representing a :class:`ast.MatchMapping` node."""
+
+ _astroid_fields = ("keys", "patterns")
+ _other_fields = ("rest",)
+ keys: typing.Optional[typing.List[NodeNG]] = None
+ patterns: typing.Optional[typing.List["PatternTypes"]] = None
+ rest: typing.Optional[str] = None
+
+ def postinit(
+ self,
+ *,
+ keys=None,
+ patterns: typing.Optional[typing.List["PatternTypes"]] = None,
+ rest: typing.Optional[str] = None,
+ ) -> None:
+ self.keys = keys
+ self.patterns = patterns
+ self.rest = rest
+
+ def get_children(self) -> typing.Generator[NodeNG, None, None]:
+ if self.keys is not None:
+ yield from self.keys
+ if self.patterns is not None:
+ yield from self.patterns
+
+
+class MatchClass(NodeNG):
+ """Class representing a :class:`ast.MatchClass` node."""
+
+ _astroid_fields = ("cls", "patterns", "kwd_attrs", "kwd_patterns")
+ cls: typing.Optional[NodeNG] = None
+ patterns: typing.Optional[typing.List["PatternTypes"]] = None
+ kwd_attrs: typing.Optional[typing.List[str]] = None
+ kwd_patterns: typing.Optional[typing.List["PatternTypes"]] = None
+
+ def postinit(
+ self,
+ *,
+ cls: typing.Optional[NodeNG] = None,
+ patterns: typing.Optional[typing.List["PatternTypes"]] = None,
+ kwd_attrs: typing.Optional[typing.List[str]] = None,
+ kwd_patterns: typing.Optional[typing.List["PatternTypes"]] = None,
+ ) -> None:
+ self.cls = cls
+ self.patterns = patterns
+ self.kwd_attrs = kwd_attrs
+ self.kwd_patterns = kwd_patterns
+
+ def get_children(self) -> typing.Generator[NodeNG, None, None]:
+ if self.cls is not None:
+ yield self.cls
+ if self.patterns is not None:
+ yield from self.patterns
+ if self.kwd_patterns is not None:
+ yield from self.kwd_patterns
+
+
+class MatchStar(NodeNG):
+ """Class representing a :class:`ast.MatchStar` node."""
+
+ _other_fields = ("name",)
+ name: typing.Optional[str] = None
+
+ def __init__(
+ self,
+ lineno: int,
+ col_offset: int,
+ parent: NodeNG,
+ *,
+ name: typing.Optional[str],
+ ) -> None:
+ self.name = name
+ super().__init__(lineno, col_offset, parent)
+
+
+class MatchAs(NodeNG):
+ """Class representing a :class:`ast.MatchAs` node."""
+
+ _astroid_fields = ("pattern",)
+ _other_fields = ("name",)
+ pattern: typing.Optional["PatternTypes"] = None
+ name: typing.Optional[str] = None
+
+ def postinit(
+ self,
+ *,
+ pattern: typing.Optional["PatternTypes"] = None,
+ name: typing.Optional[str] = None,
+ ) -> None:
+ self.pattern = pattern
+ self.name = name
+
+ def get_children(self) -> typing.Generator["PatternTypes", None, None]:
+ if self.pattern is not None:
+ yield self.pattern
+
+
+class MatchOr(NodeNG):
+ """Class representing a :class:`ast.MatchOr` node."""
+
+ _astroid_fields = ("patterns",)
+ patterns: typing.Optional[typing.List["PatternTypes"]] = None
+
+ def postinit(
+ self, *, patterns: typing.Optional[typing.List["PatternTypes"]]
+ ) -> None:
+ self.patterns = patterns
+
+ def get_children(self) -> typing.Generator["PatternTypes", None, None]:
+ if self.patterns is not None:
+ yield from self.patterns
+
+
+PatternTypes = typing.Union[
+ MatchValue,
+ MatchSingleton,
+ MatchSequence,
+ MatchMapping,
+ MatchClass,
+ MatchStar,
+ MatchAs,
+ MatchOr,
+]
+
+
# constants ##############################################################
CONST_CLS = {
diff --git a/astroid/nodes.py b/astroid/nodes.py
index af915b01..dfad2ce0 100644
--- a/astroid/nodes.py
+++ b/astroid/nodes.py
@@ -66,6 +66,16 @@ from astroid.node_classes import (
JoinedStr,
Keyword,
List,
+ Match,
+ MatchAs,
+ MatchCase,
+ MatchClass,
+ MatchMapping,
+ MatchOr,
+ MatchSequence,
+ MatchSingleton,
+ MatchStar,
+ MatchValue,
Name,
NamedExpr,
Nonlocal,
@@ -152,6 +162,16 @@ ALL_NODE_CLASSES = (
Lambda,
List,
ListComp,
+ Match,
+ MatchAs,
+ MatchCase,
+ MatchClass,
+ MatchMapping,
+ MatchOr,
+ MatchSequence,
+ MatchSingleton,
+ MatchStar,
+ MatchValue,
Name,
NamedExpr,
Nonlocal,
diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py
index 2532b61d..61195895 100644
--- a/astroid/rebuilder.py
+++ b/astroid/rebuilder.py
@@ -28,11 +28,15 @@ order to get a single Astroid representation
"""
import sys
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
import astroid
from astroid import nodes
from astroid._ast import ParserModule, get_parser_module, parse_function_type_comment
+from astroid.node_classes import NodeNG
+
+if TYPE_CHECKING:
+ import ast
CONST_NAME_TRANSFORMS = {"None": None, "True": True, "False": False}
@@ -43,6 +47,7 @@ REDIRECT = {
"GenExprFor": "Comprehension",
"excepthandler": "ExceptHandler",
"keyword": "Keyword",
+ "match_case": "MatchCase",
}
PY37 = sys.version_info >= (3, 7)
PY38 = sys.version_info >= (3, 8)
@@ -1010,3 +1015,88 @@ class TreeRebuilder:
if node.value is not None:
newnode.postinit(self.visit(node.value, newnode))
return newnode
+
+ def visit_match(self, node: "ast.Match", parent: NodeNG) -> nodes.Match:
+ newnode = nodes.Match(node.lineno, node.col_offset, parent)
+ newnode.postinit(
+ subject=self.visit(node.subject, newnode),
+ cases=[self.visit(case, newnode) for case in node.cases],
+ )
+ return newnode
+
+ def visit_matchcase(
+ self, node: "ast.match_case", parent: NodeNG
+ ) -> nodes.MatchCase:
+ newnode = nodes.MatchCase(parent=parent)
+ newnode.postinit(
+ pattern=self.visit(node.pattern, newnode),
+ guard=_visit_or_none(node, "guard", self, newnode),
+ body=[self.visit(child, newnode) for child in node.body],
+ )
+ return newnode
+
+ def visit_matchvalue(
+ self, node: "ast.MatchValue", parent: NodeNG
+ ) -> nodes.MatchValue:
+ newnode = nodes.MatchValue(node.lineno, node.col_offset, parent)
+ newnode.postinit(value=self.visit(node.value, newnode))
+ return newnode
+
+ def visit_matchsingleton(
+ self, node: "ast.MatchSingleton", parent: NodeNG
+ ) -> nodes.MatchSingleton:
+ return nodes.MatchSingleton(
+ node.lineno, node.col_offset, parent, value=node.value
+ )
+
+ def visit_matchsequence(
+ self, node: "ast.MatchSequence", parent: NodeNG
+ ) -> nodes.MatchSequence:
+ newnode = nodes.MatchSequence(node.lineno, node.col_offset, parent)
+ newnode.postinit(
+ patterns=[self.visit(pattern, newnode) for pattern in node.patterns]
+ )
+ return newnode
+
+ def visit_matchmapping(
+ self, node: "ast.MatchMapping", parent: NodeNG
+ ) -> nodes.MatchMapping:
+ newnode = nodes.MatchMapping(node.lineno, node.col_offset, parent)
+ newnode.postinit(
+ keys=[self.visit(child, newnode) for child in node.keys],
+ patterns=[self.visit(pattern, newnode) for pattern in node.patterns],
+ rest=node.rest,
+ )
+ return newnode
+
+ def visit_matchclass(
+ self, node: "ast.MatchClass", parent: NodeNG
+ ) -> nodes.MatchClass:
+ newnode = nodes.MatchClass(node.lineno, node.col_offset, parent)
+ newnode.postinit(
+ cls=self.visit(node.cls, newnode),
+ patterns=[self.visit(pattern, newnode) for pattern in node.patterns],
+ kwd_attrs=node.kwd_attrs,
+ kwd_patterns=[
+ self.visit(pattern, newnode) for pattern in node.kwd_patterns
+ ],
+ )
+ return newnode
+
+ def visit_matchstar(self, node: "ast.MatchStar", parent: NodeNG) -> nodes.MatchStar:
+ return nodes.MatchStar(node.lineno, node.col_offset, parent, name=node.name)
+
+ def visit_matchas(self, node: "ast.MatchAs", parent: NodeNG) -> nodes.MatchAs:
+ newnode = nodes.MatchAs(None, None, parent)
+ newnode.postinit(
+ pattern=_visit_or_none(node, "pattern", self, newnode),
+ name=node.name,
+ )
+ return newnode
+
+ def visit_matchor(self, node: "ast.MatchOr", parent: NodeNG) -> nodes.MatchOr:
+ newnode = nodes.MatchOr(None, None, parent)
+ newnode.postinit(
+ patterns=[self.visit(pattern, newnode) for pattern in node.patterns]
+ )
+ return newnode
diff --git a/doc/api/astroid.nodes.rst b/doc/api/astroid.nodes.rst
index 3031c8b5..549ccd25 100644
--- a/doc/api/astroid.nodes.rst
+++ b/doc/api/astroid.nodes.rst
@@ -60,6 +60,16 @@ Nodes
Lambda
List
ListComp
+ Match
+ MatchAs
+ MatchCase
+ MatchClass
+ MatchMapping
+ MatchOr
+ MatchSequence
+ MatchSingleton
+ MatchStar
+ MatchValue
Module
Name
Nonlocal
@@ -181,6 +191,26 @@ Nodes
.. autoclass:: ListComp
+.. autoclass:: Match
+
+.. autoclass:: MatchAs
+
+.. autoclass:: MatchCase
+
+.. autoclass:: MatchClass
+
+.. autoclass:: MatchMapping
+
+.. autoclass:: MatchOr
+
+.. autoclass:: MatchSequence
+
+.. autoclass:: MatchSingleton
+
+.. autoclass:: MatchStar
+
+.. autoclass:: MatchValue
+
.. autoclass:: Module
.. autoclass:: Name
diff --git a/setup.cfg b/setup.cfg
index 8dc4228f..38621c35 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,6 +21,7 @@ classifiers =
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
Programming Language :: Python :: Implementation :: CPython
Programming Language :: Python :: Implementation :: PyPy
Topic :: Software Development :: Libraries :: Python Modules
@@ -37,6 +38,7 @@ install_requires =
lazy_object_proxy>=1.4.0
wrapt>=1.11,<1.13
typed-ast>=1.4.0,<1.5;implementation_name=="cpython" and python_version<"3.8"
+ typing-extensions>=3.7.4;python_version<"3.8"
setup_requires =
setuptools_scm
python_requires = ~=3.6
diff --git a/tests/unittest_nodes.py b/tests/unittest_nodes.py
index 14f713a5..9b3971ce 100644
--- a/tests/unittest_nodes.py
+++ b/tests/unittest_nodes.py
@@ -1371,5 +1371,181 @@ def test_is_generator_for_yield_in_aug_assign():
assert bool(node.is_generator())
+@pytest.mark.skipif(
+ sys.version_info < (3, 10), reason="pattern matching was added in PY310"
+)
+class TestPatternMatching:
+ @staticmethod
+ def test_match_simple():
+ node = builder.extract_node(
+ """
+ match status:
+ case 200:
+ pass
+ case 401 | 402 | 403:
+ pass
+ case None:
+ pass
+ case _:
+ pass
+ """
+ )
+ assert isinstance(node, nodes.Match)
+ assert isinstance(node.subject, nodes.Name)
+ assert node.subject.name == "status"
+ assert isinstance(node.cases, list) and len(node.cases) == 4
+ case0, case1, case2, case3 = node.cases
+
+ assert isinstance(case0.pattern, nodes.MatchValue)
+ assert (
+ isinstance(case0.pattern.value, astroid.Const)
+ and case0.pattern.value.value == 200
+ )
+ assert case0.guard is None
+ assert isinstance(case0.body[0], astroid.Pass)
+
+ assert isinstance(case1.pattern, nodes.MatchOr)
+ assert (
+ isinstance(case1.pattern.patterns, list)
+ and len(case1.pattern.patterns) == 3
+ )
+ for i in range(3):
+ match_value = case1.pattern.patterns[i]
+ assert isinstance(match_value, nodes.MatchValue)
+ assert isinstance(match_value.value, nodes.Const)
+ assert match_value.value.value == (401, 402, 403)[i]
+
+ assert isinstance(case2.pattern, nodes.MatchSingleton)
+ assert case2.pattern.value is None
+
+ assert isinstance(case3.pattern, nodes.MatchAs)
+ assert case3.pattern.name is None
+ assert case3.pattern.pattern is None
+
+ @staticmethod
+ def test_match_sequence():
+ node = builder.extract_node(
+ """
+ match status:
+ case [x, 2, *rest] as y if x > 2:
+ pass
+ """
+ )
+ assert isinstance(node, nodes.Match)
+ assert isinstance(node.cases, list) and len(node.cases) == 1
+ case = node.cases[0]
+
+ assert isinstance(case.pattern, nodes.MatchAs)
+ assert case.pattern.name == "y"
+ assert isinstance(case.guard, nodes.Compare)
+ assert isinstance(case.body[0], nodes.Pass)
+
+ pattern_as = case.pattern.pattern
+ assert isinstance(pattern_as, nodes.MatchSequence)
+ assert isinstance(pattern_as.patterns, list) and len(pattern_as.patterns) == 3
+ assert (
+ isinstance(pattern_as.patterns[0], nodes.MatchAs)
+ and pattern_as.patterns[0].name == "x"
+ and pattern_as.patterns[0].pattern is None
+ )
+ assert (
+ isinstance(pattern_as.patterns[1], nodes.MatchValue)
+ and isinstance(pattern_as.patterns[1].value, nodes.Const)
+ and pattern_as.patterns[1].value.value == 2
+ )
+ assert (
+ isinstance(pattern_as.patterns[2], nodes.MatchStar)
+ and pattern_as.patterns[2].name == "rest"
+ )
+
+ @staticmethod
+ def test_match_mapping():
+ node = builder.extract_node(
+ """
+ match status:
+ case {0: x, 1: _}:
+ pass
+ case {**rest}:
+ pass
+ """
+ )
+ assert isinstance(node, nodes.Match)
+ assert isinstance(node.cases, list) and len(node.cases) == 2
+ case0, case1 = node.cases
+
+ assert isinstance(case0.pattern, nodes.MatchMapping)
+ assert case0.pattern.rest is None
+ assert isinstance(case0.pattern.keys, list) and len(case0.pattern.keys) == 2
+ assert (
+ isinstance(case0.pattern.patterns, list)
+ and len(case0.pattern.patterns) == 2
+ )
+ for i in range(2):
+ key = case0.pattern.keys[i]
+ assert isinstance(key, nodes.Const)
+ assert key.value == i
+ pattern = case0.pattern.patterns[i]
+ assert isinstance(pattern, nodes.MatchAs)
+ assert pattern.name == ("x" if i == 0 else None)
+
+ assert isinstance(case1.pattern, nodes.MatchMapping)
+ assert case1.pattern.rest == "rest"
+ assert isinstance(case1.pattern.keys, list) and len(case1.pattern.keys) == 0
+ assert (
+ isinstance(case1.pattern.patterns, list)
+ and len(case1.pattern.patterns) == 0
+ )
+
+ @staticmethod
+ def test_match_class():
+ node = builder.extract_node(
+ """
+ match x:
+ case Point2D(0, 1):
+ pass
+ case Point3D(x=0, y=1, z=2):
+ pass
+ """
+ )
+ assert isinstance(node, nodes.Match)
+ assert isinstance(node.cases, list) and len(node.cases) == 2
+ case0, case1 = node.cases
+
+ assert isinstance(case0.pattern, nodes.MatchClass)
+ assert isinstance(case0.pattern.cls, nodes.Name)
+ assert case0.pattern.cls.name == "Point2D"
+ assert (
+ isinstance(case0.pattern.patterns, list)
+ and len(case0.pattern.patterns) == 2
+ )
+ for i in range(2):
+ match_value = case0.pattern.patterns[i]
+ assert isinstance(match_value, nodes.MatchValue)
+ assert isinstance(match_value.value, nodes.Const)
+ assert match_value.value.value == i
+
+ assert isinstance(case1.pattern, nodes.MatchClass)
+ assert isinstance(case1.pattern.cls, nodes.Name)
+ assert case1.pattern.cls.name == "Point3D"
+ assert (
+ isinstance(case1.pattern.patterns, list)
+ and len(case1.pattern.patterns) == 0
+ )
+ assert (
+ isinstance(case1.pattern.kwd_attrs, list)
+ and len(case1.pattern.kwd_attrs) == 3
+ )
+ assert (
+ isinstance(case1.pattern.kwd_patterns, list)
+ and len(case1.pattern.kwd_patterns) == 3
+ )
+ for i in range(3):
+ assert case1.pattern.kwd_attrs[i] == ("x", "y", "z")[i]
+ kwd_pattern = case1.pattern.kwd_patterns[i]
+ assert isinstance(kwd_pattern, nodes.MatchValue)
+ assert isinstance(kwd_pattern.value, nodes.Const)
+ assert kwd_pattern.value.value == i
+
+
if __name__ == "__main__":
unittest.main()