diff options
Diffstat (limited to 'Cython/Compiler/MatchCaseNodes.py')
-rw-r--r-- | Cython/Compiler/MatchCaseNodes.py | 79 |
1 files changed, 42 insertions, 37 deletions
diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py index 3d0d6890b..a7e246bcb 100644 --- a/Cython/Compiler/MatchCaseNodes.py +++ b/Cython/Compiler/MatchCaseNodes.py @@ -1,9 +1,8 @@ # Nodes for structural pattern matching. # -# In a separate file because they're unlikely to be useful -# for much else +# In a separate file because they're unlikely to be useful for much else. -from .Nodes import Node, StatNode +from .Nodes import Node, StatNode, ErrorNode from .Errors import error, local_errors, report_error from . import Nodes, ExprNodes, PyrexTypes, Builtin from .Code import UtilityCode, TempitaUtilityCode @@ -28,7 +27,11 @@ class MatchNode(StatNode): def validate_irrefutable(self): found_irrefutable_case = None - for c in self.cases: + for case in self.cases: + if isinstance(case, ErrorNode): + # This validation happens before error nodes have been + # transformed into actual errors, so we need to ignore them + continue if found_irrefutable_case: error( found_irrefutable_case.pos, @@ -38,9 +41,9 @@ class MatchNode(StatNode): ), ) break - if c.is_irrefutable(): - found_irrefutable_case = c - c.validate_irrefutable() + if case.is_irrefutable(): + found_irrefutable_case = case + case.validate_irrefutable() def refactor_cases(self): # An early transform - changes cases that can be represented as @@ -158,6 +161,8 @@ class MatchCaseNode(Node): child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"] def is_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return True # value doesn't really matter return self.pattern.is_irrefutable() and not self.guard def is_simple_value_comparison(self): @@ -166,9 +171,13 @@ class MatchCaseNode(Node): return self.pattern.is_simple_value_comparison() def validate_targets(self): + if isinstance(self.pattern, ErrorNode): + return self.pattern.get_targets() def validate_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return self.pattern.validate_irrefutable() def is_sequence_or_mapping(self): @@ -263,7 +272,7 @@ class SubstitutedMatchCaseNode(MatchCaseBaseNode): class PatternNode(Node): """ - DW decided that PatternNode shouldn't be an expression because + PatternNode is not an expression because it does several things (evalutating a boolean expression, assignment of targets), and they need to be done at different times. @@ -305,9 +314,9 @@ class PatternNode(Node): child_attrs = ["as_targets"] def __init__(self, pos, **kwds): + if "as_targets" not in kwds: + kwds["as_targets"] = [] super(PatternNode, self).__init__(pos, **kwds) - if not hasattr(self, "as_targets"): - self.as_targets = [] def is_irrefutable(self): return False @@ -322,14 +331,13 @@ class PatternNode(Node): def get_targets(self): targets = self.get_main_pattern_targets() - for t in self.as_targets: - self.add_target_to_targets(targets, t.name) + for target in self.as_targets: + self.add_target_to_targets(targets, target.name) return targets def update_targets_with_targets(self, targets, other_targets): - intersection = targets.intersection(other_targets) - for i in intersection: - error(self.pos, "multiple assignments to name '%s' in pattern" % i) + for name in targets.intersection(other_targets): + error(self.pos, "multiple assignments to name '%s' in pattern" % name) targets.update(other_targets) def add_target_to_targets(self, targets, target): @@ -362,7 +370,7 @@ class PatternNode(Node): def validate_irrefutable(self): for attr in self.child_attrs: child = getattr(self, attr) - if isinstance(child, PatternNode): + if child is not None and isinstance(child, PatternNode): child.validate_irrefutable() def analyse_pattern_expressions(self, env, sequence_mapping_temp): @@ -516,9 +524,9 @@ class OrPatternNode(PatternNode): child_attrs = PatternNode.child_attrs + ["alternatives"] def get_first_irrefutable(self): - for a in self.alternatives: - if a.is_irrefutable(): - return a + for alternative in self.alternatives: + if alternative.is_irrefutable(): + return alternative return None def is_irrefutable(self): @@ -537,17 +545,17 @@ class OrPatternNode(PatternNode): def get_main_pattern_targets(self): child_targets = None - for ch in self.alternatives: - ch_targets = ch.get_targets() - if child_targets is not None and child_targets != ch_targets: + for alternative in self.alternatives: + alternative_targets = alternative.get_targets() + if child_targets is not None and child_targets != alternative_targets: error(self.pos, "alternative patterns bind different names") - child_targets = ch_targets + child_targets = alternative_targets return child_targets def validate_irrefutable(self): super(OrPatternNode, self).validate_irrefutable() found_irrefutable_case = None - for a in self.alternatives: + for alternative in self.alternatives: if found_irrefutable_case: error( found_irrefutable_case.pos, @@ -557,9 +565,9 @@ class OrPatternNode(PatternNode): ), ) break - if a.is_irrefutable(): - found_irrefutable_case = a - a.validate_irrefutable() + if alternative.is_irrefutable(): + found_irrefutable_case = alternative + alternative.validate_irrefutable() def is_simple_value_comparison(self): return all( @@ -739,10 +747,10 @@ class MatchSequencePatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() star_count = 0 - for p in self.patterns: - if p.is_match_and_assign_pattern and p.is_star: + for pattern in self.patterns: + if pattern.is_match_and_assign_pattern and pattern.is_star: star_count += 1 - self.update_targets_with_targets(targets, p.get_targets()) + self.update_targets_with_targets(targets, pattern.get_targets()) if star_count > 1: error(self.pos, "multiple starred names in sequence pattern") return targets @@ -1094,8 +1102,8 @@ class MatchMappingPatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() - for p in self.value_patterns: - self.update_targets_with_targets(targets, p.get_targets()) + for pattern in self.value_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) if self.double_star_capture_target: self.add_target_to_targets(targets, self.double_star_capture_target.name) return targets @@ -1485,11 +1493,8 @@ class ClassPatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() - for p in self.keyword_pattern_patterns: - self.update_targets_with_targets(targets, p.get_targets()) - - for p in self.positional_patterns: - self.update_targets_with_targets(targets, p.get_targets()) + for pattern in self.positional_patterns + self.keyword_pattern_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) return targets def generate_main_pattern_assignment_list(self, subject_node, env): |