summaryrefslogtreecommitdiff
path: root/Cython/Compiler/MatchCaseNodes.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/MatchCaseNodes.py')
-rw-r--r--Cython/Compiler/MatchCaseNodes.py76
1 files changed, 42 insertions, 34 deletions
diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py
index 10895d51d..9196b9b1f 100644
--- a/Cython/Compiler/MatchCaseNodes.py
+++ b/Cython/Compiler/MatchCaseNodes.py
@@ -1,9 +1,8 @@
# Nodes for structural pattern matching.
#
-# In a separate file because they're unlikely to be useful
-# for much else
+# In a separate file because they're unlikely to be useful for much else.
-from .Nodes import Node, StatNode
+from .Nodes import Node, StatNode, ErrorNode
from .Errors import error, local_errors, report_error
from . import Nodes, ExprNodes, PyrexTypes, Builtin
from .Code import UtilityCode
@@ -26,7 +25,11 @@ class MatchNode(StatNode):
def validate_irrefutable(self):
found_irrefutable_case = None
- for c in self.cases:
+ for case in self.cases:
+ if isinstance(case, ErrorNode):
+ # This validation happens before error nodes have been
+ # transformed into actual errors, so we need to ignore them
+ continue
if found_irrefutable_case:
error(
found_irrefutable_case.pos,
@@ -36,9 +39,9 @@ class MatchNode(StatNode):
),
)
break
- if c.is_irrefutable():
- found_irrefutable_case = c
- c.validate_irrefutable()
+ if case.is_irrefutable():
+ found_irrefutable_case = case
+ case.validate_irrefutable()
def refactor_cases(self):
# An early transform - changes cases that can be represented as
@@ -156,6 +159,8 @@ class MatchCaseNode(Node):
child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"]
def is_irrefutable(self):
+ if isinstance(self.pattern, ErrorNode):
+ return True # value doesn't really matter
return self.pattern.is_irrefutable() and not self.guard
def is_simple_value_comparison(self):
@@ -164,9 +169,13 @@ class MatchCaseNode(Node):
return self.pattern.is_simple_value_comparison()
def validate_targets(self):
+ if isinstance(self.pattern, ErrorNode):
+ return
self.pattern.get_targets()
def validate_irrefutable(self):
+ if isinstance(self.pattern, ErrorNode):
+ return
self.pattern.validate_irrefutable()
def is_sequence_or_mapping(self):
@@ -263,7 +272,7 @@ class SubstitutedMatchCaseNode(MatchCaseBaseNode):
class PatternNode(Node):
"""
- DW decided that PatternNode shouldn't be an expression because
+ PatternNode is not an expression because
it does several things (evalutating a boolean expression,
assignment of targets), and they need to be done at different
times.
@@ -304,23 +313,22 @@ class PatternNode(Node):
child_attrs = ["as_targets"]
def __init__(self, pos, **kwds):
+ if "as_targets" not in kwds:
+ kwds["as_targets"] = []
super(PatternNode, self).__init__(pos, **kwds)
- if not hasattr(self, "as_targets"):
- self.as_targets = []
def is_irrefutable(self):
return False
def get_targets(self):
targets = self.get_main_pattern_targets()
- for t in self.as_targets:
- self.add_target_to_targets(targets, t.name)
+ for target in self.as_targets:
+ self.add_target_to_targets(targets, target.name)
return targets
def update_targets_with_targets(self, targets, other_targets):
- intersection = targets.intersection(other_targets)
- for i in intersection:
- error(self.pos, "multiple assignments to name '%s' in pattern" % i)
+ for name in targets.intersection(other_targets):
+ error(self.pos, "multiple assignments to name '%s' in pattern" % name)
targets.update(other_targets)
def add_target_to_targets(self, targets, target):
@@ -353,7 +361,7 @@ class PatternNode(Node):
def validate_irrefutable(self):
for attr in self.child_attrs:
child = getattr(self, attr)
- if isinstance(child, PatternNode):
+ if child is not None and isinstance(child, PatternNode):
child.validate_irrefutable()
def analyse_pattern_expressions(self, env, sequence_mapping_temp):
@@ -499,9 +507,9 @@ class OrPatternNode(PatternNode):
child_attrs = PatternNode.child_attrs + ["alternatives"]
def get_first_irrefutable(self):
- for a in self.alternatives:
- if a.is_irrefutable():
- return a
+ for alternative in self.alternatives:
+ if alternative.is_irrefutable():
+ return alternative
return None
def is_irrefutable(self):
@@ -512,17 +520,17 @@ class OrPatternNode(PatternNode):
def get_main_pattern_targets(self):
child_targets = None
- for ch in self.alternatives:
- ch_targets = ch.get_targets()
- if child_targets is not None and child_targets != ch_targets:
+ for alternative in self.alternatives:
+ alternative_targets = alternative.get_targets()
+ if child_targets is not None and child_targets != alternative_targets:
error(self.pos, "alternative patterns bind different names")
- child_targets = ch_targets
+ child_targets = alternative_targets
return child_targets
def validate_irrefutable(self):
super(OrPatternNode, self).validate_irrefutable()
found_irrefutable_case = None
- for a in self.alternatives:
+ for alternative in self.alternatives:
if found_irrefutable_case:
error(
found_irrefutable_case.pos,
@@ -532,9 +540,9 @@ class OrPatternNode(PatternNode):
),
)
break
- if a.is_irrefutable():
- found_irrefutable_case = a
- a.validate_irrefutable()
+ if alternative.is_irrefutable():
+ found_irrefutable_case = alternative
+ alternative.validate_irrefutable()
def is_simple_value_comparison(self):
return all(
@@ -622,10 +630,10 @@ class MatchSequencePatternNode(PatternNode):
def get_main_pattern_targets(self):
targets = set()
star_count = 0
- for p in self.patterns:
- if isinstance(p, MatchAndAssignPatternNode) and p.is_star:
+ for pattern in self.patterns:
+ if isinstance(pattern, MatchAndAssignPatternNode) and pattern.is_star:
star_count += 1
- self.update_targets_with_targets(targets, p.get_targets())
+ self.update_targets_with_targets(targets, pattern.get_targets())
if star_count > 1:
error(self.pos, "multiple starred names in sequence pattern")
return targets
@@ -926,8 +934,8 @@ class MatchMappingPatternNode(PatternNode):
def get_main_pattern_targets(self):
targets = set()
- for p in self.value_patterns:
- self.update_targets_with_targets(targets, p.get_targets())
+ for pattern in self.value_patterns:
+ self.update_targets_with_targets(targets, pattern.get_targets())
if self.double_star_capture_target:
self.add_target_to_targets(targets, self.double_star_capture_target.name)
return targets
@@ -956,8 +964,8 @@ class ClassPatternNode(PatternNode):
def get_main_pattern_targets(self):
targets = set()
- for p in self.positional_patterns + self.keyword_pattern_patterns:
- self.update_targets_with_targets(targets, p.get_targets())
+ for pattern in self.positional_patterns + self.keyword_pattern_patterns:
+ self.update_targets_with_targets(targets, pattern.get_targets())
return targets