diff options
Diffstat (limited to 'Cython/Compiler/MatchCaseNodes.py')
-rw-r--r-- | Cython/Compiler/MatchCaseNodes.py | 76 |
1 files changed, 42 insertions, 34 deletions
diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py index 10895d51d..9196b9b1f 100644 --- a/Cython/Compiler/MatchCaseNodes.py +++ b/Cython/Compiler/MatchCaseNodes.py @@ -1,9 +1,8 @@ # Nodes for structural pattern matching. # -# In a separate file because they're unlikely to be useful -# for much else +# In a separate file because they're unlikely to be useful for much else. -from .Nodes import Node, StatNode +from .Nodes import Node, StatNode, ErrorNode from .Errors import error, local_errors, report_error from . import Nodes, ExprNodes, PyrexTypes, Builtin from .Code import UtilityCode @@ -26,7 +25,11 @@ class MatchNode(StatNode): def validate_irrefutable(self): found_irrefutable_case = None - for c in self.cases: + for case in self.cases: + if isinstance(case, ErrorNode): + # This validation happens before error nodes have been + # transformed into actual errors, so we need to ignore them + continue if found_irrefutable_case: error( found_irrefutable_case.pos, @@ -36,9 +39,9 @@ class MatchNode(StatNode): ), ) break - if c.is_irrefutable(): - found_irrefutable_case = c - c.validate_irrefutable() + if case.is_irrefutable(): + found_irrefutable_case = case + case.validate_irrefutable() def refactor_cases(self): # An early transform - changes cases that can be represented as @@ -156,6 +159,8 @@ class MatchCaseNode(Node): child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"] def is_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return True # value doesn't really matter return self.pattern.is_irrefutable() and not self.guard def is_simple_value_comparison(self): @@ -164,9 +169,13 @@ class MatchCaseNode(Node): return self.pattern.is_simple_value_comparison() def validate_targets(self): + if isinstance(self.pattern, ErrorNode): + return self.pattern.get_targets() def validate_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return self.pattern.validate_irrefutable() def is_sequence_or_mapping(self): @@ -263,7 +272,7 @@ class SubstitutedMatchCaseNode(MatchCaseBaseNode): class PatternNode(Node): """ - DW decided that PatternNode shouldn't be an expression because + PatternNode is not an expression because it does several things (evalutating a boolean expression, assignment of targets), and they need to be done at different times. @@ -304,23 +313,22 @@ class PatternNode(Node): child_attrs = ["as_targets"] def __init__(self, pos, **kwds): + if "as_targets" not in kwds: + kwds["as_targets"] = [] super(PatternNode, self).__init__(pos, **kwds) - if not hasattr(self, "as_targets"): - self.as_targets = [] def is_irrefutable(self): return False def get_targets(self): targets = self.get_main_pattern_targets() - for t in self.as_targets: - self.add_target_to_targets(targets, t.name) + for target in self.as_targets: + self.add_target_to_targets(targets, target.name) return targets def update_targets_with_targets(self, targets, other_targets): - intersection = targets.intersection(other_targets) - for i in intersection: - error(self.pos, "multiple assignments to name '%s' in pattern" % i) + for name in targets.intersection(other_targets): + error(self.pos, "multiple assignments to name '%s' in pattern" % name) targets.update(other_targets) def add_target_to_targets(self, targets, target): @@ -353,7 +361,7 @@ class PatternNode(Node): def validate_irrefutable(self): for attr in self.child_attrs: child = getattr(self, attr) - if isinstance(child, PatternNode): + if child is not None and isinstance(child, PatternNode): child.validate_irrefutable() def analyse_pattern_expressions(self, env, sequence_mapping_temp): @@ -499,9 +507,9 @@ class OrPatternNode(PatternNode): child_attrs = PatternNode.child_attrs + ["alternatives"] def get_first_irrefutable(self): - for a in self.alternatives: - if a.is_irrefutable(): - return a + for alternative in self.alternatives: + if alternative.is_irrefutable(): + return alternative return None def is_irrefutable(self): @@ -512,17 +520,17 @@ class OrPatternNode(PatternNode): def get_main_pattern_targets(self): child_targets = None - for ch in self.alternatives: - ch_targets = ch.get_targets() - if child_targets is not None and child_targets != ch_targets: + for alternative in self.alternatives: + alternative_targets = alternative.get_targets() + if child_targets is not None and child_targets != alternative_targets: error(self.pos, "alternative patterns bind different names") - child_targets = ch_targets + child_targets = alternative_targets return child_targets def validate_irrefutable(self): super(OrPatternNode, self).validate_irrefutable() found_irrefutable_case = None - for a in self.alternatives: + for alternative in self.alternatives: if found_irrefutable_case: error( found_irrefutable_case.pos, @@ -532,9 +540,9 @@ class OrPatternNode(PatternNode): ), ) break - if a.is_irrefutable(): - found_irrefutable_case = a - a.validate_irrefutable() + if alternative.is_irrefutable(): + found_irrefutable_case = alternative + alternative.validate_irrefutable() def is_simple_value_comparison(self): return all( @@ -622,10 +630,10 @@ class MatchSequencePatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() star_count = 0 - for p in self.patterns: - if isinstance(p, MatchAndAssignPatternNode) and p.is_star: + for pattern in self.patterns: + if isinstance(pattern, MatchAndAssignPatternNode) and pattern.is_star: star_count += 1 - self.update_targets_with_targets(targets, p.get_targets()) + self.update_targets_with_targets(targets, pattern.get_targets()) if star_count > 1: error(self.pos, "multiple starred names in sequence pattern") return targets @@ -926,8 +934,8 @@ class MatchMappingPatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() - for p in self.value_patterns: - self.update_targets_with_targets(targets, p.get_targets()) + for pattern in self.value_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) if self.double_star_capture_target: self.add_target_to_targets(targets, self.double_star_capture_target.name) return targets @@ -956,8 +964,8 @@ class ClassPatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() - for p in self.positional_patterns + self.keyword_pattern_patterns: - self.update_targets_with_targets(targets, p.get_targets()) + for pattern in self.positional_patterns + self.keyword_pattern_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) return targets |