summaryrefslogtreecommitdiff
path: root/Cython/Compiler/MatchCaseNodes.py
blob: 99fa70ccde666588ed65a3df6a013f85f7c825f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# Nodes for structural pattern matching.
#
# In a separate file because they're unlikely to be useful for much else.

from .Nodes import Node, StatNode
from .Errors import error


class MatchNode(StatNode):
    """
    subject  ExprNode    The expression to be matched
    cases    [MatchCaseNode]  list of cases
    """

    child_attrs = ["subject", "cases"]

    def validate_irrefutable(self):
        found_irrefutable_case = None
        for c in self.cases:
            if found_irrefutable_case:
                error(
                    found_irrefutable_case.pos,
                    (
                        "%s makes remaining patterns unreachable"
                        % found_irrefutable_case.pattern.irrefutable_message()
                    ),
                )
                break
            if c.is_irrefutable():
                found_irrefutable_case = c
            c.validate_irrefutable()

    def analyse_expressions(self, env):
        error(self.pos, "Structural pattern match is not yet implemented")
        return self


class MatchCaseNode(Node):
    """
    pattern    PatternNode
    body       StatListNode
    guard      ExprNode or None
    """

    child_attrs = ["pattern", "body", "guard"]

    def is_irrefutable(self):
        return self.pattern.is_irrefutable() and not self.guard

    def validate_targets(self):
        self.pattern.get_targets()

    def validate_irrefutable(self):
        self.pattern.validate_irrefutable()


class PatternNode(Node):
    """
    DW decided that PatternNode shouldn't be an expression because
    it does several things (evalutating a boolean expression,
    assignment of targets), and they need to be done at different
    times.

    as_targets   [NameNode]    any target assign by "as"
    """

    child_attrs = ["as_targets"]

    def __init__(self, pos, **kwds):
        super(PatternNode, self).__init__(pos, **kwds)
        if "as_targets" not in kwds:
            self.as_targets = []

    def is_irrefutable(self):
        return False

    def get_targets(self):
        targets = self.get_main_pattern_targets()
        for t in self.as_targets:
            self.add_target_to_targets(targets, t.name)
        return targets

    def update_targets_with_targets(self, targets, other_targets):
        intersection = targets.intersection(other_targets)
        for i in intersection:
            error(self.pos, "multiple assignments to name '%s' in pattern" % i)
        targets.update(other_targets)

    def add_target_to_targets(self, targets, target):
        if target in targets:
            error(self.pos, "multiple assignments to name '%s in pattern" % target)
        targets.add(target)

    def get_main_pattern_targets(self):
        # exclude "as" target
        raise NotImplementedError

    def validate_irrefutable(self):
        for attr in self.child_attrs:
            child = getattr(self, attr)
            if isinstance(child, PatternNode):
                child.validate_irrefutable()


class MatchValuePatternNode(PatternNode):
    """
    value   ExprNode        # todo be more specific
    is_is_check   bool     Picks "is" or equality check
    """

    child_attrs = PatternNode.child_attrs + ["value"]
    is_is_check = False

    def get_main_pattern_targets(self):
        return set()


class MatchAndAssignPatternNode(PatternNode):
    """
    target   NameNode or None  the target to assign to (None = wildcard)
    is_star  bool
    """

    target = None
    is_star = False

    child_atts = PatternNode.child_attrs + ["target"]

    def is_irrefutable(self):
        return not self.is_star

    def irrefutable_message(self):
        if self.target:
            return "name capture '%s'" % self.target.name
        else:
            return "wildcard"

    def get_main_pattern_targets(self):
        if self.target:
            return {self.target.name}
        else:
            return set()


class OrPatternNode(PatternNode):
    """
    alternatives   list of PatternNodes
    """

    child_attrs = PatternNode.child_attrs + ["alternatives"]

    def get_first_irrefutable(self):
        for a in self.alternatives:
            if a.is_irrefutable():
                return a
        return None

    def is_irrefutable(self):
        return self.get_first_irrefutable() is not None

    def irrefutable_message(self):
        return self.get_first_irrefutable().irrefutable_message()

    def get_main_pattern_targets(self):
        child_targets = None
        for ch in self.alternatives:
            ch_targets = ch.get_targets()
            if child_targets is not None and child_targets != ch_targets:
                error(self.pos, "alternative patterns bind different names")
            child_targets = ch_targets
        return child_targets

    def validate_irrefutable(self):
        super(OrPatternNode, self).validate_irrefutable()
        found_irrefutable_case = None
        for a in self.alternatives:
            if found_irrefutable_case:
                error(
                    found_irrefutable_case.pos,
                    (
                        "%s makes remaining patterns unreachable"
                        % found_irrefutable_case.irrefutable_message()
                    ),
                )
                break
            if a.is_irrefutable():
                found_irrefutable_case = a
            a.validate_irrefutable()


class MatchSequencePatternNode(PatternNode):
    """
    patterns   list of PatternNodes
    """

    child_attrs = PatternNode.child_attrs + ["patterns"]

    def get_main_pattern_targets(self):
        targets = set()
        for p in self.patterns:
            self.update_targets_with_targets(targets, p.get_targets())
        return targets


class MatchMappingPatternNode(PatternNode):
    """
    keys   list of NameNodes
    value_patterns  list of PatternNodes of equal length to keys
    double_star_capture_target  NameNode or None
    """

    keys = []
    value_patterns = []
    double_star_capture_target = None

    child_attrs = PatternNode.child_attrs + [
        "keys",
        "value_patterns",
        "double_star_capture_target",
    ]

    def get_main_pattern_targets(self):
        targets = set()
        for p in self.value_patterns:
            self.update_targets_with_targets(targets, p.get_targets())
        if self.double_star_capture_target:
            self.add_target_to_targets(targets, self.double_star_capture_target.name)
        return targets


class ClassPatternNode(PatternNode):
    """
    class_  NameNode or AttributeNode
    positional_patterns  list of PatternNodes
    keyword_pattern_names    list of NameNodes
    keyword_pattern_patterns    list of PatternNodes
                                (same length as keyword_pattern_names)
    """

    class_ = None
    positional_patterns = []
    keyword_pattern_names = []
    keyword_pattern_patterns = []

    child_attrs = PatternNode.child_attrs + [
        "class_",
        "positional_patterns",
        "keyword_pattern_names",
        "keyword_pattern_patterns",
    ]

    def get_main_pattern_targets(self):
        targets = set()
        for p in self.positional_patterns + self.keyword_pattern_patterns:
            self.update_targets_with_targets(targets, p.get_targets())
        return targets