diff options
author | da-woods <dw-git@d-woods.co.uk> | 2022-12-08 21:37:49 +0000 |
---|---|---|
committer | da-woods <dw-git@d-woods.co.uk> | 2022-12-08 21:37:49 +0000 |
commit | f0fab87c193ba3acd17010cc07183583b50987b6 (patch) | |
tree | b8fc83f31fa6385e040cd3635130225983eb266f | |
parent | 4612175c2fdaa62a184be97a0b4c5501718a6ee3 (diff) | |
parent | 79969ec1d213a6d24ce9e76fdee6d9be9dc8422b (diff) | |
download | cython-f0fab87c193ba3acd17010cc07183583b50987b6.tar.gz |
Merge branch 'match-or' into patma-preview
-rw-r--r-- | Cython/Compiler/MatchCaseNodes.py | 373 | ||||
-rw-r--r-- | Cython/Compiler/Nodes.py | 4 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.pxd | 1 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 13 | ||||
-rw-r--r-- | Cython/Compiler/Parsing.pxd | 2 | ||||
-rw-r--r-- | Cython/Compiler/Parsing.py | 271 | ||||
-rw-r--r-- | Cython/TestUtils.py | 24 | ||||
-rw-r--r-- | Cython/Utility/MatchCase.c | 475 | ||||
-rw-r--r-- | Tools/ci-run.sh | 2 | ||||
-rw-r--r-- | test-requirements-pypy27.txt | 1 | ||||
-rw-r--r-- | tests/run/extra_patma.pyx | 45 | ||||
-rw-r--r-- | tests/run/extra_patma_py.py | 35 | ||||
-rw-r--r-- | tests/run/test_patma.py | 105 |
13 files changed, 811 insertions, 540 deletions
diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py index 833520642..a7e246bcb 100644 --- a/Cython/Compiler/MatchCaseNodes.py +++ b/Cython/Compiler/MatchCaseNodes.py @@ -1,9 +1,8 @@ # Nodes for structural pattern matching. # -# In a separate file because they're unlikely to be useful -# for much else +# In a separate file because they're unlikely to be useful for much else. -from .Nodes import Node, StatNode +from .Nodes import Node, StatNode, ErrorNode from .Errors import error, local_errors, report_error from . import Nodes, ExprNodes, PyrexTypes, Builtin from .Code import UtilityCode, TempitaUtilityCode @@ -28,7 +27,11 @@ class MatchNode(StatNode): def validate_irrefutable(self): found_irrefutable_case = None - for c in self.cases: + for case in self.cases: + if isinstance(case, ErrorNode): + # This validation happens before error nodes have been + # transformed into actual errors, so we need to ignore them + continue if found_irrefutable_case: error( found_irrefutable_case.pos, @@ -38,9 +41,9 @@ class MatchNode(StatNode): ), ) break - if c.is_irrefutable(): - found_irrefutable_case = c - c.validate_irrefutable() + if case.is_irrefutable(): + found_irrefutable_case = case + case.validate_irrefutable() def refactor_cases(self): # An early transform - changes cases that can be represented as @@ -158,6 +161,8 @@ class MatchCaseNode(Node): child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"] def is_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return True # value doesn't really matter return self.pattern.is_irrefutable() and not self.guard def is_simple_value_comparison(self): @@ -166,9 +171,13 @@ class MatchCaseNode(Node): return self.pattern.is_simple_value_comparison() def validate_targets(self): + if isinstance(self.pattern, ErrorNode): + return self.pattern.get_targets() def validate_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return self.pattern.validate_irrefutable() def is_sequence_or_mapping(self): @@ -263,7 +272,7 @@ class SubstitutedMatchCaseNode(MatchCaseBaseNode): class PatternNode(Node): """ - DW decided that PatternNode shouldn't be an expression because + PatternNode is not an expression because it does several things (evalutating a boolean expression, assignment of targets), and they need to be done at different times. @@ -305,9 +314,9 @@ class PatternNode(Node): child_attrs = ["as_targets"] def __init__(self, pos, **kwds): + if "as_targets" not in kwds: + kwds["as_targets"] = [] super(PatternNode, self).__init__(pos, **kwds) - if not hasattr(self, "as_targets"): - self.as_targets = [] def is_irrefutable(self): return False @@ -322,14 +331,13 @@ class PatternNode(Node): def get_targets(self): targets = self.get_main_pattern_targets() - for t in self.as_targets: - self.add_target_to_targets(targets, t.name) + for target in self.as_targets: + self.add_target_to_targets(targets, target.name) return targets def update_targets_with_targets(self, targets, other_targets): - intersection = targets.intersection(other_targets) - for i in intersection: - error(self.pos, "multiple assignments to name '%s' in pattern" % i) + for name in targets.intersection(other_targets): + error(self.pos, "multiple assignments to name '%s' in pattern" % name) targets.update(other_targets) def add_target_to_targets(self, targets, target): @@ -362,7 +370,7 @@ class PatternNode(Node): def validate_irrefutable(self): for attr in self.child_attrs: child = getattr(self, attr) - if isinstance(child, PatternNode): + if child is not None and isinstance(child, PatternNode): child.validate_irrefutable() def analyse_pattern_expressions(self, env, sequence_mapping_temp): @@ -516,9 +524,9 @@ class OrPatternNode(PatternNode): child_attrs = PatternNode.child_attrs + ["alternatives"] def get_first_irrefutable(self): - for a in self.alternatives: - if a.is_irrefutable(): - return a + for alternative in self.alternatives: + if alternative.is_irrefutable(): + return alternative return None def is_irrefutable(self): @@ -537,17 +545,17 @@ class OrPatternNode(PatternNode): def get_main_pattern_targets(self): child_targets = None - for ch in self.alternatives: - ch_targets = ch.get_targets() - if child_targets is not None and child_targets != ch_targets: + for alternative in self.alternatives: + alternative_targets = alternative.get_targets() + if child_targets is not None and child_targets != alternative_targets: error(self.pos, "alternative patterns bind different names") - child_targets = ch_targets + child_targets = alternative_targets return child_targets def validate_irrefutable(self): super(OrPatternNode, self).validate_irrefutable() found_irrefutable_case = None - for a in self.alternatives: + for alternative in self.alternatives: if found_irrefutable_case: error( found_irrefutable_case.pos, @@ -557,9 +565,9 @@ class OrPatternNode(PatternNode): ), ) break - if a.is_irrefutable(): - found_irrefutable_case = a - a.validate_irrefutable() + if alternative.is_irrefutable(): + found_irrefutable_case = alternative + alternative.validate_irrefutable() def is_simple_value_comparison(self): return all( @@ -739,10 +747,10 @@ class MatchSequencePatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() star_count = 0 - for p in self.patterns: - if p.is_match_and_assign_pattern and p.is_star: + for pattern in self.patterns: + if pattern.is_match_and_assign_pattern and pattern.is_star: star_count += 1 - self.update_targets_with_targets(targets, p.get_targets()) + self.update_targets_with_targets(targets, pattern.get_targets()) if star_count > 1: error(self.pos, "multiple starred names in sequence pattern") return targets @@ -1060,30 +1068,32 @@ class MatchMappingPatternNode(PatternNode): ], exception_value="-1", ) + # lie about the types of keys for simplicity Pyx_mapping_check_duplicates_type = PyrexTypes.CFuncType( PyrexTypes.c_int_type, [ - PyrexTypes.CFuncTypeArg("fixed_keys", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("var_keys", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None), ], exception_value="-1", ) + # lie about the types of keys and subjects for simplicity Pyx_mapping_extract_subjects_type = PyrexTypes.CFuncType( PyrexTypes.c_bint_type, [ - PyrexTypes.CFuncTypeArg("map", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("fixed_keys", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("var_keys", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("mapping", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("subjects", PyrexTypes.c_void_ptr_ptr_type, None), ], exception_value="-1", - has_varargs=True, ) Pyx_mapping_doublestar_type = PyrexTypes.CFuncType( Builtin.dict_type, [ - PyrexTypes.CFuncTypeArg("map", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("fixed_keys", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("var_keys", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("mapping", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None), ], ) @@ -1092,8 +1102,8 @@ class MatchMappingPatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() - for p in self.value_patterns: - self.update_targets_with_targets(targets, p.get_targets()) + for pattern in self.value_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) if self.double_star_capture_target: self.add_target_to_targets(targets, self.double_star_capture_target.name) return targets @@ -1203,8 +1213,10 @@ class MatchMappingPatternNode(PatternNode): self.pos, arg=subject_node, fallback=call, check=self.is_dict_type_check ) - def make_duplicate_keys_check(self, static_keys_tuple, var_keys_tuple): + def make_duplicate_keys_check(self, n_fixed_keys): utility_code = UtilityCode.load_cached("MappingKeyCheck", "MatchCase.c") + if n_fixed_keys == len(self.keys): + return None # nothing to check return Nodes.ExprStatNode( self.pos, @@ -1213,11 +1225,15 @@ class MatchMappingPatternNode(PatternNode): "__Pyx_MatchCase_CheckMappingDuplicateKeys", self.Pyx_mapping_check_duplicates_type, utility_code=utility_code, - args=[static_keys_tuple.clone_node(), var_keys_tuple], + args=[ + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode(self.pos, value=str(n_fixed_keys)), + ExprNodes.IntNode(self.pos, value=str(len(self.keys))) + ], ), ) - def check_all_keys(self, subject_node, const_keys_tuple, var_keys_tuple): + def check_all_keys(self, subject_node): # It's debatable here whether to go for individual unpacking or a function. # Current implementation is a function that's loosely copied from CPython. # For small numbers of keys it might be better to generate the code instead. @@ -1243,24 +1259,23 @@ class MatchMappingPatternNode(PatternNode): util_code = UtilityCode.load_cached("ExtractGeneric", "MatchCase.c") func_name = "__Pyx_MatchCase_Mapping_Extract" - subject_derefs = [ - ExprNodes.NullNode(self.pos) - if t is None - else AddressOfPyObjectNode(self.pos, obj=t) - for t in self.subject_temps - ] return ExprNodes.PythonCapiCallNode( self.pos, func_name, self.Pyx_mapping_extract_subjects_type, utility_code=util_code, - args=[subject_node, const_keys_tuple.clone_node(), var_keys_tuple] - + subject_derefs, + args=[ + subject_node, + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode( + self.pos, + value=str(len(self.keys)) + ), + MappingOrClassComparisonNode.make_subjects_node(self.pos), + ], ) - def make_double_star_capture( - self, subject_node, const_tuple, var_tuple, test_result - ): + def make_double_star_capture(self, subject_node, test_result): # test_result being the variable that holds "case check passed until now" is_dict = self.is_dict_type_check(subject_node.type) if is_dict: @@ -1277,7 +1292,11 @@ class MatchMappingPatternNode(PatternNode): "__Pyx_MatchCase_DoubleStarCapture" + tag, self.Pyx_mapping_doublestar_type, utility_code=utility_code, - args=[subject_node, const_tuple, var_tuple], + args=[ + subject_node, + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode(self.pos, value=str(len(self.keys))) + ], ) assignment = Nodes.SingleAssignmentNode( self.double_star_capture_target.pos, lhs=self.double_star_temp, rhs=func @@ -1294,22 +1313,17 @@ class MatchMappingPatternNode(PatternNode): def get_comparison_node(self, subject_node, sequence_mapping_temp=None): from . import UtilNodes - const_keys = [] var_keys = [] + n_literal_keys = 0 for k in self.keys: - if not k.arg.is_literal: - k = UtilNodes.ResultRefNode(k, is_temp=False) + if not k.is_literal: var_keys.append(k) else: - const_keys.append(k.arg.clone_node()) - const_keys_tuple = ExprNodes.TupleNode(self.pos, args=const_keys) - var_keys_tuple = ExprNodes.TupleNode(self.pos, args=var_keys) - if var_keys: - var_keys_tuple = UtilNodes.ResultRefNode(var_keys_tuple, is_temp=True) + n_literal_keys += 1 all_tests = [] all_tests.append(self.make_mapping_check(subject_node, sequence_mapping_temp)) - all_tests.append(self.check_all_keys(subject_node, const_keys_tuple, var_keys_tuple)) + all_tests.append(self.check_all_keys(subject_node)) if any(isinstance(test, ExprNodes.BoolNode) and not test.value for test in all_tests): # identify automatic-failure @@ -1324,10 +1338,10 @@ class MatchMappingPatternNode(PatternNode): all_tests = generate_binop_tree_from_list(self.pos, "and", all_tests) test_result = UtilNodes.ResultRefNode(pos=self.pos, type=PyrexTypes.c_bint_type) + duplicate_check = self.make_duplicate_keys_check(n_literal_keys) body = Nodes.StatListNode( self.pos, - stats=[ - self.make_duplicate_keys_check(const_keys_tuple, var_keys_tuple), + stats=([duplicate_check] if duplicate_check else []) + [ Nodes.SingleAssignmentNode(self.pos, lhs=test_result, rhs=all_tests), ], ) @@ -1335,21 +1349,21 @@ class MatchMappingPatternNode(PatternNode): assert self.double_star_temp body.stats.append( # make_double_star_capture wraps itself in an if - self.make_double_star_capture( - subject_node, const_keys_tuple, var_keys_tuple, test_result - ) + self.make_double_star_capture(subject_node, test_result) ) - if var_keys or self.double_star_capture_target: + if duplicate_check or self.double_star_capture_target: body = UtilNodes.TempResultFromStatNode(test_result, body) - if var_keys: - body = UtilNodes.EvalWithTempExprNode(var_keys_tuple, body) - for k in var_keys: - if isinstance(k, UtilNodes.ResultRefNode): - body = UtilNodes.EvalWithTempExprNode(k, body) - return LazyCoerceToBool(body.pos, arg=body) else: - return LazyCoerceToBool(all_tests.pos, arg=all_tests) + body = all_tests + if self.keys or self.double_star_capture_target: + body = MappingOrClassComparisonNode( + body.pos, + arg=LazyCoerceToBool(body.pos, arg=body), + keys_array=self.keys, + subjects_array=self.subject_temps + ) + return LazyCoerceToBool(body.pos, arg=body) def analyse_pattern_expressions(self, env, sequence_mapping_temp): def to_temp_or_literal(node): @@ -1359,7 +1373,7 @@ class MatchMappingPatternNode(PatternNode): return node.coerce_to_temp(env) self.keys = [ - ExprNodes.ProxyNode(to_temp_or_literal(k.analyse_expressions(env))) + to_temp_or_literal(k.analyse_expressions(env)) for k in self.keys ] @@ -1410,16 +1424,19 @@ class ClassPatternNode(PatternNode): keyword_pattern_names = [] keyword_pattern_patterns = [] + # as with the mapping functions, lie a little about some of the types for + # ease of declaration Pyx_positional_type = PyrexTypes.CFuncType( PyrexTypes.c_bint_type, [ PyrexTypes.CFuncTypeArg("subject", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None), - PyrexTypes.CFuncTypeArg("keysnames_tuple", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("fixed_names", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("n_fixed", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("match_self", PyrexTypes.c_int_type, None), - PyrexTypes.CFuncTypeArg("num_args", PyrexTypes.c_int_type, None), + PyrexTypes.CFuncTypeArg("subjects", PyrexTypes.c_void_ptr_ptr_type, None), + PyrexTypes.CFuncTypeArg("n_subjects", PyrexTypes.c_int_type, None), ], - has_varargs=True, exception_value="-1", ) @@ -1476,11 +1493,8 @@ class ClassPatternNode(PatternNode): def get_main_pattern_targets(self): targets = set() - for p in self.keyword_pattern_patterns: - self.update_targets_with_targets(targets, p.get_targets()) - - for p in self.positional_patterns: - self.update_targets_with_targets(targets, p.get_targets()) + for pattern in self.positional_patterns + self.keyword_pattern_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) return targets def generate_main_pattern_assignment_list(self, subject_node, env): @@ -1589,13 +1603,10 @@ class ClassPatternNode(PatternNode): def make_positional_args_call(self, subject_node, class_node): assert self.positional_patterns util_code = UtilityCode.load_cached("ClassPositionalPatterns", "MatchCase.c") - keynames = ExprNodes.TupleNode( - self.pos, - args=[ - ExprNodes.StringNode(n.pos, value=n.name) - for n in self.keyword_pattern_names - ], - ) + keynames = [ + ExprNodes.StringNode(n.pos, value=n.name) + for n in self.keyword_pattern_names + ] # -1 is "unknown" match_self = ( -1 @@ -1628,21 +1639,28 @@ class ClassPatternNode(PatternNode): match_self = 0 # I think... Relies on knowing the bases match_self = ExprNodes.IntNode(self.pos, value=str(match_self)) - len_ = ExprNodes.IntNode(self.pos, value=str(len(self.positional_patterns))) - subject_derefs = [ - ExprNodes.NullNode(self.pos) - if t is None - else AddressOfPyObjectNode(self.pos, obj=t) - for t in self.positional_subject_temps - ] - return ExprNodes.PythonCapiCallNode( + n_subjects = ExprNodes.IntNode(self.pos, value=str(len(self.positional_patterns))) + return MappingOrClassComparisonNode( self.pos, - "__Pyx_MatchCase_ClassPositional", - self.Pyx_positional_type, - utility_code=util_code, - args=[subject_node, class_node, keynames, match_self, len_] - + subject_derefs, + arg=ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_ClassPositional", + self.Pyx_positional_type, + utility_code=util_code, + args=[ + subject_node, + class_node, + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode(self.pos, value=str(len(keynames))), + match_self, + MappingOrClassComparisonNode.make_subjects_node(self.pos), + n_subjects, + ] + ), + subjects_array=self.positional_subject_temps, + keys_array=keynames, ) + return def make_subpattern_checks(self): patterns = self.keyword_pattern_patterns + self.positional_patterns @@ -2057,27 +2075,6 @@ class CompilerDirectivesExprNode(ExprNodes.ProxyNode): self.arg.annotate(code) -class AddressOfPyObjectNode(ExprNodes.ExprNode): - """ - obj - some temp node - """ - - type = PyrexTypes.c_void_ptr_ptr_type - is_temp = False - subexprs = ["obj"] - - def analyse_types(self, env): - self.obj = self.obj.analyse_types(env) - assert self.obj.type.is_pyobject, repr(self.obj.type) - return self - - def generate_result_code(self, code): - self.obj.generate_result_code(code) - - def calculate_result_code(self): - return "&%s" % self.obj.result() - - class LazyCoerceToPyObject(ExprNodes.ExprNode): """ Just calls "self.arg.coerce_to_pyobject" when it's analysed, @@ -2127,4 +2124,122 @@ def generate_binop_tree_from_list(pos, operator, list_of_tests): operator=operator, operand1=operand1, operand2=operand2 - )
\ No newline at end of file + ) + + +class MappingOrClassComparisonNode(ExprNodes.ExprNode): + """ + Combined with MappingOrClassComparisonNodeInner this is responsible + for setting up up the arrays of subjects and keys that are used in + the function calls that handle these types of patterns + + Note that self.keys_array is owned by this but used by + MappingOrClassComparisonNodeInner - that's mainly to ensure that + it gets evaluated in the correct order + """ + subexprs = ["keys_array", "inner"] + + keys_array_cname = "__pyx_match_mapping_keys" + subjects_array_cname = "__pyx_match_mapping_subjects" + + @property + def type(self): + return self.inner.type + + @classmethod + def make_keys_node(cls, pos): + return ExprNodes.RawCNameExprNode( + pos, + type=PyrexTypes.c_void_ptr_type, + cname=cls.keys_array_cname + ) + + @classmethod + def make_subjects_node(cls, pos): + return ExprNodes.RawCNameExprNode( + pos, + type=PyrexTypes.c_void_ptr_ptr_type, + cname=cls.subjects_array_cname + ) + + def __init__(self, pos, arg, subjects_array, **kwds): + super(MappingOrClassComparisonNode, self).__init__(pos, **kwds) + self.inner = MappingOrClassComparisonNodeInner( + pos, + arg=arg, + keys_array = self.keys_array, + subjects_array = subjects_array + ) + + def analyse_types(self, env): + self.inner = self.inner.analyse_types(env) + self.keys_array = [ + key.analyse_types(env).coerce_to_simple(env) for key in self.keys_array + ] + return self + + def generate_result_code(self, code): + pass + + def calculate_result_code(self): + return self.inner.calculate_result_code() + + +class MappingOrClassComparisonNodeInner(ExprNodes.ExprNode): + """ + Sets up the arrays of subjects and keys + + Created by the constructor of MappingComparisonNode + (no need to create directly) + + has attributes: + * arg - the main comparison node + * keys_array - list of ExprNodes representing keys + * subjects_array - list of ExprNodes representing subjects + """ + subexprs = ['arg'] + + @property + def type(self): + return self.arg.type + + def analyse_types(self, env): + self.arg = self.arg.analyse_types(env) + for n in range(len(self.keys_array)): + key = self.keys_array[n].analyse_types(env) + key = key.coerce_to_pyobject(env) + self.keys_array[n] = key + assert self.arg.type is PyrexTypes.c_bint_type + return self + + def generate_evaluation_code(self, code): + code.putln("{") + keys_str = ", ".join(k.result() for k in self.keys_array) + if not keys_str: + # GCC gets worried about overflow if we pass + # a genuinely empty array + keys_str = "NULL" + code.putln("PyObject *%s[] = {%s};" % ( + MappingOrClassComparisonNode.keys_array_cname, + keys_str, + )) + subjects_str = ", ".join( + "&"+subject.result() if subject is not None else "NULL" for subject in self.subjects_array + ) + if not subjects_str: + # GCC gets worried about overflow if we pass + # a genuinely empty array + subjects_str = "NULL" + code.putln("PyObject **%s[] = {%s};" % ( + MappingOrClassComparisonNode.subjects_array_cname, + subjects_str + )) + super(MappingOrClassComparisonNodeInner, self).generate_evaluation_code(code) + + code.putln("}") + + def generate_result_code(self, code): + pass + + def calculate_result_code(self): + return self.arg.result()
\ No newline at end of file diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 49c99b2c3..006d1023c 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -10156,13 +10156,13 @@ class CnameDecoratorNode(StatNode): class ErrorNode(Node): """ - Node type for things that we want to get throught the parser + Node type for things that we want to get through the parser (especially for things that are being scanned in "tentative_scan" blocks), but should immediately raise and error afterwards. what str """ - pass + child_attrs = [] #------------------------------------------------------------------------------------ diff --git a/Cython/Compiler/ParseTreeTransforms.pxd b/Cython/Compiler/ParseTreeTransforms.pxd index efbb14f70..2778be4ef 100644 --- a/Cython/Compiler/ParseTreeTransforms.pxd +++ b/Cython/Compiler/ParseTreeTransforms.pxd @@ -18,6 +18,7 @@ cdef class PostParse(ScopeTrackingTransform): cdef dict specialattribute_handlers cdef size_t lambda_counter cdef size_t genexpr_counter + cdef bint in_pattern_node cdef _visit_assignment_node(self, node, list expr_list) diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 414bae04d..301d93335 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -193,6 +193,7 @@ class PostParse(ScopeTrackingTransform): self.specialattribute_handlers = { '__cythonbufferdefaults__' : self.handle_bufferdefaults } + self.in_pattern_node = False def visit_LambdaNode(self, node): # unpack a lambda expression into the corresponding DefNode @@ -399,6 +400,18 @@ class PostParse(ScopeTrackingTransform): self.visitchildren(node) return node + def visit_PatternNode(self, node): + in_pattern_node, self.in_pattern_node = self.in_pattern_node, True + self.visitchildren(node) + self.in_pattern_node = in_pattern_node + return node + + def visit_JoinedStrNode(self, node): + if self.in_pattern_node: + error(node.pos, "f-strings are not accepted for pattern matching") + self.visitchildren(node) + return node + class _AssignmentExpressionTargetNameFinder(TreeVisitor): def __init__(self): super(_AssignmentExpressionTargetNameFinder, self).__init__() diff --git a/Cython/Compiler/Parsing.pxd b/Cython/Compiler/Parsing.pxd index 72a855fd4..fc3e2749f 100644 --- a/Cython/Compiler/Parsing.pxd +++ b/Cython/Compiler/Parsing.pxd @@ -62,6 +62,8 @@ cdef expect_ellipsis(PyrexScanner s) cdef make_slice_nodes(pos, subscripts) cpdef make_slice_node(pos, start, stop = *, step = *) cdef p_atom(PyrexScanner s) +cdef p_atom_string(PyrexScanner s) +cdef p_atom_ident_constants(PyrexScanner s, bint bools_are_pybool = *) @cython.locals(value=unicode) cdef p_int_literal(PyrexScanner s) cdef p_name(PyrexScanner s, name) diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 02c746351..75eb194c6 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -4,7 +4,6 @@ # from __future__ import absolute_import -from ast import Expression # This should be done automatically import cython @@ -20,7 +19,7 @@ cython.declare(Nodes=object, ExprNodes=object, EncodedString=object, from io import StringIO import re import sys -from unicodedata import lookup as lookup_unicodechar, category as unicode_category, name +from unicodedata import lookup as lookup_unicodechar, category as unicode_category from functools import partial, reduce from .Scanning import PyrexScanner, FileSourceDescriptor, tentatively_scan @@ -719,36 +718,59 @@ def p_atom(s): s.next() return ExprNodes.ImagNode(pos, value = value) elif sy == 'BEGIN_STRING': - kind, bytes_value, unicode_value = p_cat_string_literal(s) - if kind == 'c': - return ExprNodes.CharNode(pos, value = bytes_value) - elif kind == 'u': - return ExprNodes.UnicodeNode(pos, value = unicode_value, bytes_value = bytes_value) - elif kind == 'b': - return ExprNodes.BytesNode(pos, value = bytes_value) - elif kind == 'f': - return ExprNodes.JoinedStrNode(pos, values = unicode_value) - elif kind == '': - return ExprNodes.StringNode(pos, value = bytes_value, unicode_value = unicode_value) - else: - s.error("invalid string kind '%s'" % kind) + return p_atom_string(s) elif sy == 'IDENT': - name = s.systring - if name == "None": - result = ExprNodes.NoneNode(pos) - elif name == "True": - result = ExprNodes.BoolNode(pos, value=True) - elif name == "False": - result = ExprNodes.BoolNode(pos, value=False) - elif name == "NULL" and not s.in_python_file: - result = ExprNodes.NullNode(pos) - else: - result = p_name(s, name) - s.next() + result = p_atom_ident_constants(s) + if result is None: + result = p_name(s, s.systring) + s.next() return result else: s.error("Expected an identifier or literal") + +def p_atom_string(s): + pos = s.position() + kind, bytes_value, unicode_value = p_cat_string_literal(s) + if kind == 'c': + return ExprNodes.CharNode(pos, value=bytes_value) + elif kind == 'u': + return ExprNodes.UnicodeNode(pos, value=unicode_value, bytes_value=bytes_value) + elif kind == 'b': + return ExprNodes.BytesNode(pos, value=bytes_value) + elif kind == 'f': + return ExprNodes.JoinedStrNode(pos, values=unicode_value) + elif kind == '': + return ExprNodes.StringNode(pos, value=bytes_value, unicode_value=unicode_value) + else: + s.error("invalid string kind '%s'" % kind) + + +def p_atom_ident_constants(s, bools_are_pybool=False): + """ + Returns None if it isn't one special-cased named constants. + Only calls s.next() if it successfully matches a matches. + """ + pos = s.position() + name = s.systring + result = None + if bools_are_pybool: + extra_kwds = {'type': Builtin.bool_type} + else: + extra_kwds = {} + if name == "None": + result = ExprNodes.NoneNode(pos) + elif name == "True": + result = ExprNodes.BoolNode(pos, value=True, **extra_kwds) + elif name == "False": + result = ExprNodes.BoolNode(pos, value=False, **extra_kwds) + elif name == "NULL" and not s.in_python_file: + result = ExprNodes.NullNode(pos) + if result: + s.next() + return result + + def p_int_literal(s): pos = s.position() value = s.systring @@ -4031,12 +4053,13 @@ def p_cpp_class_attribute(s, ctx): node.decorators = decorators return node + def p_match_statement(s, ctx): assert s.sy == "IDENT" and s.systring == "match" pos = s.position() with tentatively_scan(s) as errors: s.next() - subject = p_test(s) + subject = p_namedexpr_test(s) subjects = None if s.sy == ",": subjects = [subject] @@ -4050,6 +4073,7 @@ def p_match_statement(s, ctx): s.expect(":") if errors: return None + # at this stage were commited to it being a match block so continue # outside "with tentatively_scan" # (I think this deviates from the PEG parser slightly, and it'd @@ -4060,10 +4084,11 @@ def p_match_statement(s, ctx): while s.sy != "DEDENT": cases.append(p_case_block(s, ctx)) s.expect_dedent() - return MatchCaseNodes.MatchNode(pos, subject = subject, cases = cases) + return MatchCaseNodes.MatchNode(pos, subject=subject, cases=cases) + def p_case_block(s, ctx): - if not (s.sy=="IDENT" and s.systring == "case"): + if not (s.sy == "IDENT" and s.systring == "case"): s.error("Expected 'case'") s.next() pos = s.position() @@ -4076,8 +4101,10 @@ def p_case_block(s, ctx): return MatchCaseNodes.MatchCaseNode(pos, pattern=pattern, body=body, guard=guard) + def p_patterns(s): - # note - in slight contrast to the name, returns a single pattern + # note - in slight contrast to the name (which comes from the Python grammar), + # returns a single pattern patterns = [] seq = False pos = s.position() @@ -4089,9 +4116,9 @@ def p_patterns(s): break # all is good provided we have at least 1 pattern else: e = errors[0] - s.error(e.args[1], pos = e.args[0]) + s.error(e.args[1], pos=e.args[0]) patterns.append(pattern) - + if s.sy == ",": seq = True s.next() @@ -4099,11 +4126,13 @@ def p_patterns(s): break # common reasons to break else: break + if seq: - return MatchCaseNodes.MatchSequencePatternNode(pos, patterns = patterns) + return MatchCaseNodes.MatchSequencePatternNode(pos, patterns=patterns) else: return patterns[0] + def p_maybe_star_pattern(s): # For match case. Either star_pattern or pattern if s.sy == "*": @@ -4115,12 +4144,13 @@ def p_maybe_star_pattern(s): else: s.next() pattern = MatchCaseNodes.MatchAndAssignPatternNode( - s.position(), target = target, is_star = True + s.position(), target=target, is_star=True ) return pattern else: - p = p_pattern(s) - return p + pattern = p_pattern(s) + return pattern + def p_pattern(s): # try "as_pattern" then "or_pattern" @@ -4133,13 +4163,15 @@ def p_pattern(s): s.next() else: break + if len(patterns) > 1: pattern = MatchCaseNodes.OrPatternNode( pos, - alternatives = patterns + alternatives=patterns ) else: pattern = patterns[0] + if s.sy == 'IDENT' and s.systring == 'as': s.next() with tentatively_scan(s) as errors: @@ -4147,17 +4179,18 @@ def p_pattern(s): if errors and s.sy == "_": s.next() # make this a specific error - return Nodes.ErrorNode(errors[0].args[0], what = errors[0].args[1]) + return Nodes.ErrorNode(errors[0].args[0], what=errors[0].args[1]) elif errors: with tentatively_scan(s): expr = p_test(s) - return Nodes.ErrorNode(expr.pos, what = "Invalid pattern target") + return Nodes.ErrorNode(expr.pos, what="Invalid pattern target") s.error(errors[0]) return pattern def p_closed_pattern(s): """ + The PEG parser specifies it as | literal_pattern | capture_pattern | wildcard_pattern @@ -4166,42 +4199,49 @@ def p_closed_pattern(s): | sequence_pattern | mapping_pattern | class_pattern + + For the sake avoiding too much backtracking, we know: + * starts with "{" is a sequence_pattern + * starts with "[" is a mapping_pattern + * starts with "(" is a group_pattern or sequence_pattern + * wildcard pattern is just identifier=='_' + The rest are then tried in order with backtracking """ + if s.sy == 'IDENT' and s.systring == '_': + pos = s.position() + s.next() + return MatchCaseNodes.MatchAndAssignPatternNode(pos) + elif s.sy == '{': + return p_mapping_pattern(s) + elif s.sy == '[': + return p_sequence_pattern(s) + elif s.sy == '(': + with tentatively_scan(s) as errors: + result = p_group_pattern(s) + if not errors: + return result + return p_sequence_pattern(s) + with tentatively_scan(s) as errors: result = p_literal_pattern(s) - if not errors: - return result + if not errors: + return result with tentatively_scan(s) as errors: result = p_capture_pattern(s) - if not errors: - return result - with tentatively_scan(s) as errors: - result = p_wildcard_pattern(s) - if not errors: - return result + if not errors: + return result with tentatively_scan(s) as errors: result = p_value_pattern(s) - if not errors: - return result - with tentatively_scan(s) as errors: - result = p_group_pattern(s) - if not errors: - return result - with tentatively_scan(s) as errors: - result = p_sequence_pattern(s) - if not errors: - return result - with tentatively_scan(s) as errors: - result = p_mapping_pattern(s) - if not errors: - return result + if not errors: + return result return p_class_pattern(s) + def p_literal_pattern(s): # a lot of duplication in this function with "p_atom" next_must_be_a_number = False sign = '' - if s.sy in ['+', '-']: + if s.sy == '-': sign = s.sy sign_pos = s.position() s.next() @@ -4216,9 +4256,11 @@ def p_literal_pattern(s): elif sy == 'FLOAT': value = s.systring s.next() - res = ExprNodes.FloatNode(pos, value = value) + res = ExprNodes.FloatNode(pos, value=value) + if res and sign == "-": res = ExprNodes.UnaryMinusNode(sign_pos, operand=res) + if res and s.sy in ['+', '-']: sign = s.sy s.next() @@ -4231,61 +4273,43 @@ def p_literal_pattern(s): res = ExprNodes.binop_node( add_pos, sign, - operand1 = res, - operand2 = ExprNodes.ImagNode(s.position(), value = value) + operand1=res, + operand2=ExprNodes.ImagNode(s.position(), value=value) ) if not res and sy == 'IMAG': value = s.systring[:-1] s.next() - res = ExprNodes.ImagNode(pos, value = sign+value) + res = ExprNodes.ImagNode(pos, value=sign+value) if sign == "-": res = ExprNodes.UnaryMinusNode(sign_pos, operand=res) if res: - return MatchCaseNodes.MatchValuePatternNode(pos, value = res) + return MatchCaseNodes.MatchValuePatternNode(pos, value=res) + if next_must_be_a_number: + s.error("Expected a number") if sy == 'BEGIN_STRING': - if next_must_be_a_number: - s.error("Expected a number") - kind, bytes_value, unicode_value = p_cat_string_literal(s) - if kind == 'c': - res = ExprNodes.CharNode(pos, value = bytes_value) - elif kind == 'u': - res = ExprNodes.UnicodeNode(pos, value = unicode_value, bytes_value = bytes_value) - elif kind == 'b': - res = ExprNodes.BytesNode(pos, value = bytes_value) - elif kind == 'f': - res = Nodes.ErrorNode(pos, what = "f-strings are not accepted for pattern matching") - elif kind == '': - res = ExprNodes.StringNode(pos, value = bytes_value, unicode_value = unicode_value) - else: - s.error("invalid string kind '%s'" % kind) - return MatchCaseNodes.MatchValuePatternNode(pos, value = res) + res = p_atom_string(s) + # f-strings not being accepted is validated in PostParse + return MatchCaseNodes.MatchValuePatternNode(pos, value=res) elif sy == 'IDENT': - name = s.systring - result = None - if name == "None": - result = ExprNodes.NoneNode(pos) - elif name == "True": - result = ExprNodes.BoolNode(pos, value=True, type=Builtin.bool_type) - elif name == "False": - result = ExprNodes.BoolNode(pos, value=False, type=Builtin.bool_type) - elif name == "NULL" and not s.in_python_file: - # Included Null as an exactly matched constant here - result = ExprNodes.NullNode(pos) + # Note that p_atom_ident_constants includes NULL. + # This is a deliberate Cython addition to the pattern matching specification + result = p_atom_ident_constants(s, bools_are_pybool=True) if result: - s.next() - return MatchCaseNodes.MatchValuePatternNode(pos, value = result, is_is_check = True) + return MatchCaseNodes.MatchValuePatternNode(pos, value=result, is_is_check=True) s.error("Failed to match literal") + def p_capture_pattern(s): return MatchCaseNodes.MatchAndAssignPatternNode( s.position(), - target = p_pattern_capture_target(s) + target=p_pattern_capture_target(s) ) + def p_value_pattern(s): if s.sy != "IDENT": s.error("Expected identifier") @@ -4298,10 +4322,11 @@ def p_value_pattern(s): attr_pos = s.position() s.next() attr = p_ident(s) - res = ExprNodes.AttributeNode(attr_pos, obj = res, attribute=attr) + res = ExprNodes.AttributeNode(attr_pos, obj=res, attribute=attr) if s.sy in ['(', '=']: s.error("Unexpected symbol '%s'" % s.sy) - return MatchCaseNodes.MatchValuePatternNode(pos, value = res) + return MatchCaseNodes.MatchValuePatternNode(pos, value=res) + def p_group_pattern(s): s.expect("(") @@ -4309,12 +4334,6 @@ def p_group_pattern(s): s.expect(")") return pattern -def p_wildcard_pattern(s): - if s.sy != "IDENT" or s.systring != "_": - s.error("Expected '_'") - pos = s.position() - s.next() - return MatchCaseNodes.MatchAndAssignPatternNode(pos) def p_sequence_pattern(s): opener = s.sy @@ -4334,28 +4353,32 @@ def p_sequence_pattern(s): if s.sy == closer: break else: - if opener == ')' and len(patterns)==1: + if opener == ')' and len(patterns) == 1: s.error("tuple-like pattern of length 1 must finish with ','") break s.expect(closer) return MatchCaseNodes.MatchSequencePatternNode(pos, patterns=patterns) else: - s.error("Expected '[' or '('") + s.error("Expected '[' or '('") + def p_mapping_pattern(s): pos = s.position() s.expect('{') if s.sy == '}': + # trivial empty mapping s.next() return MatchCaseNodes.MatchMappingPatternNode(pos) + double_star_capture_target = None items_patterns = [] double_star_set_twice = None pattern_after_double_star = None + star_star_arg_pos = None while True: + if double_star_capture_target and not star_star_arg_pos: + star_star_arg_pos = s.position() if s.sy == '**': - if double_star_capture_target: - double_star_set_twice = s.position() s.next() double_star_capture_target = p_pattern_capture_target(s) else: @@ -4384,6 +4407,11 @@ def p_mapping_pattern(s): return Nodes.ErrorNode(double_star_set_twice, what = "Double star capture set twice") if pattern_after_double_star: return Nodes.ErrorNode(pattern_after_double_star, what = "pattern follows ** capture") + if star_star_arg_pos is not None: + return Nodes.ErrorNode( + star_star_arg_pos, + what = "** pattern must be the final part of a mapping pattern." + ) return MatchCaseNodes.MatchMappingPatternNode( pos, keys = [kv[0] for kv in items_patterns], @@ -4391,8 +4419,9 @@ def p_mapping_pattern(s): double_star_capture_target = double_star_capture_target ) + def p_class_pattern(s): - # name_or_attr + # start by parsing the class as name_or_attr pos = s.position() res = p_name(s, s.systring) s.next() @@ -4400,12 +4429,16 @@ def p_class_pattern(s): attr_pos = s.position() s.next() attr = p_ident(s) - res = ExprNodes.AttributeNode(attr_pos, obj = res, attribute=attr) + res = ExprNodes.AttributeNode(attr_pos, obj=res, attribute=attr) class_ = res + s.expect("(") if s.sy == ")": + # trivial case with no arguments matched s.next() return MatchCaseNodes.ClassPatternNode(pos, class_=class_) + + # parse the arguments positional_patterns = [] keyword_patterns = [] keyword_patterns_error = None @@ -4418,17 +4451,17 @@ def p_class_pattern(s): else: with tentatively_scan(s) as errors: keyword_patterns.append(p_keyword_pattern(s)) - if s.sy == ",": - s.next() - if s.sy == ")": - break - else: + if s.sy != ",": break + s.next() + if s.sy == ")": + break # Allow trailing comma. s.expect(")") + if keyword_patterns_error is not None: return Nodes.ErrorNode( keyword_patterns_error, - what = "Positional patterns follow keyword patterns" + what="Positional patterns follow keyword patterns" ) return MatchCaseNodes.ClassPatternNode( pos, class_ = class_, @@ -4437,6 +4470,7 @@ def p_class_pattern(s): keyword_pattern_patterns = [kv[1] for kv in keyword_patterns], ) + def p_keyword_pattern(s): if s.sy != "IDENT": s.error("Expected identifier") @@ -4446,6 +4480,7 @@ def p_keyword_pattern(s): value = p_pattern(s) return arg, value + def p_pattern_capture_target(s): # any name but '_', and with some constraints on what follows if s.sy != 'IDENT': diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index 8bcd26b6f..45a8e6f59 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -12,9 +12,10 @@ from functools import partial from .Compiler import Errors from .CodeWriter import CodeWriter -from .Compiler.TreeFragment import TreeFragment, strip_common_indent +from .Compiler.TreeFragment import TreeFragment, strip_common_indent, StringParseContext from .Compiler.Visitor import TreeVisitor, VisitorTransform from .Compiler import TreePath +from .Compiler.ParseTreeTransforms import PostParse class NodeTypeWriter(TreeVisitor): @@ -357,3 +358,24 @@ def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None while other_time is None or other_time >= os.path.getmtime(file_path): write_file(file_path, content, dedent=dedent, encoding=encoding) + + +def py_parse_code(code): + """ + Compiles code far enough to get errors from the parser and post-parse stage. + + Is useful for checking for syntax errors, however it doesn't generate runable + code. + """ + context = StringParseContext("test") + # all the errors we care about are in the parsing or postparse stage + try: + with Errors.local_errors() as errors: + result = TreeFragment(code, pipeline=[PostParse(context)]) + result = result.substitute() + if errors: + raise errors[0] # compile error, which should get caught + else: + return result + except Errors.CompileError as e: + raise SyntaxError(e.message_only) diff --git a/Cython/Utility/MatchCase.c b/Cython/Utility/MatchCase.c index fea08e0c1..55c8b99b4 100644 --- a/Cython/Utility/MatchCase.c +++ b/Cython/Utility/MatchCase.c @@ -1,28 +1,14 @@ ///////////////////////////// ABCCheck ////////////////////////////// #if PY_VERSION_HEX < 0x030A0000 -static int __Pyx_MatchCase_IsExactSequence(PyObject *o) { +static CYTHON_INLINE int __Pyx_MatchCase_IsExactSequence(PyObject *o) { // is one of the small list of builtin types known to be a sequence - if (PyList_CheckExact(o) || PyTuple_CheckExact(o)) { + if (PyList_CheckExact(o) || PyTuple_CheckExact(o) || + PyType_CheckExact(o, PyRange_Type) || PyType_CheckExact(o, PyMemoryView_Type)) { // Use exact type match for these checks. I in the event of inheritence we need to make sure // that it isn't a mapping too return 1; } - if (PyRange_Check(o) || PyMemoryView_Check(o)) { - // Exact check isn't possible so do exact check in another way - PyObject *mro = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__mro__"); - if (mro) { - Py_ssize_t len = PyObject_Length(mro); - Py_DECREF(mro); - if (len < 0) { - PyErr_Clear(); // doesn't really matter, just proceed with other checks - } else if (len == 2) { - return 1; // the type and "object" and no other bases - } - } else { - PyErr_Clear(); // doesn't really matter, just proceed with other checks - } - } return 0; } @@ -34,10 +20,13 @@ static CYTHON_INLINE int __Pyx_MatchCase_IsExactMapping(PyObject *o) { } static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) { - if (PyUnicode_Check(o) || PyBytes_Check(o) || PyByteArray_Check(o)) { + if (PyType_GetFlags(Py_TYPE(o)) & (Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS)) || + PyByteArray_Check(o)) { return 1; // these types are deliberately excluded from the sequence test // even though they look like sequences for most other purposes. - // They're therefore "inexact" checks + // Leave them as inexact checks since they do pass + // "isinstance(o, collections.abc.Sequence)" so it's very hard to + // reason about their subclasses } if (o == Py_None || PyLong_CheckExact(o) || PyFloat_CheckExact(o)) { return 1; @@ -73,6 +62,16 @@ static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) { #define __PYX_SEQUENCE_MAPPING_ERROR (1U<<4) // only used by the ABCCheck function #endif +static int __Pyx_MatchCase_InitAndIsInstanceAbc(PyObject *o, PyObject *abc_module, + PyObject **abc_type, PyObject *name) { + assert(!abc_type); + abc_type = PyObject_GetAttr(abc_module, name); + if (!abc_type) { + return -1; + } + return PyObject_IsInstance(o, abc_type); +} + // the result is defined using the specification for sequence_mapping_temp // (detailed in "is_sequence") static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, int definitely_not_sequence, int definitely_not_mapping) { @@ -101,12 +100,7 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in result = __PYX_DEFINITELY_SEQUENCE_FLAG; goto end; } - sequence_type = PyObject_GetAttr(abc_module, PYIDENT("Sequence")); - if (!sequence_type) { - result = __PYX_SEQUENCE_MAPPING_ERROR; - goto end; - } - sequence_result = PyObject_IsInstance(o, sequence_type); + sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence")); if (sequence_result < 0) { result = __PYX_SEQUENCE_MAPPING_ERROR; goto end; @@ -114,41 +108,32 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in result |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG; goto end; } - // else wait to see what mapping is + // else wait to see what mapping is } if (!definitely_not_mapping) { - mapping_type = PyObject_GetAttr(abc_module, PYIDENT("Mapping")); - if (!mapping_type) { + mapping_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &mapping_type, PYIDENT("Mapping")); + if (mapping_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; goto end; - } - mapping_result = PyObject_IsInstance(o, mapping_type); - } - if (mapping_result < 0) { - result = __PYX_SEQUENCE_MAPPING_ERROR; - goto end; - } else if (mapping_result == 0) { - result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG; - if (sequence_first) { - assert(sequence_result); - result |= __PYX_DEFINITELY_SEQUENCE_FLAG; - } - goto end; - } else /* mapping_result == 1 */ { - if (sequence_first && !sequence_result) { - result |= __PYX_DEFINITELY_MAPPING_FLAG; + } else if (mapping_result == 0) { + result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG; + if (sequence_first) { + assert(sequence_result); + result |= __PYX_DEFINITELY_SEQUENCE_FLAG; + } goto end; + } else /* mapping_result == 1 */ { + if (sequence_first && !sequence_result) { + result |= __PYX_DEFINITELY_MAPPING_FLAG; + goto end; + } } } if (!sequence_first) { // here we know mapping_result is true because we'd have returned otherwise assert(mapping_result); if (!definitely_not_sequence) { - sequence_type = PyObject_GetAttr(abc_module, PYIDENT("Sequence")); - if (!sequence_type) { - result = __PYX_SEQUENCE_MAPPING_ERROR; - goto end; - } - sequence_result = PyObject_IsInstance(o, sequence_type); + sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence")); } if (sequence_result < 0) { result = __PYX_SEQUENCE_MAPPING_ERROR; @@ -167,7 +152,7 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in if (!mro) { PyErr_Clear(); goto end; - } + } if (!PyTuple_Check(mro)) { Py_DECREF(mro); goto end; @@ -322,7 +307,7 @@ static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_ PyObject *list; ssizeargfunc slot; PyTypeObject *type = Py_TYPE(x); - + list = PyList_New(total); if (!list) { return NULL; @@ -454,17 +439,15 @@ static int __Pyx_MatchCase_IsMapping(PyObject *o, unsigned int *sequence_mapping #endif } -//////////////////////// DuplicateKeyCheck.proto /////////////////////// +//////////////////////// MappingKeyCheck.proto ///////////////////////// -// Returns an new reference to any duplicate key. -// NULL can indicate no duplicate keys or an error (so use PyErr_Occurred) -static PyObject* __Pyx_MatchCase_CheckDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys, Py_ssize_t n_var_keys); /*proto*/ +static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *keys[], Py_ssize_t nFixedKeys, Py_ssize_t nKeys); -//////////////////////// DuplicateKeyCheck ///////////////////////////// +//////////////////////// MappingKeyCheck /////////////////////////////// -static PyObject* __Pyx_MatchCase_CheckDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys, Py_ssize_t n_var_keys) { - // Inputs are tuples, and typically fairly small. It may be more efficient to - // loop over the tuple than create a set. +static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *keys[], Py_ssize_t nFixedKeys, Py_ssize_t nKeys) { + // Inputs are arrays, and typically fairly small. It may be more efficient to + // loop over the array than create a set. // The CPython implementation (match_keys in ceval.c) does this concurrently with // taking the keys out of the dictionary. I'm choosing to do it separately since the @@ -472,71 +455,55 @@ static PyObject* __Pyx_MatchCase_CheckDuplicateKeys(PyObject *fixed_keys, PyObje // this step completely. PyObject *var_keys_set; - PyObject *key = NULL; + PyObject *key; Py_ssize_t n; int contains; var_keys_set = PySet_New(NULL); - if (!var_keys_set) return NULL; + if (!var_keys_set) return -1; - n_var_keys = (n_var_keys < 0) ? PyTuple_GET_SIZE(var_keys) : n_var_keys; - for (n=0; n < n_var_keys; ++n) { - key = PyTuple_GET_ITEM(var_keys, n); + for (n=nFixedKeys; n < nKeys; ++n) { + key = keys[n]; contains = PySet_Contains(var_keys_set, key); if (contains < 0) { - key = NULL; - goto end; + goto bad; } else if (contains == 1) { - Py_INCREF(key); - goto end; + goto raise_error; } else { if (PySet_Add(var_keys_set, key)) { - key = NULL; - goto end; + goto bad; } } } - for (n=0; n < PyTuple_GET_SIZE(fixed_keys); ++n) { - key = PyTuple_GET_ITEM(fixed_keys, n); + for (n=0; n < nFixedKeys; ++n) { + key = keys[n]; contains = PySet_Contains(var_keys_set, key); if (contains < 0) { - key = NULL; - goto end; + goto bad; } else if (contains == 1) { - Py_INCREF(key); - goto end; + goto raise_error; } } - key = NULL; - end: Py_DECREF(var_keys_set); - return key; -} - -//////////////////////// MappingKeyCheck.proto ///////////////////////// - -static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys); /* proto */ - -//////////////////////// MappingKeyCheck /////////////////////////////// -//@requires: DuplicateKeyCheck - -static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys) { - PyObject *key = __Pyx_MatchCase_CheckDuplicateKeys(fixed_keys, var_keys, -1); - if (key) { - PyErr_Format(PyExc_ValueError, "mapping pattern checks duplicate key (%R)", key); - Py_DECREF(key); - return -1; - } else if (PyErr_Occurred()) { - return -1; - } else { - return 0; - } + return 0; + + raise_error: + #if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_ValueError, + "mapping pattern checks duplicate key (%R)", key); + #else + // DW really can't be bothered working around features that don't exist in + // Python 2, so just provide less information! + PyErr_SetString(PyExc_ValueError, + "mapping pattern checks duplicate key"); + #endif + bad: + Py_DECREF(var_keys_set); + return -1; } /////////////////////////// ExtractExactDict.proto //////////////// -#include <stdarg.h> - // the variadic arguments are a list of PyObject** to subjects to be filled. They may be NULL // in which case they're ignored. // @@ -544,48 +511,31 @@ static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *fixed_keys, PyObj #if CYTHON_REFNANNY #define __Pyx_MatchCase_Mapping_ExtractDict(...) __Pyx__MatchCase_Mapping_ExtractDict(__pyx_refnanny, __VA_ARGS__) -#define __Pyx_MatchCase_Mapping_ExtractDictV(...) __Pyx__MatchCase_Mapping_ExtractDictV(__pyx_refnanny, __VA_ARGS__) #else #define __Pyx_MatchCase_Mapping_ExtractDict(...) __Pyx__MatchCase_Mapping_ExtractDict(NULL, __VA_ARGS__) -#define __Pyx_MatchCase_Mapping_ExtractDictV(...) __Pyx__MatchCase_Mapping_ExtractDictV(NULL, __VA_ARGS__) #endif -static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, ...); /* proto */ -static int __Pyx__MatchCase_Mapping_ExtractDictV(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, va_list subjects); /* proto */ +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */ /////////////////////////// ExtractExactDict //////////////// -static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, ...) { - int result; - va_list subjects; - - va_start(subjects, var_keys); - result = __Pyx_MatchCase_Mapping_ExtractDictV(dict, fixed_keys, var_keys, subjects); - va_end(subjects); - return result; -} +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) { + Py_ssize_t i; -static int __Pyx__MatchCase_Mapping_ExtractDictV(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, va_list subjects) { - PyObject *keys[] = {fixed_keys, var_keys}; - Py_ssize_t i, j; - - for (i=0; i<2; ++i) { - PyObject *tuple = keys[i]; - for (j=0; j<PyTuple_GET_SIZE(tuple); ++j) { - PyObject *key = PyTuple_GET_ITEM(tuple, j); - PyObject **subject = va_arg(subjects, PyObject**); - if (!subject) { - int contains = PyDict_Contains(dict, key); - if (contains <= 0) { - return -1; // any subjects that were already set will be cleaned up externally - } - } else { - PyObject *value = __Pyx_PyDict_GetItemStrWithError(dict, key); - if (!value) { - return (PyErr_Occurred()) ? -1 : 0; // any subjects that were already set will be cleaned up externally - } - __Pyx_XDECREF_SET(*subject, value); - __Pyx_INCREF(*subject); // capture this incref with refnanny! + for (i=0; i<nKeys; ++i) { + PyObject *key = keys[i]; + PyObject **subject = subjects[i]; + if (!subject) { + int contains = PyDict_Contains(dict, key); + if (contains <= 0) { + return -1; // any subjects that were already set will be cleaned up externally + } + } else { + PyObject *value = __Pyx_PyDict_GetItemStrWithError(dict, key); + if (!value) { + return (PyErr_Occurred()) ? -1 : 0; // any subjects that were already set will be cleaned up externally } + __Pyx_XDECREF_SET(*subject, value); + __Pyx_INCREF(*subject); // capture this incref with refnanny! } } return 1; // success @@ -598,74 +548,74 @@ static int __Pyx__MatchCase_Mapping_ExtractDictV(void *__pyx_refnanny, PyObject // // This is a specialized version for the rarer case when the type isn't an exact dict. -#include <stdarg.h> - #if CYTHON_REFNANNY #define __Pyx_MatchCase_Mapping_ExtractNonDict(...) __Pyx__MatchCase_Mapping_ExtractNonDict(__pyx_refnanny, __VA_ARGS__) -#define __Pyx_MatchCase_Mapping_ExtractNonDictV(...) __Pyx__MatchCase_Mapping_ExtractNonDictV(__pyx_refnanny, __VA_ARGS__) #else #define __Pyx_MatchCase_Mapping_ExtractNonDict(...) __Pyx__MatchCase_Mapping_ExtractNonDict(NULL, __VA_ARGS__) -#define __Pyx_MatchCase_Mapping_ExtractNonDictV(...) __Pyx__MatchCase_Mapping_ExtractNonDictV(NULL, __VA_ARGS__) #endif -static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *fixed_keys, PyObject *var_keys, ...); /* proto */ -static int __Pyx__MatchCase_Mapping_ExtractNonDictV(void *__pyx_refnanny, PyObject *mapping, PyObject *fixed_keys, PyObject *var_keys, va_list subjects); /* proto */ +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */ ///////////////////////// ExtractNonDict ////////////////////////////////////// //@requires: ObjectHandling.c::PyObjectCall2Args // largely adapted from match_keys in CPython ceval.c -static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, ...) { - int result; - va_list subjects; - - va_start(subjects, var_keys); - result = __Pyx_MatchCase_Mapping_ExtractNonDictV(map, fixed_keys, var_keys, subjects); - va_end(subjects); - return result; -} - -static int __Pyx__MatchCase_Mapping_ExtractNonDictV(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, va_list subjects) { +static int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) { PyObject *dummy=NULL, *get=NULL; - PyObject *keys[] = {fixed_keys, var_keys}; - Py_ssize_t i, j; + Py_ssize_t i; int result = 0; +#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL + PyObject *get_method = NULL, *get_self = NULL; +#endif dummy = PyObject_CallObject((PyObject *)&PyBaseObject_Type, NULL); if (!dummy) { return -1; } - get = PyObject_GetAttrString(map, "get"); + get = PyObject_GetAttrString(mapping, "get"); if (!get) { result = -1; goto end; } +#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL + if (likely(PyMethod_Check(get))) { + // both of these are borrowed + get_method = PyMethod_GET_FUNCTION(get); + get_self = PyMethod_GET_SELF(get); + } +#endif - for (i=0; i<2; ++i) { - PyObject *tuple = keys[i]; - for (j=0; j<PyTuple_GET_SIZE(tuple); ++j) { - PyObject **subject; - PyObject *value = NULL; - PyObject *key = PyTuple_GET_ITEM(tuple, j); - - // TODO - there's an optimization here (although it deviates from the strict definition of pattern matching). - // If we don't need the values then we can call PyObject_Contains instead of "get". If we don't need *any* - // of the values then we can skip initialization "get" and "dummy" + for (i=0; i<nKeys; ++i) { + PyObject **subject; + PyObject *value = NULL; + PyObject *key = keys[i]; + + // TODO - there's an optimization here (although it deviates from the strict definition of pattern matching). + // If we don't need the values then we can call PyObject_Contains instead of "get". If we don't need *any* + // of the values then we can skip initialization "get" and "dummy" +#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL + if (likely(get_method)) { + PyObject *args[] = { get_self, key, dummy }; + value = _PyObject_Vectorcall(get_method, args, 3, NULL); + } + else +#endif + { value = __Pyx_PyObject_Call2Args(get, key, dummy); - if (!value) { - result = -1; - goto end; - } else if (value == dummy) { - Py_DECREF(value); - goto end; // failed + } + if (!value) { + result = -1; + goto end; + } else if (value == dummy) { + Py_DECREF(value); + goto end; // failed + } else { + subject = subjects[i]; + if (subject) { + __Pyx_XDECREF_SET(*subject, value); + __Pyx_GOTREF(*subject); } else { - subject = va_arg(subjects, PyObject**); - if (subject) { - __Pyx_XDECREF_SET(*subject, value); - __Pyx_GOTREF(*subject); - } else { - Py_DECREF(value); - } + Py_DECREF(value); } } } @@ -679,36 +629,28 @@ static int __Pyx__MatchCase_Mapping_ExtractNonDictV(void *__pyx_refnanny, PyObje ///////////////////////// ExtractGeneric.proto //////////////////////////////// -#include <stdarg.h> - #if CYTHON_REFNANNY #define __Pyx_MatchCase_Mapping_Extract(...) __Pyx__MatchCase_Mapping_Extract(__pyx_refnanny, __VA_ARGS__) #else #define __Pyx_MatchCase_Mapping_Extract(...) __Pyx__MatchCase_Mapping_Extract(NULL, __VA_ARGS__) #endif -static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, ...); /* proto */ +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */ ////////////////////// ExtractGeneric ////////////////////////////////////// //@requires: ExtractExactDict //@requires: ExtractNonDict -static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, ...) { - va_list subjects; - int result; - - va_start(subjects, var_keys); - if (PyDict_CheckExact(map)) { - result = __Pyx_MatchCase_Mapping_ExtractDictV(map, fixed_keys, var_keys, subjects); +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) { + if (PyDict_CheckExact(mapping)) { + return __Pyx_MatchCase_Mapping_ExtractDict(mapping, keys, nKeys, subjects); } else { - result = __Pyx_MatchCase_Mapping_ExtractNonDictV(map, fixed_keys, var_keys, subjects); + return __Pyx_MatchCase_Mapping_ExtractNonDict(mapping, keys, nKeys, subjects); } - va_end(subjects); - return result; } ///////////////////////////// DoubleStarCapture.proto ////////////////////// -static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObject *const_temps, PyObject *var_temps); /* proto */ +static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys); /* proto */ //////////////////////////// DoubleStarCapture ////////////////////////////// @@ -717,31 +659,30 @@ static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObjec // https://github.com/python/cpython/blob/145bf269df3530176f6ebeab1324890ef7070bf8/Python/ceval.c#L3977 // (now removed in favour of building the same thing from a combination of opcodes) // The differences are: -// 1. We loop over separate tuples for constant and runtime keys -// 2. We add a shortcut for when there will be no left over keys (because I'm guess it's pretty common) +// 1. We use an array of keys rather than a tuple of keys +// 2. We add a shortcut for when there will be no left over keys (because I guess it's pretty common) // // Tempita variable 'tag' can be "NonDict", "ExactDict" or empty -static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObject *const_temps, PyObject *var_temps) { +static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys) { PyObject *dict_out; - PyObject *tuples[] = { const_temps, var_temps }; - Py_ssize_t i, j; + Py_ssize_t i; {{if tag != "NonDict"}} // shortcut for when there are no left-over keys - if ({{if tag=="ExactDict"}}(1){{else}}PyDict_CheckExact(map){{endif}}) { - Py_ssize_t s = PyDict_Size(map); + if ({{if tag=="ExactDict"}}(1){{else}}PyDict_CheckExact(mapping){{endif}}) { + Py_ssize_t s = PyDict_Size(mapping); if (s == -1) { return NULL; } - if (s == (PyTuple_GET_SIZE(const_temps) + PyTuple_GET_SIZE(var_temps))) { + if (s == nKeys) { return PyDict_New(); } } {{endif}} {{if tag=="ExactDict"}} - dict_out = PyDict_Copy(map); + dict_out = PyDict_Copy(mapping); {{else}} dict_out = PyDict_New(); {{endif}} @@ -749,19 +690,16 @@ static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObjec return NULL; } {{if tag!="ExactDict"}} - if (PyDict_Update(dict_out, map)) { + if (PyDict_Update(dict_out, mapping)) { Py_DECREF(dict_out); return NULL; } {{endif}} - for (i=0; i<2; ++i) { - PyObject *keys = tuples[i]; - for (j=0; j<PyTuple_GET_SIZE(keys); ++j) { - if (PyDict_DelItem(dict_out, PyTuple_GET_ITEM(keys, j))) { - Py_DECREF(dict_out); - return NULL; - } + for (i=0; i<nKeys; ++i) { + if (PyDict_DelItem(dict_out, keys[i])) { + Py_DECREF(dict_out); + return NULL; } } return dict_out; @@ -769,30 +707,81 @@ static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObjec ////////////////////////////// ClassPositionalPatterns.proto //////////////////////// -#include <stdarg.h> - #if CYTHON_REFNANNY #define __Pyx_MatchCase_ClassPositional(...) __Pyx__MatchCase_ClassPositional(__pyx_refnanny, __VA_ARGS__) #else #define __Pyx_MatchCase_ClassPositional(...) __Pyx__MatchCase_ClassPositional(NULL, __VA_ARGS__) #endif -static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *keysnames_tuple, int match_self, int num_args, ...); /* proto */ +static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *fixed_names[], Py_ssize_t n_fixed, int match_self, PyObject **subjects[], Py_ssize_t n_subjects); /* proto */ /////////////////////////////// ClassPositionalPatterns ////////////////////////////// -//@requires: DuplicateKeyCheck + +static int __Pyx_MatchCase_ClassCheckDuplicateAttrs(const char *tp_name, PyObject *fixed_names[], Py_ssize_t n_fixed, PyObject *match_args, Py_ssize_t num_args) { + // a lot of the basic logic of this is shared with __Pyx_MatchCase_CheckMappingDuplicateKeys + // but they take different input types so it isn't easy to actually share the code. + + // Inputs are tuples, and typically fairly small. It may be more efficient to + // loop over the tuple than create a set. + + PyObject *attrs_set; + PyObject *attr = NULL; + Py_ssize_t n; + int contains; + + attrs_set = PySet_New(NULL); + if (!attrs_set) return -1; + + num_args = PyTuple_GET_SIZE(match_args) < num_args ? PyTuple_GET_SIZE(match_args) : num_args; + for (n=0; n < num_args; ++n) { + attr = PyTuple_GET_ITEM(match_args, n); + contains = PySet_Contains(attrs_set, attr); + if (contains < 0) { + goto bad; + } else if (contains == 1) { + goto raise_error; + } else { + if (PySet_Add(attrs_set, attr)) { + goto bad; + } + } + } + for (n=0; n < n_fixed; ++n) { + attr = fixed_names[n]; + contains = PySet_Contains(attrs_set, attr); + if (contains < 0) { + goto bad; + } else if (contains == 1) { + goto raise_error; + } + } + Py_DECREF(attrs_set); + return 0; + + raise_error: + #if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute %R", + tp_name, attr); + #else + // DW has no interest in working around the lack of %R in Python 2.7 + PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute", + tp_name); + #endif + bad: + Py_DECREF(attrs_set); + return -1; +} // Adapted from ceval.c "match_class" in CPython // // The argument match_self can equal 1 for "known to be true" // 0 for "known to be false" // -1 for "unknown", runtime test - -static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *keysnames_tuple, int match_self, int num_args, ...) +// nargs is >= 0 otherwise this function will be skipped +static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *fixed_names[], Py_ssize_t n_fixed, int match_self, PyObject **subjects[], Py_ssize_t n_subjects) { - PyObject *match_args, *dup_key; + PyObject *match_args; Py_ssize_t allowed, i; int result; - va_list subjects; match_args = PyObject_GetAttrString((PyObject*)type, "__match_args__"); if (!match_args) { @@ -805,19 +794,17 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj _Py_TPFLAGS_MATCH_SELF); #else // probably an earlier version of Python. Go off the known list in the specification - match_self = (PyType_IsSubtype(type, &PyByteArray_Type) || - PyType_IsSubtype(type, &PyBytes_Type) || - PyType_IsSubtype(type, &PyDict_Type) || + match_self = ((PyType_GetFlags(type) & + // long should capture bool too + (Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS | Py_TPFLAGS_TUPLE_SUBCLASS | + Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS | Py_TPFLAGS_DICT_SUBCLASS + #if PY_MAJOR_VERSION < 3 + | Py_TPFLAGS_IN_SUBCLASS + #endif + )) || + PyType_IsSubtype(type, &PyByteArray_Type) || PyType_IsSubtype(type, &PyFloat_Type) || PyType_IsSubtype(type, &PyFrozenSet_Type) || - PyType_IsSubtype(type, &PyLong_Type) || // This should capture bool too - #if PY_MAJOR_VERSION < 3 - PyType_IsSubtype(type, &PyInt_Type) || - #endif - PyType_IsSubtype(type, &PyList_Type) || - PyType_IsSubtype(type, &PySet_Type) || - PyType_IsSubtype(type, &PyUnicode_Type) || - PyType_IsSubtype(type, &PyTuple_Type) ); #endif } @@ -838,18 +825,17 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj allowed = match_self ? 1 : (match_args ? PyTuple_GET_SIZE(match_args) : 0); - if (allowed < num_args) { + if (allowed < n_subjects) { const char *plural = (allowed == 1) ? "" : "s"; PyErr_Format(PyExc_TypeError, "%s() accepts %d positional sub-pattern%s (%d given)", type->tp_name, - allowed, plural, num_args); + allowed, plural, n_subjects); Py_XDECREF(match_args); return -1; } - va_start(subjects, num_args); if (match_self) { - PyObject **self_subject = va_arg(subjects, PyObject**); + PyObject **self_subject = subjects[0]; if (self_subject) { // Easy. Copy the subject itself, and move on to kwargs. __Pyx_XDECREF_SET(*self_subject, subject); @@ -858,20 +844,13 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj result = 1; goto end_match_self; } - // next stage is to check for duplicate keys. Reuse code from mapping - dup_key = __Pyx_MatchCase_CheckDuplicateKeys(keysnames_tuple, match_args, num_args); - if (dup_key) { - PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute %R", - type->tp_name, dup_key); - Py_DECREF(dup_key); - result = -1; - goto end; - } else if (PyErr_Occurred()) { + // next stage is to check for duplicate attributes. + if (__Pyx_MatchCase_ClassCheckDuplicateAttrs(type->tp_name, fixed_names, n_fixed, match_args, n_subjects)) { result = -1; goto end; } - for (i = 0; i < num_args; i++) { + for (i = 0; i < n_subjects; i++) { PyObject *attr; PyObject **subject_i; PyObject *name = PyTuple_GET_ITEM(match_args, i); @@ -889,7 +868,7 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj result = 0; goto end; } - subject_i = va_arg(subjects, PyObject**); + subject_i = subjects[i]; if (subject_i) { __Pyx_XDECREF_SET(*subject_i, attr); __Pyx_GOTREF(attr); @@ -902,7 +881,6 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj end: Py_DECREF(match_args); end_match_self: // because match_args isn't set - va_end(subjects); return result; } @@ -913,6 +891,13 @@ static PyTypeObject* __Pyx_MatchCase_IsType(PyObject* type); /* proto */ //////////////////////// MatchClassIsType ///////////////////////////// static PyTypeObject* __Pyx_MatchCase_IsType(PyObject* type) { + #if PY_MAJOR_VERSION < 3 + if (PyClass_Check(type)) { + // I don't really think it's worth the effort getting this to work! + PyErr_Format(PyExc_TypeError, "called match pattern must be a new-style class."); + return NULL; + } + #endif if (!PyType_Check(type)) { PyErr_Format(PyExc_TypeError, "called match pattern must be a type"); return NULL; diff --git a/Tools/ci-run.sh b/Tools/ci-run.sh index 905a9d1e3..ffde4cbe1 100644 --- a/Tools/ci-run.sh +++ b/Tools/ci-run.sh @@ -83,6 +83,8 @@ else python -m pip install -r test-requirements.txt || exit 1 if [[ $PYTHON_VERSION != "pypy"* && $PYTHON_VERSION != "3."[1]* ]]; then python -m pip install -r test-requirements-cpython.txt || exit 1 + elif [[ $PYTHON_VERSION == "pypy-2.7" ]]; then + python -m pip install -r test-requirements-pypy27.txt || exit 1 fi fi fi diff --git a/test-requirements-pypy27.txt b/test-requirements-pypy27.txt index 9f9505240..6d4f83bca 100644 --- a/test-requirements-pypy27.txt +++ b/test-requirements-pypy27.txt @@ -1,2 +1,3 @@ -r test-requirements.txt +enum34==1.1.10 mock==3.0.5 diff --git a/tests/run/extra_patma.pyx b/tests/run/extra_patma.pyx index 3bded3426..76357f36f 100644 --- a/tests/run/extra_patma.pyx +++ b/tests/run/extra_patma.pyx @@ -5,6 +5,41 @@ cimport cython import array +import sys + +__doc__ = "" + + +cdef bint is_null(int* x): + return False # disabled - currently just a parser test + match x: + case NULL: + return True + case _: + return False + + +def test_is_null(): + """ + >>> test_is_null() + """ + cdef int some_int = 1 + return # disabled - currently just a parser test + assert is_null(&some_int) == False + assert is_null(NULL) == True + + +if sys.version_info[0] > 2: + __doc__ += """ + array.array doesn't have the buffer protocol in Py2 and + this doesn't really feel worth working around to test + >>> print(test_memoryview(array.array('i', [0, 1, 2]))) + a 1 + >>> print(test_memoryview(array.array('i', []))) + b + >>> print(test_memoryview(array.array('i', [5]))) + c [5] + """ # goes via .shape instead @cython.test_fail_if_path_exists("//CallNode//NameNode[@name = 'len']") @@ -12,14 +47,8 @@ import array @cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") def test_memoryview(int[:] x): """ - >>> print(test_memoryview(array.array('i', [0, 1, 2]))) - a 1 - >>> print(test_memoryview(array.array('i', []))) - b >>> print(test_memoryview(None)) no! - >>> print(test_memoryview(array.array('i', [5]))) - c [5] """ match x: case [0, y, 2]: @@ -45,6 +74,7 @@ def test_list_to_sequence(list x): case _: return False + @cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") @cython.test_fail_if_path_exists("//CmpNode") # There's nothing to compare - it always succeeds! def test_list_not_None_to_sequence(list x not None): @@ -89,7 +119,7 @@ def class_attr_lookup(x): assert cython.typeof(y) == "double", cython.typeof(y) return y -class PyClass: +class PyClass(object): pass @cython.test_assert_path_exists("//PythonCapiFunctionNode[@cname='__Pyx_TypeCheck']") @@ -110,6 +140,7 @@ def class_typecheck_exists(x): case _: return False + @cython.test_fail_if_path_exists("//NameNode[@name='isinstance']") @cython.test_fail_if_path_exists("//PythonCapiFunctionNode[@cname='__Pyx_TypeCheck']") def class_typecheck_doesnt_exist(C x): diff --git a/tests/run/extra_patma_py.py b/tests/run/extra_patma_py.py index e4b2aef84..a5046b997 100644 --- a/tests/run/extra_patma_py.py +++ b/tests/run/extra_patma_py.py @@ -4,6 +4,9 @@ from __future__ import print_function import array +import sys + +__doc__ = "" def test_type_inference(x): """ @@ -63,10 +66,15 @@ def test_duplicate_keys(key1, key2): >>> test_duplicate_keys("a", "b") True - >>> test_duplicate_keys("a", "a") - Traceback (most recent call last): - ... - ValueError: mapping pattern checks duplicate key ('a') + + Slightly awkward doctest to work around Py2 incompatibility + >>> try: + ... test_duplicate_keys("a", "a") + ... except ValueError as e: + ... if sys.version_info[0] > 2: + ... assert e.args[0] == "mapping pattern checks duplicate key ('a')", e.args[0] + ... else: + ... assert e.args[0] == "mapping pattern checks duplicate key" """ class Keys: KEY_1 = key1 @@ -79,7 +87,7 @@ def test_duplicate_keys(key1, key2): return False -class PyClass: +class PyClass(object): pass @@ -99,3 +107,20 @@ class PrivateAttrLookupOuter: match x: case PyClass(__something=y): return y + + +if sys.version_info[0] < 3: + class OldStyleClass: + pass + + def test_oldstyle_class_failure(x): + match x: + case OldStyleClass(): + return True + + __doc__ += """ + >>> test_oldstyle_class_failure(1) + Traceback (most recent call last): + ... + TypeError: called match pattern must be a new-style class. + """ diff --git a/tests/run/test_patma.py b/tests/run/test_patma.py index 9344c8169..e51ba0dbd 100644 --- a/tests/run/test_patma.py +++ b/tests/run/test_patma.py @@ -1,43 +1,21 @@ -### COPIED FROM CPython 3.9 +### COPIED FROM CPython 3.12 alpha (July 2022) ### Original part after ############ # cython: language_level=3 # new code import cython -from Cython.Compiler.Main import compile as cython_compile, CompileError -from Cython.Build.Inline import cython_inline -import contextlib -from tempfile import NamedTemporaryFile - -@contextlib.contextmanager -def hidden_stderr(): - try: - from StringIO import StringIO - except ImportError: - from io import StringIO - - old_stderr = sys.stderr - try: - sys.stderr = StringIO() - yield - finally: - sys.stderr = old_stderr - -def _compile(code): - with NamedTemporaryFile(suffix='.py') as f: - f.write(code.encode('utf8')) - f.flush() - - with hidden_stderr(): - result = cython_compile(f.name, language_level=3) - return result +from Cython.TestUtils import py_parse_code + if cython.compiled: def compile(code, name, what): assert what == 'exec' - result = _compile(code) - if not result.c_file: - raise SyntaxError('unexpected EOF') # compile is only used for testing errors + py_parse_code(code) + + +def disable(func): + pass + ############## SLIGHTLY MODIFIED ORIGINAL CODE import array @@ -68,9 +46,47 @@ else: return self.x == other.x and self.y == other.y # TestCompiler removed - it's very CPython-specific -# TestTracing also removed - doesn't seem like a core test +# TestTracing also mainly removed - doesn't seem like a core test +# except for one test that seems misplaced in CPython (which is below) + +class TestTracing(unittest.TestCase): + if sys.version_info < (3, 4): + class SubTestClass(object): + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + return + def __call__(self, *args): + return self + subTest = SubTestClass() + + def test_parser_deeply_nested_patterns(self): + # Deeply nested patterns can cause exponential backtracking when parsing. + # See CPython gh-93671 for more information. + # + # DW: Cython note - this doesn't break the parser but may cause a + # RecursionError later in the code-generation. I don't believe that's + # easily avoidable with the way Cython visitors currently work + + levels = 100 + + patterns = [ + "A" + "(" * levels + ")" * levels, + "{1:" * levels + "1" + "}" * levels, + "[" * levels + "1" + "]" * levels, + ] -# FIXME - return all the "return"s added to cause code to be dropped + for pattern in patterns: + with self.subTest(pattern): + code = inspect.cleandoc(""" + match None: + case {}: + pass + """.format(pattern)) + compile(code, "<string>", "exec") + + +# FIXME - remove all the "return"s added to cause code to be dropped ############## ORIGINAL PART FROM CPYTHON @@ -2706,6 +2722,21 @@ class TestPatma(unittest.TestCase): self.assertEqual(y, 'bar') + def test_patma_249(self): + return # disabled + class C: + __attr = "eggs" # mangled to _C__attr + _Outer__attr = "bacon" + class Outer: + def f(self, x): + match x: + # looks up __attr, not _C__attr or _Outer__attr + case C(__attr=y): + return y + c = C() + setattr(c, "__attr", "spam") # setattr is needed because we're in a class scope + self.assertEqual(Outer().f(c), "spam") + class TestSyntaxErrors(unittest.TestCase): @@ -2728,6 +2759,7 @@ class TestSyntaxErrors(unittest.TestCase): """) + @disable # validation will be added when class patterns are added def test_attribute_name_repeated_in_class_pattern(self): self.assert_syntax_error(""" match ...: @@ -2826,6 +2858,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # will be implemented as part of sequence patterns def test_multiple_starred_names_in_sequence_pattern_0(self): self.assert_syntax_error(""" match ...: @@ -2833,6 +2866,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # will be implemented as part of sequence patterns def test_multiple_starred_names_in_sequence_pattern_1(self): self.assert_syntax_error(""" match ...: @@ -2967,6 +3001,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # validation will be added when class patterns are added def test_mapping_pattern_duplicate_key(self): self.assert_syntax_error(""" match ...: @@ -2974,6 +3009,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # validation will be added when class patterns are added def test_mapping_pattern_duplicate_key_edge_case0(self): self.assert_syntax_error(""" match ...: @@ -2981,6 +3017,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # validation will be added when class patterns are added def test_mapping_pattern_duplicate_key_edge_case1(self): self.assert_syntax_error(""" match ...: @@ -2988,6 +3025,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # validation will be added when class patterns are added def test_mapping_pattern_duplicate_key_edge_case2(self): self.assert_syntax_error(""" match ...: @@ -2995,6 +3033,7 @@ class TestSyntaxErrors(unittest.TestCase): pass """) + @disable # validation will be added when class patterns are added def test_mapping_pattern_duplicate_key_edge_case3(self): self.assert_syntax_error(""" match ...: |