diff options
-rw-r--r-- | Cython/Compiler/MatchCaseNodes.py | 906 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 9 | ||||
-rw-r--r-- | Cython/Compiler/Visitor.py | 3 | ||||
-rw-r--r-- | Cython/Utility/MatchCase.c | 390 | ||||
-rw-r--r-- | Cython/Utility/MatchCase_Cy.pyx | 12 | ||||
-rw-r--r-- | tests/run/extra_patma.pyx | 86 | ||||
-rw-r--r-- | tests/run/extra_patma_py.py | 20 | ||||
-rw-r--r-- | tests/run/test_patma.py | 62 |
8 files changed, 1328 insertions, 160 deletions
diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py index 461e6ecd9..9196b9b1f 100644 --- a/Cython/Compiler/MatchCaseNodes.py +++ b/Cython/Compiler/MatchCaseNodes.py @@ -3,20 +3,25 @@ # In a separate file because they're unlikely to be useful for much else. from .Nodes import Node, StatNode, ErrorNode -from . import Nodes -from .Errors import error -from . import ExprNodes +from .Errors import error, local_errors, report_error +from . import Nodes, ExprNodes, PyrexTypes, Builtin +from .Code import UtilityCode +from .Options import copy_inherited_directives +from contextlib import contextmanager class MatchNode(StatNode): """ subject ExprNode The expression to be matched cases [MatchCaseBaseNode] list of cases + + sequence_mapping_temp None or AssignableTempNode an int temp to store result of sequence/mapping tests """ child_attrs = ["subject", "cases"] subject_clonenode = None # set to a value if we require a temp + sequence_mapping_temp = None def validate_irrefutable(self): found_irrefutable_case = None @@ -51,16 +56,14 @@ class MatchNode(StatNode): for n, c in enumerate(self.cases + [None]): # The None is dummy at the end if c is not None and c.is_simple_value_comparison(): body = SubstitutedIfStatListNode( - c.body.pos, - stats = c.body.stats, - match_node = self + c.body.pos, stats=c.body.stats, match_node=self ) if_clause = Nodes.IfClauseNode( c.pos, condition=c.pattern.get_simple_comparison_node(subject), body=body, ) - assignments = c.pattern.generate_target_assignments(subject) + assignments = c.pattern.generate_target_assignments(subject, None) if assignments: if_clause.body.stats.insert(0, assignments) if not current_if_statement: @@ -72,7 +75,7 @@ class MatchNode(StatNode): elif current_if_statement: # this cannot be simplified, but previous case(s) were self.cases[n - 1] = SubstitutedMatchCaseNode( - current_if_statement.pos, body = current_if_statement + current_if_statement.pos, body=current_if_statement ) current_if_statement = None # eliminate optimized cases @@ -84,20 +87,46 @@ class MatchNode(StatNode): c.analyse_case_declarations(self.subject_clonenode, env) def analyse_expressions(self, env): + sequence_mapping_count = 0 + for c in self.cases: + if c.is_sequence_or_mapping(): + sequence_mapping_count += 1 + if sequence_mapping_count >= 2: + self.sequence_mapping_temp = AssignableTempNode( + self.pos, PyrexTypes.c_uint_type + ) + self.sequence_mapping_temp.is_addressable = lambda: True + self.subject = self.subject.analyse_expressions(env) assert isinstance(self.subject, ExprNodes.ProxyNode) if not self.subject.arg.is_literal: self.subject.arg = self.subject.arg.coerce_to_temp(env) - subject = self.subject_clonenode - self.cases = [c.analyse_case_expressions(subject, env) for c in self.cases] + subject = self.subject_clonenode.analyse_expressions(env) + self.cases = [ + c.analyse_case_expressions(subject, env, self.sequence_mapping_temp) + for c in self.cases + ] + self.cases = [c for c in self.cases if c is not None] return self def generate_execution_code(self, code): + if self.sequence_mapping_temp: + self.sequence_mapping_temp.allocate(code) + code.putln( + "%s = 0; /* sequence/mapping test temp */" + % self.sequence_mapping_temp.result() + ) + # For things that are a sequence at compile-time it's difficult + # to avoid generating the sequence mapping temp. Therefore, silence + # an "unused error" + code.putln("(void)%s;" % self.sequence_mapping_temp.result()) end_label = self.end_label = code.new_label() if self.subject_clonenode: self.subject.generate_evaluation_code(code) for c in self.cases: c.generate_execution_code(code, end_label) + if self.sequence_mapping_temp: + self.sequence_mapping_temp.release(code) if code.label_used(end_label): code.put_label(end_label) if self.subject_clonenode: @@ -121,12 +150,13 @@ class MatchCaseNode(Node): guard ExprNode or None generated: - original_pattern PatternNode (not coerced to temp) target_assignments [ SingleAssignmentNodes ] + comp_node ExprNode that evaluates to bool """ target_assignments = None - child_attrs = ["pattern", "target_assignments", "guard", "body"] + comp_node = None + child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"] def is_irrefutable(self): if isinstance(self.pattern, ErrorNode): @@ -148,21 +178,38 @@ class MatchCaseNode(Node): return self.pattern.validate_irrefutable() + def is_sequence_or_mapping(self): + return isinstance( + self.pattern, (MatchSequencePatternNode, MatchMappingPatternNode) + ) + def analyse_case_declarations(self, subject_node, env): self.pattern.analyse_declarations(env) - self.target_assignments = self.pattern.generate_target_assignments(subject_node) + self.target_assignments = self.pattern.generate_target_assignments( + subject_node, env + ) if self.target_assignments: self.target_assignments.analyse_declarations(env) if self.guard: self.guard.analyse_declarations(env) self.body.analyse_declarations(env) - def analyse_case_expressions(self, subject_node, env): - self.pattern = self.pattern.analyse_pattern_expressions(subject_node, env) - self.original_pattern = self.pattern - self.pattern.comp_node = self.pattern.comp_node.coerce_to_boolean( - env - ).coerce_to_simple(env) + def analyse_case_expressions(self, subject_node, env, sequence_mapping_temp): + with local_errors(True) as errors: + self.pattern = self.pattern.analyse_pattern_expressions(env, sequence_mapping_temp) + self.comp_node = self.pattern.get_comparison_node(subject_node, sequence_mapping_temp) + self.comp_node = self.comp_node.analyse_types(env) + + if self.comp_node and self.comp_node.is_literal: + self.comp_node.calculate_constant_result() + if not self.comp_node.constant_result: + # we know this pattern can't succeed. Ignore any errors and return None + return None + for error in errors: + report_error(error) + + self.comp_node = self.comp_node.coerce_to_boolean(env).coerce_to_simple(env) + if self.target_assignments: self.target_assignments = self.target_assignments.analyse_expressions(env) if self.guard: @@ -171,36 +218,51 @@ class MatchCaseNode(Node): return self def generate_execution_code(self, code, end_label): - self.pattern.generate_comparison_evaluation_code(code) - code.putln("if (%s) { /* pattern */" % self.pattern.comparison_result()) - self.pattern.generate_comparison_disposal_code(code) - self.pattern.free_comparison_temps(code) + self.pattern.allocate_subject_temps(code) + self.comp_node.generate_evaluation_code(code) + + end_of_case_label = code.new_label() + + code.putln("if (!%s) { /* !pattern */" % self.comp_node.result()) + self.pattern.dispose_of_subject_temps(code) # failed, don't need the subjects + code.put_goto(end_of_case_label) + + code.putln("} else { /* pattern */") + self.comp_node.generate_disposal_code(code) + self.comp_node.free_temps(code) if self.target_assignments: self.target_assignments.generate_execution_code(code) + self.pattern.dispose_of_subject_temps(code) + self.pattern.release_subject_temps(code) # we're done with the subjects here if self.guard: self.guard.generate_evaluation_code(code) code.putln("if (%s) { /* guard */" % self.guard.result()) self.guard.generate_disposal_code(code) self.guard.free_temps(code) + # body_insertion_point = code.insertion_point() self.body.generate_execution_code(code) if not self.body.is_terminator: code.put_goto(end_label) if self.guard: code.putln("} /* guard */") code.putln("} /* pattern */") + code.put_label(end_of_case_label) class SubstitutedMatchCaseNode(MatchCaseBaseNode): # body - Node - The (probably) if statement that it's replaced with child_attrs = ["body"] + def is_sequence_or_mapping(self): + return False + def analyse_case_declarations(self, subject_node, env): self.analyse_declarations(env) def analyse_declarations(self, env): self.body.analyse_declarations(env) - def analyse_case_expressions(self, subject_node, env): + def analyse_case_expressions(self, subject_node, env, sequence_mapping_temp): self.body = self.body.analyse_expressions(env) return self @@ -219,29 +281,42 @@ class PatternNode(Node): Generated in analysis: comp_node ExprNode node to evaluate for the pattern + + ---------------------------------------- + How these nodes are processed: + 1. During "analyse_declarations" PatternNode.generate_target_assignments + is called on the main PatternNode of the case. This calls its + sub-patterns generate_target_assignments recursively. + This creates a StatListNode that is held by the + MatchCaseNode. + 2. In the "analyse_expressions" phases, the MatchCaseNode calls + PatternNode.analyse_pattern_expressions, which calls its + sub-pattern recursively. + 3. At the end of the "analyse_expressions" stage the MatchCaseNode + class PatternNode.get_comparison_node (which calls + PatternNode.get_comparison_node for its sub-patterns). This + returns an ExprNode which can be evaluated to determine if the + pattern has matched. + While generating the comparison we try quite hard not to + analyse it until right at the end, because otherwise it'll lead + to a lot of repeated work for deeply nested patterns. + 4. In the code generation stage, PatternNodes hardly generate any + code themselves. However, they do set up whatever temps they + need (mainly for sub-pattern subjects), with "allocate_subject_temps", + "release_subject_temps", and "dispose_of_subject_temps" (which + they also call recursively on their sub-patterns) """ + # useful for type tests is_match_value_pattern = False - comp_node = None - - # When pattern nodes are analysed it changes which children are important. - # Therefore have two different list of child_attrs and switch - initial_child_attrs = ["as_targets"] - post_analysis_child_attrs = ["comp_node"] + child_attrs = ["as_targets"] def __init__(self, pos, **kwds): if "as_targets" not in kwds: kwds["as_targets"] = [] super(PatternNode, self).__init__(pos, **kwds) - @property - def child_attrs(self): - if self.comp_node is None: - return self.initial_child_attrs - else: - return self.post_analysis_child_attrs - def is_irrefutable(self): return False @@ -279,44 +354,29 @@ class PatternNode(Node): """ raise NotImplementedError + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + error(self.pos, "This type of pattern is not currently supported %s" % self) + raise NotImplementedError + def validate_irrefutable(self): for attr in self.child_attrs: child = getattr(self, attr) if child is not None and isinstance(child, PatternNode): child.validate_irrefutable() - def analyse_pattern_expressions(self, subject_node, env): - error(self.pos, "This type of pattern is not currently supported") - return self - - def calculate_result_code(self): - return self.comp_node.result() + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + error(self.pos, "This type of pattern is not currently supported %s" % self) + raise NotImplementedError def generate_result_code(self, code): pass - def generate_comparison_evaluation_code(self, code): - self.comp_node.generate_evaluation_code(code) - - def comparison_result(self): - return self.comp_node.result() - - def generate_comparison_disposal_code(self, code): - self.comp_node.generate_disposal_code(code) - - def free_comparison_temps(self, code): - self.comp_node.free_temps(code) - - def generate_target_assignments(self, subject_node): + def generate_target_assignments(self, subject_node, env): # Generates the assignment code needed to initialize all the targets. # Returns either a StatListNode or None assignments = [] for target in self.as_targets: - if ( - self.is_match_value_pattern - and self.value - and self.value.is_simple() - ): + if self.is_match_value_pattern and self.value and self.value.is_simple(): # in this case we can optimize slightly and just take the value subject_node = self.value.clone_node() assignments.append( @@ -324,27 +384,39 @@ class PatternNode(Node): target.pos, lhs=target.clone_node(), rhs=subject_node ) ) - assignments.extend(self.generate_main_pattern_assignment_list(subject_node)) + assignments.extend( + self.generate_main_pattern_assignment_list(subject_node, env) + ) if assignments: return Nodes.StatListNode(self.pos, stats=assignments) else: return None - def generate_main_pattern_assignment_list(self, subject_node): + def generate_main_pattern_assignment_list(self, subject_node, env): # generates assignments for everything except the "as_target". # Override in subclasses. # Returns a list of Nodes return [] + def allocate_subject_temps(self, code): + pass # Implement in nodes that need it + + def release_subject_temps(self, code): + pass # Implement in nodes that need it + + def dispose_of_subject_temps(self, code): + pass # Implement in nodes that need it + class MatchValuePatternNode(PatternNode): """ value ExprNode is_is_check bool Picks "is" or equality check """ + is_match_value_pattern = True - initial_child_attrs = PatternNode.initial_child_attrs + ["value"] + child_attrs = PatternNode.child_attrs + ["value"] is_is_check = False @@ -354,25 +426,26 @@ class MatchValuePatternNode(PatternNode): def is_simple_value_comparison(self): return True - def get_comparison_node(self, subject_node): + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + # for this node the comparison and "simple" comparison are the same + return LazyCoerceToBool(self.pos, + arg=self.get_simple_comparison_node(subject_node) + ) + + def get_simple_comparison_node(self, subject_node): op = "is" if self.is_is_check else "==" return ExprNodes.PrimaryCmpNode( self.pos, operator=op, operand1=subject_node, operand2=self.value ) - def get_simple_comparison_node(self, subject_node): - # for this node the comparison and "simple" comparison are the same - return self.get_comparison_node(subject_node) - def analyse_declarations(self, env): super(MatchValuePatternNode, self).analyse_declarations(env) if self.value: self.value.analyse_declarations(env) - def analyse_pattern_expressions(self, subject_node, env): + def analyse_pattern_expressions(self, env, sequence_mapping_temp): if self.value: self.value = self.value.analyse_expressions(env) - self.comp_node = self.get_comparison_node(subject_node).analyse_expressions(env) return self @@ -385,10 +458,10 @@ class MatchAndAssignPatternNode(PatternNode): target = None is_star = False - initial_child_attrs = PatternNode.initial_child_attrs + ["target"] + child_attrs = PatternNode.child_attrs + ["target"] def is_irrefutable(self): - return not self.is_star + return True def irrefutable_message(self): if self.target: @@ -407,12 +480,12 @@ class MatchAndAssignPatternNode(PatternNode): def get_simple_comparison_node(self, subject_node): assert self.is_simple_value_comparison() - return ExprNodes.BoolNode(self.pos, value=True) + return self.get_comparison_node(subject_node, None) - def get_comparison_node(self, subject_node): - return self.get_simple_comparison_node(subject_node) + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + return ExprNodes.BoolNode(self.pos, value=True) - def generate_main_pattern_assignment_list(self, subject_node): + def generate_main_pattern_assignment_list(self, subject_node, env): if self.target: return [ Nodes.SingleAssignmentNode( @@ -422,16 +495,8 @@ class MatchAndAssignPatternNode(PatternNode): else: return [] - def analyse_pattern_expressions(self, subject_node, env): - if self.is_star: - return super(MatchAndAssignPatternNode, self).analyse_pattern_expressions( - subject_node, env - ) - else: - self.comp_node = self.get_comparison_node(subject_node).analyse_expressions( - env - ) - return self + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + return self # nothing to analyse class OrPatternNode(PatternNode): @@ -439,7 +504,7 @@ class OrPatternNode(PatternNode): alternatives list of PatternNodes """ - initial_child_attrs = PatternNode.initial_child_attrs + ["alternatives"] + child_attrs = PatternNode.child_attrs + ["alternatives"] def get_first_irrefutable(self): for alternative in self.alternatives: @@ -505,24 +570,26 @@ class OrPatternNode(PatternNode): ) return binop + def get_comparison_node(self, subject_node, sequence_mapping_temp): + error(self.pos, "'or' cases aren't fully implemented yet") + return ExprNodes.BoolNode(self.pos, value=False) + def analyse_declarations(self, env): super(OrPatternNode, self).analyse_declarations(env) for a in self.alternatives: a.analyse_declarations(env) - def analyse_pattern_expressions(self, subject_node, env): + def analyse_pattern_expressions(self, env, sequence_mapping_temp): self.alternatives = [ - a.analyse_pattern_expressions(subject_node, env) for a in self.alternatives + a.analyse_pattern_expressions(env, sequence_mapping_temp) + for a in self.alternatives ] - self.comp_node = self.get_comparison_node( - subject_node - ).analyse_temp_boolean_expression(env) return self - def generate_main_pattern_assignment_list(self, subject_node): + def generate_main_pattern_assignment_list(self, subject_node, env): assignments = [] for a in self.alternatives: - a_assignment = a.generate_target_assignments(subject_node) + a_assignment = a.generate_target_assignments(subject_node, env) if a_assignment: # Switch code paths depending on which node gets assigned error(self.pos, "Need to handle assignments in or nodes correctly") @@ -533,16 +600,320 @@ class OrPatternNode(PatternNode): class MatchSequencePatternNode(PatternNode): """ patterns list of PatternNodes + + generated: + subjects [TrackTypeTempNode] individual subsubjects can be assigned to these """ - initial_child_attrs = PatternNode.initial_child_attrs + ["patterns"] + subjects = None + needs_length_temp = False + + child_attrs = PatternNode.child_attrs + ["patterns"] + + Pyx_sequence_check_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, + [ + PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg( + "sequence_mapping_temp", + PyrexTypes.c_ptr_type(PyrexTypes.c_uint_type), + None, + ), + ], + exception_value="-1", + ) + + def __init__(self, pos, **kwds): + super(MatchSequencePatternNode, self).__init__(pos, **kwds) + self.length_temp = AssignableTempNode(self.pos, PyrexTypes.c_py_ssize_t_type) def get_main_pattern_targets(self): targets = set() + star_count = 0 for pattern in self.patterns: + if isinstance(pattern, MatchAndAssignPatternNode) and pattern.is_star: + star_count += 1 self.update_targets_with_targets(targets, pattern.get_targets()) + if star_count > 1: + error(self.pos, "multiple starred names in sequence pattern") return targets + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + from .UtilNodes import TempResultFromStatNode, ResultRefNode + + test = None + assert getattr(self, "subject_temps", None) is not None + + seq_test = self.make_sequence_check(subject_node, sequence_mapping_temp) + if isinstance(seq_test, ExprNodes.BoolNode) and not seq_test.value: + return seq_test # no point in proceeding further! + + has_star = False + all_tests = [seq_test] + pattern_tests = [] + for n, pattern in enumerate(self.patterns): + if isinstance(pattern, MatchAndAssignPatternNode) and pattern.is_star: + has_star = True + self.needs_length_temp = True + + if self.subject_temps[n] is None: + # The subject has been identified as unneeded, so don't evaluate it + continue + p_test = pattern.get_comparison_node(self.subject_temps[n]) + + result_ref = ResultRefNode(pos=self.pos, type=PyrexTypes.c_bint_type) + subject_assignment = Nodes.SingleAssignmentNode( + self.pos, + lhs=self.subject_temps[n], # the temp node + rhs=self.subjects[n], # the regular node + ) + test_assignment = Nodes.SingleAssignmentNode( + self.pos, lhs=result_ref, rhs=p_test + ) + stats = Nodes.StatListNode( + self.pos, stats=[subject_assignment, test_assignment] + ) + pattern_tests.append(TempResultFromStatNode(result_ref, stats)) + + min_length = len(self.patterns) + if has_star: + min_length -= 1 + # check whether we need a length call... + if not (self.patterns and len(self.patterns) == 1 and has_star): + length_call = self.make_length_call_node(subject_node) + + if length_call.is_literal and ( + (has_star and min_length < length_call.constant_result) + or (not has_star and min_length != length_call.constant_result) + ): + # definitely failed! + return ExprNodes.BoolNode(self.pos, value=False) + seq_len_test = ExprNodes.PrimaryCmpNode( + self.pos, + operator=">=" if has_star else "==", + operand1=length_call, + operand2=ExprNodes.IntNode(self.pos, value=str(min_length)), + ) + all_tests.append(seq_len_test) + else: + self.needs_length_temp = False + all_tests.extend(pattern_tests) + test = generate_binop_tree_from_list(self.pos, "and", all_tests) + return LazyCoerceToBool(test.pos, arg=test) + + def generate_subjects(self, subject_node, env): + assert self.subjects is None # not called twice + + star_idx = None + for n, pattern in enumerate(self.patterns): + if isinstance(pattern, MatchAndAssignPatternNode) and pattern.is_star: + star_idx = n + if star_idx is None: + idxs = list(range(len(self.patterns))) + else: + fwd_idxs = list(range(star_idx)) + backward_idxs = list(range(star_idx - len(self.patterns) + 1, 0)) + star_idx = ( + fwd_idxs[-1] + 1 if fwd_idxs else None, + backward_idxs[0] if backward_idxs else None, + ) + idxs = fwd_idxs + [star_idx] + backward_idxs + + subjects = [] + for pattern, idx in zip(self.patterns, idxs): + indexer = self.make_indexing_node(pattern, subject_node, idx, env) + subjects.append(ExprNodes.ProxyNode(indexer) if indexer else None) + self.subjects = subjects + self.subject_temps = [ + None if p.is_irrefutable() else TrackTypeTempNode(self.pos, s) + for s, p in zip(self.subjects, self.patterns) + ] + + def generate_main_pattern_assignment_list(self, subject_node, env): + assignments = [] + self.generate_subjects(subject_node, env) + for subject_temp, subject, pattern in zip( + self.subject_temps, self.subjects, self.patterns + ): + needs_result_ref = False + if subject_temp is not None: + subject = subject_temp + else: + if subject is None: + assert not pattern.get_targets() + continue + elif not subject.is_literal or subject.is_temp: + from .UtilNodes import ResultRefNode, LetNode + + subject = ResultRefNode(subject) + needs_result_ref = True + p_assignments = pattern.generate_target_assignments(subject, env) + if needs_result_ref: + p_assignments = LetNode(subject, p_assignments) + else: + p_assignments = p_assignments + if p_assignments: + assignments.append(p_assignments) + return assignments + + def make_sequence_check(self, subject_node, sequence_mapping_temp): + # Note: the sequence check code is very quick on Python 3.10+ + # but potentially quite slow on lower versions (although should + # be medium quick for common types). It'd be nice to cache the + # results of it where it's been called on the same object + # multiple times. + # DW has decided that that's too complicated to implement + # for now. + utility_code = UtilityCode.load_cached("IsSequence", "MatchCase.c") + if sequence_mapping_temp is not None: + sequence_mapping_temp = ExprNodes.AmpersandNode( + self.pos, operand=sequence_mapping_temp + ) + else: + sequence_mapping_temp = ExprNodes.NullNode(self.pos) + call = ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_IsSequence", + self.Pyx_sequence_check_type, + utility_code=utility_code, + args=[subject_node, sequence_mapping_temp], + ) + + def type_check(type): + # type-check need not be perfect, it's an optimization + if type in [Builtin.list_type, Builtin.tuple_type]: + return True + if type.is_memoryviewslice or type.is_ctuple: + return True + if type in [ + Builtin.str_type, + Builtin.bytes_type, + Builtin.unicode_type, + Builtin.bytearray_type, + Builtin.dict_type, + Builtin.set_type, + ]: + # non-exhaustive list at this stage, but returning "False" is + # an optimization so it's allowed to be non-exchaustive + return False + if type.is_numeric or type.is_struct or type.is_enum: + # again, not exhaustive + return False + return None + + return StaticTypeCheckNode( + self.pos, arg=subject_node, fallback=call, check=type_check + ) + + def make_length_call_node(self, subject_node): + len_entry = Builtin.builtin_scope.lookup("len") + if subject_node.type.is_memoryviewslice: + len_call = ExprNodes.IndexNode( + self.pos, + base=ExprNodes.AttributeNode( + self.pos, obj=subject_node, attribute="shape" + ), + index=ExprNodes.IntNode(self.pos, value="0"), + ) + elif subject_node.type.is_ctuple: + len_call = ExprNodes.IntNode( + self.pos, value=str(len(subject_node.type.components)) + ) + else: + len_call = ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode(self.pos, name="len", entry=len_entry), + args=[subject_node], + ) + if self.needs_length_temp: + return ExprNodes.AssignmentExpressionNode( + self.pos, lhs=self.length_temp, rhs=len_call + ) + else: + return len_call + + def make_indexing_node(self, pattern, subject_node, idx, env): + if pattern.is_irrefutable() and not pattern.get_targets(): + # Nothing to do - index isn't used + return None + + def get_index_from_int(i): + if i is None: + return None + else: + int_node = ExprNodes.IntNode(pattern.pos, value=str(i)) + if i >= 0: + return int_node + else: + self.needs_length_temp = True + return ExprNodes.binop_node( + pattern.pos, + operator="+", + operand1=self.length_temp, + operand2=int_node, + ) + + if isinstance(idx, tuple): + start = get_index_from_int(idx[0]) + stop = get_index_from_int(idx[1]) + indexer = SliceToListNode( + pattern.pos, + base=subject_node, + start=start, + stop=stop, + length_node=self.length_temp if self.needs_length_temp else None, + ) + else: + indexer = CompilerDirectivesExprNode( + arg=ExprNodes.IndexNode( + pattern.pos, base=subject_node, index=get_index_from_int(idx) + ), + directives=copy_inherited_directives( + env.directives, boundscheck=False, wraparound=False + ), + ) + return indexer + + def analyse_declarations(self, env): + for p in self.patterns: + p.analyse_declarations(env) + return super(MatchSequencePatternNode, self).analyse_declarations(env) + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + for n in range(len(self.subjects)): + if self.subjects[n]: + self.subjects[n] = self.subjects[n].analyse_types(env) + for n in range(len(self.patterns)): + self.patterns[n] = self.patterns[n].analyse_pattern_expressions(env, None) + return self + + def allocate_subject_temps(self, code): + if self.needs_length_temp: + self.length_temp.allocate(code) + for temp in self.subject_temps: + if temp is not None: + temp.allocate(code) + for pattern in self.patterns: + pattern.allocate_subject_temps(code) + + def release_subject_temps(self, code): + if self.needs_length_temp: + self.length_temp.release(code) + for temp in self.subject_temps: + if temp is not None: + temp.release(code) + for pattern in self.patterns: + pattern.release_subject_temps(code) + + def dispose_of_subject_temps(self, code): + if self.needs_length_temp: + code.put_xdecref_clear(self.length_temp.result(), self.length_temp.type) + for temp in self.subject_temps: + if temp is not None: + code.put_xdecref_clear(temp.result(), temp.type) + for pattern in self.patterns: + pattern.dispose_of_subject_temps(code) + class MatchMappingPatternNode(PatternNode): """ @@ -555,7 +926,7 @@ class MatchMappingPatternNode(PatternNode): value_patterns = [] double_star_capture_target = None - initial_child_attrs = PatternNode.initial_child_attrs + [ + child_attrs = PatternNode.child_attrs + [ "keys", "value_patterns", "double_star_capture_target", @@ -584,7 +955,7 @@ class ClassPatternNode(PatternNode): keyword_pattern_names = [] keyword_pattern_patterns = [] - initial_child_attrs = PatternNode.initial_child_attrs + [ + child_attrs = PatternNode.child_attrs + [ "class_", "positional_patterns", "keyword_pattern_names", @@ -605,8 +976,347 @@ class SubstitutedIfStatListNode(Nodes.StatListNode): match_node - the enclosing match statement """ + def generate_execution_code(self, code): super(SubstitutedIfStatListNode, self).generate_execution_code(code) if not self.is_terminator: code.put_goto(self.match_node.end_label) + +class StaticTypeCheckNode(ExprNodes.ExprNode): + """ + Useful for structural pattern matching, where we + can skip the "is_seqeunce/is_mapping" checks if + we know the type in advantage (or reduce it to a + None check). + + This should optimize itself out at the analyse_expressions + stage + + arg ExprNode + fallback ExprNode Function to be called if the static + typecheck isn't optimized out + check callable Returns True, False, or None (for "can't tell") + """ + + child_attrs = ["fallback"] # arg in not included since it's in "fallback" + + def analyse_types(self, env): + check = self.check(self.arg.type) + if check: + if self.arg.may_be_none(): + return ExprNodes.PrimaryCmpNode( + self.pos, + operand1=self.arg, + operand2=ExprNodes.NoneNode(self.pos), + operator="is_not", + ).analyse_expressions(env) + else: + return ExprNodes.BoolNode(pos=self.pos, value=True).analyse_expressions( + env + ) + elif check is None: + return self.fallback.analyse_expressions(env) + else: + return ExprNodes.BoolNode(pos=self.pos, value=False).analyse_expressions( + env + ) + + +class AssignableTempNode(ExprNodes.TempNode): + lhs_of_first_assignment = True # assume it can be assigned to once + _assigned_twice = False + + def infer_type(self, env): + return self.type + + def generate_assignment_code(self, rhs, code, overloaded_assignment=False): + assert ( + not self._assigned_twice + ) # if this happens it's not a disaster but it needs a refactor + self._assigned_twice = True + if self.type.is_pyobject: + rhs.make_owned_reference(code) + if not self.lhs_of_first_assignment: + code.put_decref(self.result(), self.ctype()) + code.putln( + "%s = %s;" + % ( + self.result(), + rhs.result() if overloaded_assignment else rhs.result_as(self.ctype()), + ) + ) + rhs.generate_post_assignment_code(code) + rhs.free_temps(code) + + def generate_post_assignment_code(self, code): + code.put_incref(self.result(), self.type) + + def clone_node(self): + return self # temps break if you make a copy! + + +class TrackTypeTempNode(AssignableTempNode): + # Like a temp node, but type is set from arg + + lhs_of_first_assignment = True # assume it can be assigned to once + _assigned_twice = False + + @property + def type(self): + return getattr(self.arg, "type", None) + + def __init__(self, pos, arg): + ExprNodes.ExprNode.__init__(self, pos) # skip a level + self.arg = arg + + def infer_type(self, env): + return self.arg.infer_type(env) + + +class SliceToListNode(ExprNodes.ExprNode): + """ + Used as a brief temporary node to optimize + case [..., *_, ...]. + Always reduces to something else after analyse_types + """ + + subexprs = ["base", "start", "stop", "length_node"] + + type = Builtin.list_type + + Pyx_iterable_to_list_type = PyrexTypes.CFuncType( + Builtin.list_type, + [ + PyrexTypes.CFuncTypeArg("iterable", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + ], + ) + + def generate_via_slicing(self, env): + # for any more complicated type that doesn't have a specialized path + # we can simply slice it and copy it to list + res = CompilerDirectivesExprNode( + arg=ExprNodes.SliceIndexNode( + self.pos, base=self.base, start=self.start, stop=self.stop + ), + directives=copy_inherited_directives( + env.directives, boundcheck=False, wraparound=False + ), + ) + res = ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode( + self.pos, + name="list", + entry=Builtin.builtin_scope.lookup("list"), + ), + args=[res], + ) + return res + + def get_stop(self): + if not self.stop: + if self.length_node: + return self.length_node + else: + return ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode( + self.pos, name="len", entry=Builtin.builtin_scope.lookup("len") + ), + args=[self.base], + ) + else: + return self.stop + + def generate_for_memoryview(self, env): + # Requires Cython code generation... + # A list comprehension with indexing turns out to be a good option + from .UtilityCode import CythonUtilityCode + + suffix = self.base.type.specialization_suffix() + util_code = CythonUtilityCode.load( + "MemoryviewSliceToList", + "MatchCase_Cy.pyx", + context={ + "decl_code": self.base.type.empty_declaration_code(pyrex=True), + "suffix": suffix, + }, + ) + func_type = PyrexTypes.CFuncType( + Builtin.list_type, + [ + PyrexTypes.CFuncTypeArg("x", self.base.type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + ], + ) + env.use_utility_code( + util_code + ) # attaching it to the call node doesn't seem enough + return ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_SliceMemoryview_%s" % suffix, + func_type, + utility_code=util_code, + args=[ + self.base, + self.start if self.start else ExprNodes.IntNode(self.pos, value="0"), + self.get_stop(), + ], + ) + + def generate_for_pyobject(self): + util_code_name = None + func_name = None + if self.base.type is Builtin.tuple_type: + util_code_name = "TupleSliceToList" + elif self.base.type is Builtin.list_type: + func_name = "PyList_GetSlice" + elif ( + self.base.type.is_pyobject + and not self.base.type is PyrexTypes.py_object_type + ): + # some specialized type that almost certainly isn't a list. Just go straight + # to the "other" version of it + util_code_name = "OtherSequenceSliceToList" + else: + util_code_name = "UnknownTypeSliceToList" + if not func_name: + func_name = "__Pyx_MatchCase_%s" % util_code_name + if util_code_name: + util_code = UtilityCode.load_cached( + util_code_name, + "MatchCase.c" + ) + else: + util_code = None + start = self.start if self.start else ExprNodes.IntNode(self.pos, value="0") + stop = self.get_stop() + return ExprNodes.PythonCapiCallNode( + self.pos, + func_name, + self.Pyx_iterable_to_list_type, + utility_code=util_code, + args=[self.base, start, stop], + ) + + def analyse_types(self, env): + self.base = self.base.analyse_types(env) + if self.base.type.is_memoryviewslice: + result = self.generate_for_memoryview(env) + elif self.base.type.is_pyobject: + result = self.generate_for_pyobject() + else: + # Some other type (probably a ctuple). + # Just slice it, copy it to a list and hope it works + result = self.generate_via_slicing(env) + return result.analyse_types(env) + + +class CompilerDirectivesExprNode(ExprNodes.ProxyNode): + # Like compiler directives node, but for an expression + # directives {string:value} A dictionary holding the right value for + # *all* possible directives. + # arg ExprNode + + def __init__(self, arg, directives): + super(CompilerDirectivesExprNode, self).__init__(arg) + self.directives = directives + + @contextmanager + def _apply_directives(self, obj): + old = obj.directives + obj.directives = self.directives + yield + obj.directives = old + + @property + def is_temp(self): + return self.arg.is_temp + + def infer_type(self, env): + with self._apply_directives(env): + return super(CompilerDirectivesExprNode, self).infer_type(env) + + def analyse_declarations(self, env): + with self._apply_directives(env): + self.arg.analyse_declarations(env) + + def analyse_types(self, env): + with self._apply_directives(env): + return super(CompilerDirectivesExprNode, self).analyse_types(env) + + def generate_result_code(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).generate_result_code(code) + + def generate_evaluation_code(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).generate_evaluation_code(code) + + def generate_disposal_code(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).generate_disposal_code(code) + + def free_temps(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).free_temps(code) + + def annotate(self, code): + with self._apply_directives(code.globalstate): + self.arg.annotate(code) + + +class LazyCoerceToPyObject(ExprNodes.ExprNode): + """ + Just calls "self.arg.coerce_to_pyobject" when it's analysed, + so doesn't need 'env' when it's created + arg - ExprNode + """ + subexprs = ["arg"] + type = PyrexTypes.py_object_type + + def analyse_types(self, env): + return self.arg.analyse_types(env).coerce_to_pyobject(env) + + +class LazyCoerceToBool(ExprNodes.ExprNode): + """ + Just calls "self.arg.coerce_to_bool" when it's analysed, + so doesn't need 'env' when it's created + arg - ExprNode + """ + subexprs = ["arg"] + type = PyrexTypes.c_bint_type + + def analyse_types(self, env): + return self.arg.analyse_boolean_expression(env) + +def generate_binop_tree_from_list(pos, operator, list_of_tests): + """ + Given a list of operands generates a roughly balanced tree: + (test1 op test2) op (test3 op test4) + This is better than (((test1 op test2) op test3) op test4) + because it generates a shallower tree of nodes so is + less likely to overflow the compiler + """ + len_tests = len(list_of_tests) + if len_tests == 1: + return list_of_tests[0] + else: + split_idx = len_tests // 2 + operand1 = generate_binop_tree_from_list( + pos, operator, list_of_tests[:split_idx] + ) + operand2 = generate_binop_tree_from_list( + pos, operator, list_of_tests[split_idx:] + ) + return ExprNodes.binop_node( + pos, + operator=operator, + operand1=operand1, + operand2=operand2 + )
\ No newline at end of file diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 981e4b174..301d93335 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -945,6 +945,9 @@ class InterpretCompilerDirectives(CythonTransform): self.directives = old_directives return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + # The following four functions track imports and cimports that # begin with "cython" def is_cython_directive(self, name): @@ -2025,6 +2028,9 @@ class ForwardDeclareTypes(CythonTransform): env.directives = old return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + def visit_ModuleNode(self, node): self.module_scope = node.scope self.module_scope.directives = node.directives @@ -2898,6 +2904,9 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations): self.directives = old_directives return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + def visit_DefNode(self, node): modifiers = [] if 'inline' in self.directives: diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 92e2eb9c0..e9545d76e 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -318,6 +318,9 @@ class CythonTransform(VisitorTransform): self.current_directives = old return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + def visit_Node(self, node): self._process_children(node) return node diff --git a/Cython/Utility/MatchCase.c b/Cython/Utility/MatchCase.c new file mode 100644 index 000000000..cea40d8b2 --- /dev/null +++ b/Cython/Utility/MatchCase.c @@ -0,0 +1,390 @@ +///////////////////////////// ABCCheck ////////////////////////////// + +#if PY_VERSION_HEX < 0x030A0000 +static CYTHON_INLINE int __Pyx_MatchCase_IsExactSequence(PyObject *o) { + // is one of the small list of builtin types known to be a sequence + if (PyList_CheckExact(o) || PyTuple_CheckExact(o) || + PyType_CheckExact(o, PyRange_Type) || PyType_CheckExact(o, PyMemoryView_Type)) { + // Use exact type match for these checks. I in the event of inheritence we need to make sure + // that it isn't a mapping too + return 1; + } + return 0; +} + +static CYTHON_INLINE int __Pyx_MatchCase_IsExactMapping(PyObject *o) { + // Py_Dict is the only regularly used mapping type + // "types.MappingProxyType" also exists but is correctly covered by + // the isinstance(o, Mapping) check + return PyDict_CheckExact(o); +} + +static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) { + if (PyType_GetFlags(Py_TYPE(o)) & (Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS)) || + PyByteArray_Check(o)) { + return 1; // these types are deliberately excluded from the sequence test + // even though they look like sequences for most other purposes. + // Leave them as inexact checks since they do pass + // "isinstance(o, collections.abc.Sequence)" so it's very hard to + // reason about their subclasses + } + if (o == Py_None || PyLong_CheckExact(o) || PyFloat_CheckExact(o)) { + return 1; + } + #if PY_MAJOR_VERSION < 3 + if (PyInt_CheckExact(o)) { + return 1; + } + #endif + + return 0; +} + +// sequence_mapping_temp: For Python 3.10 testing sequences and mappings are +// really quick and this is ignored. For lower versions of Python they're +// slow, especially in the "fail" case. +// Therefore, we store an int temp to avoid duplicating tests. +// The bits of it in order are: +// 0. definitely a sequence +// 1. definitely a mapping +// - note that both of the above and be true when +// the type is registered with both abc types (not via inheritance) +// and in this case we return true for both IsSequence or IsMapping +// (which seems like the best handling of an ambiguous situation) +// 2. definitely not a sequence +// 3. definitely not a mapping + +#if PY_VERSION_HEX < 0x030A0000 +#define __PYX_DEFINITELY_SEQUENCE_FLAG 1U +#define __PYX_DEFINITELY_MAPPING_FLAG (1U<<1) +#define __PYX_DEFINITELY_NOT_SEQUENCE_FLAG (1U<<2) +#define __PYX_DEFINITELY_NOT_MAPPING_FLAG (1U<<3) +#define __PYX_SEQUENCE_MAPPING_ERROR (1U<<4) // only used by the ABCCheck function +#endif + +static int __Pyx_MatchCase_InitAndIsInstanceAbc(PyObject *o, PyObject *abc_module, + PyObject **abc_type, PyObject *name) { + assert(!abc_type); + abc_type = PyObject_GetAttr(abc_module, name); + if (!abc_type) { + return -1; + } + return PyObject_IsInstance(o, abc_type); +} + +// the result is defined using the specification for sequence_mapping_temp +// (detailed in "is_sequence") +static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, int definitely_not_sequence, int definitely_not_mapping) { + // in Python 3.10 objects can have their sequence bit set or their mapping bit set + // but not both. Practically this translates to "which type is registered first". + // In Python < 3.10 we can only determine this if they're direct bases (by looking + // at the MRO order). If they're registered manually then we can't tell + + PyObject *abc_module=NULL, *sequence_type=NULL, *mapping_type=NULL; + PyObject *mro; + int sequence_result=0, mapping_result=0; + unsigned int result = 0; + + abc_module = PyImport_ImportModule( +#if PY_VERSION_HEX > 0x03030000 + "collections.abc" +#else + "collections" +#endif + ); + if (!abc_module) { + return __PYX_SEQUENCE_MAPPING_ERROR; + } + if (sequence_first) { + if (definitely_not_sequence) { + result = __PYX_DEFINITELY_SEQUENCE_FLAG; + goto end; + } + sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence")); + if (sequence_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; + goto end; + } else if (sequence_result == 0) { + result |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG; + goto end; + } + // else wait to see what mapping is + } + if (!definitely_not_mapping) { + mapping_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &mapping_type, PYIDENT("Mapping")); + if (mapping_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; + goto end; + } else if (mapping_result == 0) { + result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG; + if (sequence_first) { + assert(sequence_result); + result |= __PYX_DEFINITELY_SEQUENCE_FLAG; + } + goto end; + } else /* mapping_result == 1 */ { + if (sequence_first && !sequence_result) { + result |= __PYX_DEFINITELY_MAPPING_FLAG; + goto end; + } + } + } + if (!sequence_first) { + // here we know mapping_result is true because we'd have returned otherwise + assert(mapping_result); + if (!definitely_not_sequence) { + sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence")); + } + if (sequence_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; + goto end; + } else if (sequence_result == 0) { + result |= (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_MAPPING_FLAG); + goto end; + } /* else sequence_result == 1, continue to check both */ + } + + // It's an instance of both types. Look up the MRO order. + // In event of failure treat it as "could be either" + result = __PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_MAPPING_FLAG; + mro = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__mro__"); + Py_ssize_t i; + if (!mro) { + PyErr_Clear(); + goto end; + } + if (!PyTuple_Check(mro)) { + Py_DECREF(mro); + goto end; + } + for (i=1; i < PyTuple_GET_SIZE(mro); ++i) { + int is_subclass_sequence, is_subclass_mapping; + PyObject *mro_item = PyTuple_GET_ITEM(mro, i); + is_subclass_sequence = PyObject_IsSubclass(mro_item, sequence_type); + if (is_subclass_sequence < 0) goto loop_error; + is_subclass_mapping = PyObject_IsSubclass(mro_item, mapping_type); + if (is_subclass_mapping < 0) goto loop_error; + if (is_subclass_sequence && !is_subclass_mapping) { + result = (__PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + break; + } else if (is_subclass_mapping && !is_subclass_sequence) { + result = (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_MAPPING_FLAG); + break; + } + } + // If we get to the end of the loop without breaking then neither type is in + // the MRO, so they've both been registered manually. We don't know which was + // registered first so accept the object as either as a compromise + if (0) { + loop_error: + PyErr_Clear(); + } + Py_DECREF(mro); + + end: + Py_XDECREF(abc_module); + Py_XDECREF(sequence_type); + Py_XDECREF(mapping_type); + return result; +} +#endif + +///////////////////////////// IsSequence.proto ////////////////////// + +static int __Pyx_MatchCase_IsSequence(PyObject *o, unsigned int *sequence_mapping_temp); /* proto */ + +//////////////////////////// IsSequence ///////////////////////// +//@requires: ABCCheck + +static int __Pyx_MatchCase_IsSequence(PyObject *o, unsigned int *sequence_mapping_temp) { +#if PY_VERSION_HEX >= 0x030A0000 + return __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_SEQUENCE); +#else + // Py_TPFLAGS_SEQUENCE doesn't exit. + PyObject *o_module_name; + unsigned int abc_result, dummy=0; + + if (sequence_mapping_temp) { + // maybe we already know the answer + if (*sequence_mapping_temp & __PYX_DEFINITELY_SEQUENCE_FLAG) { + return 1; + } + if (*sequence_mapping_temp & __PYX_DEFINITELY_NOT_SEQUENCE_FLAG) { + return 0; + } + } else { + // Probably quicker to just assign it and not check from here + sequence_mapping_temp = &dummy; + } + + // Start by check a known list of types + if (__Pyx_MatchCase_IsExactSequence(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + return 1; + } + if (__Pyx_MatchCase_IsExactMapping(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_MAPPING_FLAG | __PYX_DEFINITELY_NOT_SEQUENCE_FLAG); + return 0; + } + if (__Pyx_MatchCase_IsExactNeitherSequenceNorMapping(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + return 0; + } + + abc_result = __Pyx_MatchCase_ABCCheck( + o, 1, + *sequence_mapping_temp & __PYX_DEFINITELY_NOT_SEQUENCE_FLAG, + *sequence_mapping_temp & __PYX_DEFINITELY_NOT_MAPPING_FLAG + ); + if (abc_result & __PYX_SEQUENCE_MAPPING_ERROR) { + return -1; + } + *sequence_mapping_temp = abc_result; + if (*sequence_mapping_temp & __PYX_DEFINITELY_SEQUENCE_FLAG) { + return 1; + } + + // array.array is a more complicated check (and unfortunately isn't covered by + // collections.abc.Sequence on Python <3.10). + // Do the test by checking the module name, and then importing/testing the class + // It also doesn't give perfect results for classes that inherit from both array.array + // and a mapping + o_module_name = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__module__"); + if (!o_module_name) { + return -1; + } +#if PY_MAJOR_VERSION >= 3 + if (PyUnicode_Check(o_module_name) && PyUnicode_CompareWithASCIIString(o_module_name, "array") == 0) +#else + if (PyBytes_Check(o_module_name) && PyBytes_AS_STRING(o_module_name)[0] == 'a' && + PyBytes_AS_STRING(o_module_name)[1] == 'r' && PyBytes_AS_STRING(o_module_name)[2] == 'r' && + PyBytes_AS_STRING(o_module_name)[3] == 'a' && PyBytes_AS_STRING(o_module_name)[4] == 'y' && + PyBytes_AS_STRING(o_module_name)[5] == '\0') +#endif + { + int is_array; + PyObject *array_module, *array_object; + Py_DECREF(o_module_name); + array_module = PyImport_ImportModule("array"); + if (!array_module) { + PyErr_Clear(); + return 0; // treat these tests as "soft" and don't cause an exception + } + array_object = PyObject_GetAttrString(array_module, "array"); + Py_DECREF(array_module); + if (!array_object) { + PyErr_Clear(); + return 0; + } + is_array = PyObject_IsInstance(o, array_object); + Py_DECREF(array_object); + if (is_array) { + *sequence_mapping_temp |= __PYX_DEFINITELY_SEQUENCE_FLAG; + return 1; + } + PyErr_Clear(); + } else { + Py_DECREF(o_module_name); + } + *sequence_mapping_temp |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG; + return 0; +#endif +} + +////////////////////// OtherSequenceSliceToList.proto ////////////////////// + +static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end); /* proto */ + +////////////////////// OtherSequenceSliceToList ////////////////////////// + +// This is substantially based off ceval unpack_iterable. +// It's also pretty similar to itertools.islice +// Indices must be postive - there's no wraparound or boundschecking + +static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end) { + int total = end-start; + int i; + PyObject *list; + ssizeargfunc slot; + PyTypeObject *type = Py_TYPE(x); + + list = PyList_New(total); + if (!list) { + return NULL; + } + +#if CYTHON_USE_TYPE_SLOTS || PY_MAJOR_VERSION < 3 || CYTHON_COMPILING_IN_PYPY + slot = type->tp_as_sequence ? type->tp_as_sequence->sq_item : NULL; +#else + if ((PY_VERSION_HEX >= 0x030A0000) || __Pyx_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE)) { + // PyType_GetSlot only works on heap types in Python <3.10 + slot = (ssizeargfunc) PyType_GetSlot(type, Py_sq_item); + } +#endif + if (!slot) { + #if !defined(Py_LIMITED_API) && !defined(PySequence_ITEM) + // PyPy (and maybe others?) implements PySequence_ITEM as a function. In this case + // it's slightly more efficient than using PySequence_GetItem since it skips negative indices + slot = PySequence_ITEM; + #else + slot = PySequence_GetItem; + #endif + } + + for (i=start; i<end; ++i) { + PyObject *obj = slot(x, i); + if (!obj) { + Py_DECREF(list); + return NULL; + } + PyList_SET_ITEM(list, i-start, obj); + } + return list; +} + +////////////////////// TupleSliceToList.proto ////////////////////// + +static PyObject *__Pyx_MatchCase_TupleSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end); /* proto */ + +////////////////////// TupleSliceToList ////////////////////////// +//@requires: OtherSequenceSliceToList +//@requires: ObjectHandling.c::TupleAndListFromArray + +// Note that this should also work fine on lists (if needed) +// Indices must be postive - there's no wraparound or boundschecking + +static PyObject *__Pyx_MatchCase_TupleSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end) { +#if !CYTHON_COMPILING_IN_CPYTHON + return __Pyx_MatchCase_OtherSequenceSliceToList(x, start, end); +#else + PyObject **array; + + (void)__Pyx_MatchCase_OtherSequenceSliceToList; // clear unused warning + + array = PySequence_Fast_ITEMS(x); + return __Pyx_PyList_FromArray(array+start, end-start); +#endif +} + +////////////////////////// UnknownTypeSliceToList.proto ////////////////////// + +static PyObject *__Pyx_MatchCase_UnknownTypeSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end); /* proto */ + +////////////////////////// UnknownTypeSliceToList.proto ////////////////////// +//@requires: TupleSliceToList +//@requires: OtherSequenceSliceToList + +static PyObject *__Pyx_MatchCase_UnknownTypeSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end) { + if (PyList_CheckExact(x)) { + return PyList_GetSlice(x, start, end); + } +#if !CYTHON_COMPILING_IN_CPYTHON + // since __Pyx_MatchCase_TupleToList only does anything special in CPython, skip the check otherwise + if (PyTuple_CheckExact(x)) { + return __Pyx_MatchCase_TupleSliceToList(x, start, end); + } +#else + (void)__Pyx_MatchCase_TupleSliceToList; +#endif + return __Pyx_MatchCase_OtherSequenceSliceToList(x, start, end); +} diff --git a/Cython/Utility/MatchCase_Cy.pyx b/Cython/Utility/MatchCase_Cy.pyx new file mode 100644 index 000000000..dbb478ffe --- /dev/null +++ b/Cython/Utility/MatchCase_Cy.pyx @@ -0,0 +1,12 @@ +################### MemoryviewSliceToList ####################### + +cimport cython + +@cname("__Pyx_MatchCase_SliceMemoryview_{{suffix}}") +cdef list slice_to_list({{decl_code}} x, Py_ssize_t start, Py_ssize_t stop): + if stop < 0: + # use -1 as a flag for "end" + stop = x.shape[0] + # This code performs slightly better than [ xi for xi in x ] + with cython.boundscheck(False), cython.wraparound(False): + return [ x[i] for i in range(start, stop) ] diff --git a/tests/run/extra_patma.pyx b/tests/run/extra_patma.pyx index b2303f45b..6ff6a48a2 100644 --- a/tests/run/extra_patma.pyx +++ b/tests/run/extra_patma.pyx @@ -1,5 +1,15 @@ # mode: run +# Extra pattern matching test for Cython-specific features, optimizations, etc. + +cimport cython + +import array +import sys + +__doc__ = "" + + cdef bint is_null(int* x): return False # disabled - currently just a parser test match x: @@ -8,6 +18,7 @@ cdef bint is_null(int* x): case _: return False + def test_is_null(): """ >>> test_is_null() @@ -16,3 +27,78 @@ def test_is_null(): return # disabled - currently just a parser test assert is_null(&some_int) == False assert is_null(NULL) == True + + +if sys.version_info[0] > 2: + __doc__ += """ + array.array doesn't have the buffer protocol in Py2 and + this doesn't really feel worth working around to test + >>> print(test_memoryview(array.array('i', [0, 1, 2]))) + a 1 + >>> print(test_memoryview(array.array('i', []))) + b + >>> print(test_memoryview(array.array('i', [5]))) + c [5] + """ + +# goes via .shape instead +@cython.test_fail_if_path_exists("//CallNode//NameNode[@name = 'len']") +# No need for "is Sequence check" +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +def test_memoryview(int[:] x): + """ + >>> print(test_memoryview(None)) + no! + """ + match x: + case [0, y, 2]: + assert cython.typeof(y) == "int", cython.typeof(y) # type inference works + return f"a {y}" + case []: + return "b" + case [*z]: + return f"c {z}" + return "no!" + +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +def test_list_to_sequence(list x): + """ + >>> test_list_to_sequence([1,2,3]) + True + >>> test_list_to_sequence(None) + False + """ + match x: + case [*_]: + return True + case _: + return False + + +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +@cython.test_fail_if_path_exists("//CmpNode") # There's nothing to compare - it always succeeds! +def test_list_not_None_to_sequence(list x not None): + """ + >>> test_list_not_None_to_sequence([1,2,3]) + True + """ + match x: + case [*_]: + return True + case _: + return False + +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +@cython.test_fail_if_path_exists("//CmpNode") # There's nothing to compare - it always succeeds! +def test_ctuple_to_sequence((int, int) x): + """ + >>> test_ctuple_to_sequence((1, 2)) + (1, 2) + """ + match x: + case [a, b, c]: # can't possibly succeed! + return a, b, c + case [a, b]: + assert cython.typeof(a) == "int", cython.typeof(a) # test that types have inferred + return a, b + diff --git a/tests/run/extra_patma_py.py b/tests/run/extra_patma_py.py new file mode 100644 index 000000000..ea4abb5b3 --- /dev/null +++ b/tests/run/extra_patma_py.py @@ -0,0 +1,20 @@ +# mode: run +# tag: pure3.10 + +import array + +def test_array_is_sequence(x): + """ + Because this has to be specifically special-cased on early Python versions + >>> test_array_is_sequence(array.array('i', [0, 1, 2])) + 1 + >>> test_array_is_sequence(array.array('i', [0, 1, 2, 3, 4])) + [0, 1, 2, 3, 4] + """ + match x: + case [0, y, 2]: + return y + case [*z]: + return z + case _: + return "Not a sequence" diff --git a/tests/run/test_patma.py b/tests/run/test_patma.py index 511fc48cd..aae22c18c 100644 --- a/tests/run/test_patma.py +++ b/tests/run/test_patma.py @@ -279,14 +279,12 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 2) def test_patma_010(self): - return # disabled match (): case []: x = 0 self.assertEqual(x, 0) def test_patma_011(self): - return # disabled match (0, 1, 2): case [*x]: y = 0 @@ -294,7 +292,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_012(self): - return # disabled match (0, 1, 2): case [0, *x]: y = 0 @@ -302,7 +299,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_013(self): - return # disabled match (0, 1, 2): case [0, 1, *x,]: y = 0 @@ -310,7 +306,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_014(self): - return # disabled match (0, 1, 2): case [0, 1, 2, *x]: y = 0 @@ -318,7 +313,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_015(self): - return # disabled match (0, 1, 2): case [*x, 2,]: y = 0 @@ -326,7 +320,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_016(self): - return # disabled match (0, 1, 2): case [*x, 1, 2]: y = 0 @@ -334,7 +327,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_017(self): - return # disabled match (0, 1, 2): case [*x, 0, 1, 2,]: y = 0 @@ -342,7 +334,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_018(self): - return # disabled match (0, 1, 2): case [0, *x, 2]: y = 0 @@ -350,7 +341,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_019(self): - return # disabled match (0, 1, 2): case [0, 1, *x, 2,]: y = 0 @@ -358,7 +348,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_020(self): - return # disabled match (0, 1, 2): case [0, *x, 1, 2]: y = 0 @@ -366,7 +355,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_021(self): - return # disabled match (0, 1, 2): case [*x,]: y = 0 @@ -608,7 +596,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_044(self): - return # disabled x = () match x: case []: @@ -617,7 +604,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_045(self): - return # disabled x = () match x: case (): @@ -626,7 +612,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_046(self): - return # disabled x = (0,) match x: case [0]: @@ -635,7 +620,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_047(self): - return # disabled x = ((),) match x: case [[]]: @@ -686,7 +670,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_052(self): - return # disabled x = [1, 0] match x: case [0]: @@ -699,7 +682,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 2) def test_patma_053(self): - return # disabled x = {0} y = None match x: @@ -709,7 +691,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_054(self): - return # disabled x = set() y = None match x: @@ -719,7 +700,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_055(self): - return # disabled x = iter([1, 2, 3]) y = None match x: @@ -729,7 +709,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_056(self): - return # disabled x = {} y = None match x: @@ -739,7 +718,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_057(self): - return # disabled x = {0: False, 1: True} y = None match x: @@ -919,7 +897,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_075(self): - return # disabled x = "x" match x: case ["x"]: @@ -930,7 +907,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 1) def test_patma_076(self): - return # disabled x = b"x" match x: case [b"x"]: @@ -945,7 +921,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 4) def test_patma_077(self): - return # disabled x = bytearray(b"x") y = None match x: @@ -957,7 +932,6 @@ class TestPatma(unittest.TestCase): self.assertIs(y, None) def test_patma_078(self): - return # disabled x = "" match x: case []: @@ -970,7 +944,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 2) def test_patma_079(self): - return # disabled x = "xxx" match x: case ["x", "x", "x"]: @@ -983,7 +956,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 2) def test_patma_080(self): - return # disabled x = b"xxx" match x: case [120, 120, 120]: @@ -1345,7 +1317,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(x, 0) def test_patma_118(self): - return # disabled x = [] match x: case [*_, _]: @@ -1379,14 +1350,12 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, {}) def test_patma_121(self): - return # disabled match (): case (): x = 0 self.assertEqual(x, 0) def test_patma_122(self): - return # disabled match (0, 1, 2): case (*x,): y = 0 @@ -1394,7 +1363,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_123(self): - return # disabled match (0, 1, 2): case 0, *x: y = 0 @@ -1402,7 +1370,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_124(self): - return # disabled match (0, 1, 2): case (0, 1, *x,): y = 0 @@ -1410,7 +1377,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_125(self): - return # disabled match (0, 1, 2): case 0, 1, 2, *x: y = 0 @@ -1418,7 +1384,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_126(self): - return # disabled match (0, 1, 2): case *x, 2,: y = 0 @@ -1426,7 +1391,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_127(self): - return # disabled match (0, 1, 2): case (*x, 1, 2): y = 0 @@ -1434,7 +1398,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_128(self): - return # disabled match (0, 1, 2): case *x, 0, 1, 2,: y = 0 @@ -1442,7 +1405,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_129(self): - return # disabled match (0, 1, 2): case (0, *x, 2): y = 0 @@ -1450,7 +1412,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_130(self): - return # disabled match (0, 1, 2): case 0, 1, *x, 2,: y = 0 @@ -1458,7 +1419,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_131(self): - return # disabled match (0, 1, 2): case (0, *x, 1, 2): y = 0 @@ -1466,7 +1426,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_132(self): - return # disabled match (0, 1, 2): case *x,: y = 0 @@ -1676,7 +1635,6 @@ class TestPatma(unittest.TestCase): self.assertIs(z, x) def test_patma_151(self): - return # disabled x = 0 match x,: case y,: @@ -1686,7 +1644,6 @@ class TestPatma(unittest.TestCase): self.assertIs(z, 0) def test_patma_152(self): - return # disabled w = 0 x = 0 match w, x: @@ -1699,7 +1656,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(v, 0) def test_patma_153(self): - return # disabled x = 0 match w := x,: case y as v,: @@ -2078,7 +2034,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(whereis(Point("X", "x")), "Not on the diagonal") def test_patma_184(self): - return # disabled class Seq(collections.abc.Sequence): __getitem__ = None def __len__(self): @@ -2089,7 +2044,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_185(self): - return # disabled class Seq(collections.abc.Sequence): __getitem__ = None def __len__(self): @@ -2100,7 +2054,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 0) def test_patma_186(self): - return # disabled class Seq(collections.abc.Sequence): def __getitem__(self, i): return i @@ -2114,7 +2067,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_187(self): - return # disabled w = range(10) match w: case [x, y, *rest]: @@ -2126,7 +2078,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(rest, list(range(2, 10))) def test_patma_188(self): - return # disabled w = range(100) match w: case (x, y, *rest): @@ -2138,7 +2089,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(rest, list(range(2, 100))) def test_patma_189(self): - return # disabled w = range(1000) match w: case x, y, *rest: @@ -2150,7 +2100,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(rest, list(range(2, 1000))) def test_patma_190(self): - return # disabled w = range(1 << 10) match w: case [x, y, *_]: @@ -2161,7 +2110,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_191(self): - return # disabled w = range(1 << 20) match w: case (x, y, *_): @@ -2172,7 +2120,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_192(self): - return # disabled w = range(1 << 30) match w: case x, y, *_: @@ -2431,7 +2378,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(f((1, 2)), {}) def test_patma_210(self): - return # disabled def f(w): match w: case (x, y, z): @@ -2472,7 +2418,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(f(Point(42, "hello")), {"xx": 42}) def test_patma_213(self): - return # disabled def f(w): match w: case (p, q) as x: @@ -2513,7 +2458,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(set(f()), {"abc"}) def test_patma_218(self): - return # disabled def f(): match ..., ...: case a, b: @@ -2616,14 +2560,12 @@ class TestPatma(unittest.TestCase): self.assertIs(f(3), None) def test_patma_228(self): - return # disabled match(): case(): x = 0 self.assertEqual(x, 0) def test_patma_229(self): - return # disabled x = 0 match(x): case(x): @@ -2708,7 +2650,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_238(self): - return # disabled x = ((0, 1), (2, 3)) match x: case ((a as b, c as d) as e) as w, ((f as g, h) as i) as z: @@ -2757,7 +2698,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_242(self): - return # disabled x = range(3) match x: case [y, *_, z]: @@ -2768,7 +2708,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 2) def test_patma_243(self): - return # disabled x = range(3) match x: case [_, *_, y]: @@ -2778,7 +2717,6 @@ class TestPatma(unittest.TestCase): self.assertEqual(z, 0) def test_patma_244(self): - return # disabled x = range(3) match x: case [*_, y]: |