summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorda-woods <dw-git@d-woods.co.uk>2022-12-08 21:37:49 +0000
committerda-woods <dw-git@d-woods.co.uk>2022-12-08 21:37:49 +0000
commitf0fab87c193ba3acd17010cc07183583b50987b6 (patch)
treeb8fc83f31fa6385e040cd3635130225983eb266f
parent4612175c2fdaa62a184be97a0b4c5501718a6ee3 (diff)
parent79969ec1d213a6d24ce9e76fdee6d9be9dc8422b (diff)
downloadcython-f0fab87c193ba3acd17010cc07183583b50987b6.tar.gz
Merge branch 'match-or' into patma-preview
-rw-r--r--Cython/Compiler/MatchCaseNodes.py373
-rw-r--r--Cython/Compiler/Nodes.py4
-rw-r--r--Cython/Compiler/ParseTreeTransforms.pxd1
-rw-r--r--Cython/Compiler/ParseTreeTransforms.py13
-rw-r--r--Cython/Compiler/Parsing.pxd2
-rw-r--r--Cython/Compiler/Parsing.py271
-rw-r--r--Cython/TestUtils.py24
-rw-r--r--Cython/Utility/MatchCase.c475
-rw-r--r--Tools/ci-run.sh2
-rw-r--r--test-requirements-pypy27.txt1
-rw-r--r--tests/run/extra_patma.pyx45
-rw-r--r--tests/run/extra_patma_py.py35
-rw-r--r--tests/run/test_patma.py105
13 files changed, 811 insertions, 540 deletions
diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py
index 833520642..a7e246bcb 100644
--- a/Cython/Compiler/MatchCaseNodes.py
+++ b/Cython/Compiler/MatchCaseNodes.py
@@ -1,9 +1,8 @@
# Nodes for structural pattern matching.
#
-# In a separate file because they're unlikely to be useful
-# for much else
+# In a separate file because they're unlikely to be useful for much else.
-from .Nodes import Node, StatNode
+from .Nodes import Node, StatNode, ErrorNode
from .Errors import error, local_errors, report_error
from . import Nodes, ExprNodes, PyrexTypes, Builtin
from .Code import UtilityCode, TempitaUtilityCode
@@ -28,7 +27,11 @@ class MatchNode(StatNode):
def validate_irrefutable(self):
found_irrefutable_case = None
- for c in self.cases:
+ for case in self.cases:
+ if isinstance(case, ErrorNode):
+ # This validation happens before error nodes have been
+ # transformed into actual errors, so we need to ignore them
+ continue
if found_irrefutable_case:
error(
found_irrefutable_case.pos,
@@ -38,9 +41,9 @@ class MatchNode(StatNode):
),
)
break
- if c.is_irrefutable():
- found_irrefutable_case = c
- c.validate_irrefutable()
+ if case.is_irrefutable():
+ found_irrefutable_case = case
+ case.validate_irrefutable()
def refactor_cases(self):
# An early transform - changes cases that can be represented as
@@ -158,6 +161,8 @@ class MatchCaseNode(Node):
child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"]
def is_irrefutable(self):
+ if isinstance(self.pattern, ErrorNode):
+ return True # value doesn't really matter
return self.pattern.is_irrefutable() and not self.guard
def is_simple_value_comparison(self):
@@ -166,9 +171,13 @@ class MatchCaseNode(Node):
return self.pattern.is_simple_value_comparison()
def validate_targets(self):
+ if isinstance(self.pattern, ErrorNode):
+ return
self.pattern.get_targets()
def validate_irrefutable(self):
+ if isinstance(self.pattern, ErrorNode):
+ return
self.pattern.validate_irrefutable()
def is_sequence_or_mapping(self):
@@ -263,7 +272,7 @@ class SubstitutedMatchCaseNode(MatchCaseBaseNode):
class PatternNode(Node):
"""
- DW decided that PatternNode shouldn't be an expression because
+ PatternNode is not an expression because
it does several things (evalutating a boolean expression,
assignment of targets), and they need to be done at different
times.
@@ -305,9 +314,9 @@ class PatternNode(Node):
child_attrs = ["as_targets"]
def __init__(self, pos, **kwds):
+ if "as_targets" not in kwds:
+ kwds["as_targets"] = []
super(PatternNode, self).__init__(pos, **kwds)
- if not hasattr(self, "as_targets"):
- self.as_targets = []
def is_irrefutable(self):
return False
@@ -322,14 +331,13 @@ class PatternNode(Node):
def get_targets(self):
targets = self.get_main_pattern_targets()
- for t in self.as_targets:
- self.add_target_to_targets(targets, t.name)
+ for target in self.as_targets:
+ self.add_target_to_targets(targets, target.name)
return targets
def update_targets_with_targets(self, targets, other_targets):
- intersection = targets.intersection(other_targets)
- for i in intersection:
- error(self.pos, "multiple assignments to name '%s' in pattern" % i)
+ for name in targets.intersection(other_targets):
+ error(self.pos, "multiple assignments to name '%s' in pattern" % name)
targets.update(other_targets)
def add_target_to_targets(self, targets, target):
@@ -362,7 +370,7 @@ class PatternNode(Node):
def validate_irrefutable(self):
for attr in self.child_attrs:
child = getattr(self, attr)
- if isinstance(child, PatternNode):
+ if child is not None and isinstance(child, PatternNode):
child.validate_irrefutable()
def analyse_pattern_expressions(self, env, sequence_mapping_temp):
@@ -516,9 +524,9 @@ class OrPatternNode(PatternNode):
child_attrs = PatternNode.child_attrs + ["alternatives"]
def get_first_irrefutable(self):
- for a in self.alternatives:
- if a.is_irrefutable():
- return a
+ for alternative in self.alternatives:
+ if alternative.is_irrefutable():
+ return alternative
return None
def is_irrefutable(self):
@@ -537,17 +545,17 @@ class OrPatternNode(PatternNode):
def get_main_pattern_targets(self):
child_targets = None
- for ch in self.alternatives:
- ch_targets = ch.get_targets()
- if child_targets is not None and child_targets != ch_targets:
+ for alternative in self.alternatives:
+ alternative_targets = alternative.get_targets()
+ if child_targets is not None and child_targets != alternative_targets:
error(self.pos, "alternative patterns bind different names")
- child_targets = ch_targets
+ child_targets = alternative_targets
return child_targets
def validate_irrefutable(self):
super(OrPatternNode, self).validate_irrefutable()
found_irrefutable_case = None
- for a in self.alternatives:
+ for alternative in self.alternatives:
if found_irrefutable_case:
error(
found_irrefutable_case.pos,
@@ -557,9 +565,9 @@ class OrPatternNode(PatternNode):
),
)
break
- if a.is_irrefutable():
- found_irrefutable_case = a
- a.validate_irrefutable()
+ if alternative.is_irrefutable():
+ found_irrefutable_case = alternative
+ alternative.validate_irrefutable()
def is_simple_value_comparison(self):
return all(
@@ -739,10 +747,10 @@ class MatchSequencePatternNode(PatternNode):
def get_main_pattern_targets(self):
targets = set()
star_count = 0
- for p in self.patterns:
- if p.is_match_and_assign_pattern and p.is_star:
+ for pattern in self.patterns:
+ if pattern.is_match_and_assign_pattern and pattern.is_star:
star_count += 1
- self.update_targets_with_targets(targets, p.get_targets())
+ self.update_targets_with_targets(targets, pattern.get_targets())
if star_count > 1:
error(self.pos, "multiple starred names in sequence pattern")
return targets
@@ -1060,30 +1068,32 @@ class MatchMappingPatternNode(PatternNode):
],
exception_value="-1",
)
+ # lie about the types of keys for simplicity
Pyx_mapping_check_duplicates_type = PyrexTypes.CFuncType(
PyrexTypes.c_int_type,
[
- PyrexTypes.CFuncTypeArg("fixed_keys", PyrexTypes.py_object_type, None),
- PyrexTypes.CFuncTypeArg("var_keys", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None),
],
exception_value="-1",
)
+ # lie about the types of keys and subjects for simplicity
Pyx_mapping_extract_subjects_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type,
[
- PyrexTypes.CFuncTypeArg("map", PyrexTypes.py_object_type, None),
- PyrexTypes.CFuncTypeArg("fixed_keys", PyrexTypes.py_object_type, None),
- PyrexTypes.CFuncTypeArg("var_keys", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("mapping", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("subjects", PyrexTypes.c_void_ptr_ptr_type, None),
],
exception_value="-1",
- has_varargs=True,
)
Pyx_mapping_doublestar_type = PyrexTypes.CFuncType(
Builtin.dict_type,
[
- PyrexTypes.CFuncTypeArg("map", PyrexTypes.py_object_type, None),
- PyrexTypes.CFuncTypeArg("fixed_keys", PyrexTypes.py_object_type, None),
- PyrexTypes.CFuncTypeArg("var_keys", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("mapping", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None),
],
)
@@ -1092,8 +1102,8 @@ class MatchMappingPatternNode(PatternNode):
def get_main_pattern_targets(self):
targets = set()
- for p in self.value_patterns:
- self.update_targets_with_targets(targets, p.get_targets())
+ for pattern in self.value_patterns:
+ self.update_targets_with_targets(targets, pattern.get_targets())
if self.double_star_capture_target:
self.add_target_to_targets(targets, self.double_star_capture_target.name)
return targets
@@ -1203,8 +1213,10 @@ class MatchMappingPatternNode(PatternNode):
self.pos, arg=subject_node, fallback=call, check=self.is_dict_type_check
)
- def make_duplicate_keys_check(self, static_keys_tuple, var_keys_tuple):
+ def make_duplicate_keys_check(self, n_fixed_keys):
utility_code = UtilityCode.load_cached("MappingKeyCheck", "MatchCase.c")
+ if n_fixed_keys == len(self.keys):
+ return None # nothing to check
return Nodes.ExprStatNode(
self.pos,
@@ -1213,11 +1225,15 @@ class MatchMappingPatternNode(PatternNode):
"__Pyx_MatchCase_CheckMappingDuplicateKeys",
self.Pyx_mapping_check_duplicates_type,
utility_code=utility_code,
- args=[static_keys_tuple.clone_node(), var_keys_tuple],
+ args=[
+ MappingOrClassComparisonNode.make_keys_node(self.pos),
+ ExprNodes.IntNode(self.pos, value=str(n_fixed_keys)),
+ ExprNodes.IntNode(self.pos, value=str(len(self.keys)))
+ ],
),
)
- def check_all_keys(self, subject_node, const_keys_tuple, var_keys_tuple):
+ def check_all_keys(self, subject_node):
# It's debatable here whether to go for individual unpacking or a function.
# Current implementation is a function that's loosely copied from CPython.
# For small numbers of keys it might be better to generate the code instead.
@@ -1243,24 +1259,23 @@ class MatchMappingPatternNode(PatternNode):
util_code = UtilityCode.load_cached("ExtractGeneric", "MatchCase.c")
func_name = "__Pyx_MatchCase_Mapping_Extract"
- subject_derefs = [
- ExprNodes.NullNode(self.pos)
- if t is None
- else AddressOfPyObjectNode(self.pos, obj=t)
- for t in self.subject_temps
- ]
return ExprNodes.PythonCapiCallNode(
self.pos,
func_name,
self.Pyx_mapping_extract_subjects_type,
utility_code=util_code,
- args=[subject_node, const_keys_tuple.clone_node(), var_keys_tuple]
- + subject_derefs,
+ args=[
+ subject_node,
+ MappingOrClassComparisonNode.make_keys_node(self.pos),
+ ExprNodes.IntNode(
+ self.pos,
+ value=str(len(self.keys))
+ ),
+ MappingOrClassComparisonNode.make_subjects_node(self.pos),
+ ],
)
- def make_double_star_capture(
- self, subject_node, const_tuple, var_tuple, test_result
- ):
+ def make_double_star_capture(self, subject_node, test_result):
# test_result being the variable that holds "case check passed until now"
is_dict = self.is_dict_type_check(subject_node.type)
if is_dict:
@@ -1277,7 +1292,11 @@ class MatchMappingPatternNode(PatternNode):
"__Pyx_MatchCase_DoubleStarCapture" + tag,
self.Pyx_mapping_doublestar_type,
utility_code=utility_code,
- args=[subject_node, const_tuple, var_tuple],
+ args=[
+ subject_node,
+ MappingOrClassComparisonNode.make_keys_node(self.pos),
+ ExprNodes.IntNode(self.pos, value=str(len(self.keys)))
+ ],
)
assignment = Nodes.SingleAssignmentNode(
self.double_star_capture_target.pos, lhs=self.double_star_temp, rhs=func
@@ -1294,22 +1313,17 @@ class MatchMappingPatternNode(PatternNode):
def get_comparison_node(self, subject_node, sequence_mapping_temp=None):
from . import UtilNodes
- const_keys = []
var_keys = []
+ n_literal_keys = 0
for k in self.keys:
- if not k.arg.is_literal:
- k = UtilNodes.ResultRefNode(k, is_temp=False)
+ if not k.is_literal:
var_keys.append(k)
else:
- const_keys.append(k.arg.clone_node())
- const_keys_tuple = ExprNodes.TupleNode(self.pos, args=const_keys)
- var_keys_tuple = ExprNodes.TupleNode(self.pos, args=var_keys)
- if var_keys:
- var_keys_tuple = UtilNodes.ResultRefNode(var_keys_tuple, is_temp=True)
+ n_literal_keys += 1
all_tests = []
all_tests.append(self.make_mapping_check(subject_node, sequence_mapping_temp))
- all_tests.append(self.check_all_keys(subject_node, const_keys_tuple, var_keys_tuple))
+ all_tests.append(self.check_all_keys(subject_node))
if any(isinstance(test, ExprNodes.BoolNode) and not test.value for test in all_tests):
# identify automatic-failure
@@ -1324,10 +1338,10 @@ class MatchMappingPatternNode(PatternNode):
all_tests = generate_binop_tree_from_list(self.pos, "and", all_tests)
test_result = UtilNodes.ResultRefNode(pos=self.pos, type=PyrexTypes.c_bint_type)
+ duplicate_check = self.make_duplicate_keys_check(n_literal_keys)
body = Nodes.StatListNode(
self.pos,
- stats=[
- self.make_duplicate_keys_check(const_keys_tuple, var_keys_tuple),
+ stats=([duplicate_check] if duplicate_check else []) + [
Nodes.SingleAssignmentNode(self.pos, lhs=test_result, rhs=all_tests),
],
)
@@ -1335,21 +1349,21 @@ class MatchMappingPatternNode(PatternNode):
assert self.double_star_temp
body.stats.append(
# make_double_star_capture wraps itself in an if
- self.make_double_star_capture(
- subject_node, const_keys_tuple, var_keys_tuple, test_result
- )
+ self.make_double_star_capture(subject_node, test_result)
)
- if var_keys or self.double_star_capture_target:
+ if duplicate_check or self.double_star_capture_target:
body = UtilNodes.TempResultFromStatNode(test_result, body)
- if var_keys:
- body = UtilNodes.EvalWithTempExprNode(var_keys_tuple, body)
- for k in var_keys:
- if isinstance(k, UtilNodes.ResultRefNode):
- body = UtilNodes.EvalWithTempExprNode(k, body)
- return LazyCoerceToBool(body.pos, arg=body)
else:
- return LazyCoerceToBool(all_tests.pos, arg=all_tests)
+ body = all_tests
+ if self.keys or self.double_star_capture_target:
+ body = MappingOrClassComparisonNode(
+ body.pos,
+ arg=LazyCoerceToBool(body.pos, arg=body),
+ keys_array=self.keys,
+ subjects_array=self.subject_temps
+ )
+ return LazyCoerceToBool(body.pos, arg=body)
def analyse_pattern_expressions(self, env, sequence_mapping_temp):
def to_temp_or_literal(node):
@@ -1359,7 +1373,7 @@ class MatchMappingPatternNode(PatternNode):
return node.coerce_to_temp(env)
self.keys = [
- ExprNodes.ProxyNode(to_temp_or_literal(k.analyse_expressions(env)))
+ to_temp_or_literal(k.analyse_expressions(env))
for k in self.keys
]
@@ -1410,16 +1424,19 @@ class ClassPatternNode(PatternNode):
keyword_pattern_names = []
keyword_pattern_patterns = []
+ # as with the mapping functions, lie a little about some of the types for
+ # ease of declaration
Pyx_positional_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type,
[
PyrexTypes.CFuncTypeArg("subject", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None),
- PyrexTypes.CFuncTypeArg("keysnames_tuple", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("fixed_names", PyrexTypes.c_void_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("n_fixed", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("match_self", PyrexTypes.c_int_type, None),
- PyrexTypes.CFuncTypeArg("num_args", PyrexTypes.c_int_type, None),
+ PyrexTypes.CFuncTypeArg("subjects", PyrexTypes.c_void_ptr_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("n_subjects", PyrexTypes.c_int_type, None),
],
- has_varargs=True,
exception_value="-1",
)
@@ -1476,11 +1493,8 @@ class ClassPatternNode(PatternNode):
def get_main_pattern_targets(self):
targets = set()
- for p in self.keyword_pattern_patterns:
- self.update_targets_with_targets(targets, p.get_targets())
-
- for p in self.positional_patterns:
- self.update_targets_with_targets(targets, p.get_targets())
+ for pattern in self.positional_patterns + self.keyword_pattern_patterns:
+ self.update_targets_with_targets(targets, pattern.get_targets())
return targets
def generate_main_pattern_assignment_list(self, subject_node, env):
@@ -1589,13 +1603,10 @@ class ClassPatternNode(PatternNode):
def make_positional_args_call(self, subject_node, class_node):
assert self.positional_patterns
util_code = UtilityCode.load_cached("ClassPositionalPatterns", "MatchCase.c")
- keynames = ExprNodes.TupleNode(
- self.pos,
- args=[
- ExprNodes.StringNode(n.pos, value=n.name)
- for n in self.keyword_pattern_names
- ],
- )
+ keynames = [
+ ExprNodes.StringNode(n.pos, value=n.name)
+ for n in self.keyword_pattern_names
+ ]
# -1 is "unknown"
match_self = (
-1
@@ -1628,21 +1639,28 @@ class ClassPatternNode(PatternNode):
match_self = 0 # I think... Relies on knowing the bases
match_self = ExprNodes.IntNode(self.pos, value=str(match_self))
- len_ = ExprNodes.IntNode(self.pos, value=str(len(self.positional_patterns)))
- subject_derefs = [
- ExprNodes.NullNode(self.pos)
- if t is None
- else AddressOfPyObjectNode(self.pos, obj=t)
- for t in self.positional_subject_temps
- ]
- return ExprNodes.PythonCapiCallNode(
+ n_subjects = ExprNodes.IntNode(self.pos, value=str(len(self.positional_patterns)))
+ return MappingOrClassComparisonNode(
self.pos,
- "__Pyx_MatchCase_ClassPositional",
- self.Pyx_positional_type,
- utility_code=util_code,
- args=[subject_node, class_node, keynames, match_self, len_]
- + subject_derefs,
+ arg=ExprNodes.PythonCapiCallNode(
+ self.pos,
+ "__Pyx_MatchCase_ClassPositional",
+ self.Pyx_positional_type,
+ utility_code=util_code,
+ args=[
+ subject_node,
+ class_node,
+ MappingOrClassComparisonNode.make_keys_node(self.pos),
+ ExprNodes.IntNode(self.pos, value=str(len(keynames))),
+ match_self,
+ MappingOrClassComparisonNode.make_subjects_node(self.pos),
+ n_subjects,
+ ]
+ ),
+ subjects_array=self.positional_subject_temps,
+ keys_array=keynames,
)
+ return
def make_subpattern_checks(self):
patterns = self.keyword_pattern_patterns + self.positional_patterns
@@ -2057,27 +2075,6 @@ class CompilerDirectivesExprNode(ExprNodes.ProxyNode):
self.arg.annotate(code)
-class AddressOfPyObjectNode(ExprNodes.ExprNode):
- """
- obj - some temp node
- """
-
- type = PyrexTypes.c_void_ptr_ptr_type
- is_temp = False
- subexprs = ["obj"]
-
- def analyse_types(self, env):
- self.obj = self.obj.analyse_types(env)
- assert self.obj.type.is_pyobject, repr(self.obj.type)
- return self
-
- def generate_result_code(self, code):
- self.obj.generate_result_code(code)
-
- def calculate_result_code(self):
- return "&%s" % self.obj.result()
-
-
class LazyCoerceToPyObject(ExprNodes.ExprNode):
"""
Just calls "self.arg.coerce_to_pyobject" when it's analysed,
@@ -2127,4 +2124,122 @@ def generate_binop_tree_from_list(pos, operator, list_of_tests):
operator=operator,
operand1=operand1,
operand2=operand2
- ) \ No newline at end of file
+ )
+
+
+class MappingOrClassComparisonNode(ExprNodes.ExprNode):
+ """
+ Combined with MappingOrClassComparisonNodeInner this is responsible
+ for setting up up the arrays of subjects and keys that are used in
+ the function calls that handle these types of patterns
+
+ Note that self.keys_array is owned by this but used by
+ MappingOrClassComparisonNodeInner - that's mainly to ensure that
+ it gets evaluated in the correct order
+ """
+ subexprs = ["keys_array", "inner"]
+
+ keys_array_cname = "__pyx_match_mapping_keys"
+ subjects_array_cname = "__pyx_match_mapping_subjects"
+
+ @property
+ def type(self):
+ return self.inner.type
+
+ @classmethod
+ def make_keys_node(cls, pos):
+ return ExprNodes.RawCNameExprNode(
+ pos,
+ type=PyrexTypes.c_void_ptr_type,
+ cname=cls.keys_array_cname
+ )
+
+ @classmethod
+ def make_subjects_node(cls, pos):
+ return ExprNodes.RawCNameExprNode(
+ pos,
+ type=PyrexTypes.c_void_ptr_ptr_type,
+ cname=cls.subjects_array_cname
+ )
+
+ def __init__(self, pos, arg, subjects_array, **kwds):
+ super(MappingOrClassComparisonNode, self).__init__(pos, **kwds)
+ self.inner = MappingOrClassComparisonNodeInner(
+ pos,
+ arg=arg,
+ keys_array = self.keys_array,
+ subjects_array = subjects_array
+ )
+
+ def analyse_types(self, env):
+ self.inner = self.inner.analyse_types(env)
+ self.keys_array = [
+ key.analyse_types(env).coerce_to_simple(env) for key in self.keys_array
+ ]
+ return self
+
+ def generate_result_code(self, code):
+ pass
+
+ def calculate_result_code(self):
+ return self.inner.calculate_result_code()
+
+
+class MappingOrClassComparisonNodeInner(ExprNodes.ExprNode):
+ """
+ Sets up the arrays of subjects and keys
+
+ Created by the constructor of MappingComparisonNode
+ (no need to create directly)
+
+ has attributes:
+ * arg - the main comparison node
+ * keys_array - list of ExprNodes representing keys
+ * subjects_array - list of ExprNodes representing subjects
+ """
+ subexprs = ['arg']
+
+ @property
+ def type(self):
+ return self.arg.type
+
+ def analyse_types(self, env):
+ self.arg = self.arg.analyse_types(env)
+ for n in range(len(self.keys_array)):
+ key = self.keys_array[n].analyse_types(env)
+ key = key.coerce_to_pyobject(env)
+ self.keys_array[n] = key
+ assert self.arg.type is PyrexTypes.c_bint_type
+ return self
+
+ def generate_evaluation_code(self, code):
+ code.putln("{")
+ keys_str = ", ".join(k.result() for k in self.keys_array)
+ if not keys_str:
+ # GCC gets worried about overflow if we pass
+ # a genuinely empty array
+ keys_str = "NULL"
+ code.putln("PyObject *%s[] = {%s};" % (
+ MappingOrClassComparisonNode.keys_array_cname,
+ keys_str,
+ ))
+ subjects_str = ", ".join(
+ "&"+subject.result() if subject is not None else "NULL" for subject in self.subjects_array
+ )
+ if not subjects_str:
+ # GCC gets worried about overflow if we pass
+ # a genuinely empty array
+ subjects_str = "NULL"
+ code.putln("PyObject **%s[] = {%s};" % (
+ MappingOrClassComparisonNode.subjects_array_cname,
+ subjects_str
+ ))
+ super(MappingOrClassComparisonNodeInner, self).generate_evaluation_code(code)
+
+ code.putln("}")
+
+ def generate_result_code(self, code):
+ pass
+
+ def calculate_result_code(self):
+ return self.arg.result() \ No newline at end of file
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index 49c99b2c3..006d1023c 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -10156,13 +10156,13 @@ class CnameDecoratorNode(StatNode):
class ErrorNode(Node):
"""
- Node type for things that we want to get throught the parser
+ Node type for things that we want to get through the parser
(especially for things that are being scanned in "tentative_scan"
blocks), but should immediately raise and error afterwards.
what str
"""
- pass
+ child_attrs = []
#------------------------------------------------------------------------------------
diff --git a/Cython/Compiler/ParseTreeTransforms.pxd b/Cython/Compiler/ParseTreeTransforms.pxd
index efbb14f70..2778be4ef 100644
--- a/Cython/Compiler/ParseTreeTransforms.pxd
+++ b/Cython/Compiler/ParseTreeTransforms.pxd
@@ -18,6 +18,7 @@ cdef class PostParse(ScopeTrackingTransform):
cdef dict specialattribute_handlers
cdef size_t lambda_counter
cdef size_t genexpr_counter
+ cdef bint in_pattern_node
cdef _visit_assignment_node(self, node, list expr_list)
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
index 414bae04d..301d93335 100644
--- a/Cython/Compiler/ParseTreeTransforms.py
+++ b/Cython/Compiler/ParseTreeTransforms.py
@@ -193,6 +193,7 @@ class PostParse(ScopeTrackingTransform):
self.specialattribute_handlers = {
'__cythonbufferdefaults__' : self.handle_bufferdefaults
}
+ self.in_pattern_node = False
def visit_LambdaNode(self, node):
# unpack a lambda expression into the corresponding DefNode
@@ -399,6 +400,18 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node)
return node
+ def visit_PatternNode(self, node):
+ in_pattern_node, self.in_pattern_node = self.in_pattern_node, True
+ self.visitchildren(node)
+ self.in_pattern_node = in_pattern_node
+ return node
+
+ def visit_JoinedStrNode(self, node):
+ if self.in_pattern_node:
+ error(node.pos, "f-strings are not accepted for pattern matching")
+ self.visitchildren(node)
+ return node
+
class _AssignmentExpressionTargetNameFinder(TreeVisitor):
def __init__(self):
super(_AssignmentExpressionTargetNameFinder, self).__init__()
diff --git a/Cython/Compiler/Parsing.pxd b/Cython/Compiler/Parsing.pxd
index 72a855fd4..fc3e2749f 100644
--- a/Cython/Compiler/Parsing.pxd
+++ b/Cython/Compiler/Parsing.pxd
@@ -62,6 +62,8 @@ cdef expect_ellipsis(PyrexScanner s)
cdef make_slice_nodes(pos, subscripts)
cpdef make_slice_node(pos, start, stop = *, step = *)
cdef p_atom(PyrexScanner s)
+cdef p_atom_string(PyrexScanner s)
+cdef p_atom_ident_constants(PyrexScanner s, bint bools_are_pybool = *)
@cython.locals(value=unicode)
cdef p_int_literal(PyrexScanner s)
cdef p_name(PyrexScanner s, name)
diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py
index 02c746351..75eb194c6 100644
--- a/Cython/Compiler/Parsing.py
+++ b/Cython/Compiler/Parsing.py
@@ -4,7 +4,6 @@
#
from __future__ import absolute_import
-from ast import Expression
# This should be done automatically
import cython
@@ -20,7 +19,7 @@ cython.declare(Nodes=object, ExprNodes=object, EncodedString=object,
from io import StringIO
import re
import sys
-from unicodedata import lookup as lookup_unicodechar, category as unicode_category, name
+from unicodedata import lookup as lookup_unicodechar, category as unicode_category
from functools import partial, reduce
from .Scanning import PyrexScanner, FileSourceDescriptor, tentatively_scan
@@ -719,36 +718,59 @@ def p_atom(s):
s.next()
return ExprNodes.ImagNode(pos, value = value)
elif sy == 'BEGIN_STRING':
- kind, bytes_value, unicode_value = p_cat_string_literal(s)
- if kind == 'c':
- return ExprNodes.CharNode(pos, value = bytes_value)
- elif kind == 'u':
- return ExprNodes.UnicodeNode(pos, value = unicode_value, bytes_value = bytes_value)
- elif kind == 'b':
- return ExprNodes.BytesNode(pos, value = bytes_value)
- elif kind == 'f':
- return ExprNodes.JoinedStrNode(pos, values = unicode_value)
- elif kind == '':
- return ExprNodes.StringNode(pos, value = bytes_value, unicode_value = unicode_value)
- else:
- s.error("invalid string kind '%s'" % kind)
+ return p_atom_string(s)
elif sy == 'IDENT':
- name = s.systring
- if name == "None":
- result = ExprNodes.NoneNode(pos)
- elif name == "True":
- result = ExprNodes.BoolNode(pos, value=True)
- elif name == "False":
- result = ExprNodes.BoolNode(pos, value=False)
- elif name == "NULL" and not s.in_python_file:
- result = ExprNodes.NullNode(pos)
- else:
- result = p_name(s, name)
- s.next()
+ result = p_atom_ident_constants(s)
+ if result is None:
+ result = p_name(s, s.systring)
+ s.next()
return result
else:
s.error("Expected an identifier or literal")
+
+def p_atom_string(s):
+ pos = s.position()
+ kind, bytes_value, unicode_value = p_cat_string_literal(s)
+ if kind == 'c':
+ return ExprNodes.CharNode(pos, value=bytes_value)
+ elif kind == 'u':
+ return ExprNodes.UnicodeNode(pos, value=unicode_value, bytes_value=bytes_value)
+ elif kind == 'b':
+ return ExprNodes.BytesNode(pos, value=bytes_value)
+ elif kind == 'f':
+ return ExprNodes.JoinedStrNode(pos, values=unicode_value)
+ elif kind == '':
+ return ExprNodes.StringNode(pos, value=bytes_value, unicode_value=unicode_value)
+ else:
+ s.error("invalid string kind '%s'" % kind)
+
+
+def p_atom_ident_constants(s, bools_are_pybool=False):
+ """
+ Returns None if it isn't one special-cased named constants.
+ Only calls s.next() if it successfully matches a matches.
+ """
+ pos = s.position()
+ name = s.systring
+ result = None
+ if bools_are_pybool:
+ extra_kwds = {'type': Builtin.bool_type}
+ else:
+ extra_kwds = {}
+ if name == "None":
+ result = ExprNodes.NoneNode(pos)
+ elif name == "True":
+ result = ExprNodes.BoolNode(pos, value=True, **extra_kwds)
+ elif name == "False":
+ result = ExprNodes.BoolNode(pos, value=False, **extra_kwds)
+ elif name == "NULL" and not s.in_python_file:
+ result = ExprNodes.NullNode(pos)
+ if result:
+ s.next()
+ return result
+
+
def p_int_literal(s):
pos = s.position()
value = s.systring
@@ -4031,12 +4053,13 @@ def p_cpp_class_attribute(s, ctx):
node.decorators = decorators
return node
+
def p_match_statement(s, ctx):
assert s.sy == "IDENT" and s.systring == "match"
pos = s.position()
with tentatively_scan(s) as errors:
s.next()
- subject = p_test(s)
+ subject = p_namedexpr_test(s)
subjects = None
if s.sy == ",":
subjects = [subject]
@@ -4050,6 +4073,7 @@ def p_match_statement(s, ctx):
s.expect(":")
if errors:
return None
+
# at this stage were commited to it being a match block so continue
# outside "with tentatively_scan"
# (I think this deviates from the PEG parser slightly, and it'd
@@ -4060,10 +4084,11 @@ def p_match_statement(s, ctx):
while s.sy != "DEDENT":
cases.append(p_case_block(s, ctx))
s.expect_dedent()
- return MatchCaseNodes.MatchNode(pos, subject = subject, cases = cases)
+ return MatchCaseNodes.MatchNode(pos, subject=subject, cases=cases)
+
def p_case_block(s, ctx):
- if not (s.sy=="IDENT" and s.systring == "case"):
+ if not (s.sy == "IDENT" and s.systring == "case"):
s.error("Expected 'case'")
s.next()
pos = s.position()
@@ -4076,8 +4101,10 @@ def p_case_block(s, ctx):
return MatchCaseNodes.MatchCaseNode(pos, pattern=pattern, body=body, guard=guard)
+
def p_patterns(s):
- # note - in slight contrast to the name, returns a single pattern
+ # note - in slight contrast to the name (which comes from the Python grammar),
+ # returns a single pattern
patterns = []
seq = False
pos = s.position()
@@ -4089,9 +4116,9 @@ def p_patterns(s):
break # all is good provided we have at least 1 pattern
else:
e = errors[0]
- s.error(e.args[1], pos = e.args[0])
+ s.error(e.args[1], pos=e.args[0])
patterns.append(pattern)
-
+
if s.sy == ",":
seq = True
s.next()
@@ -4099,11 +4126,13 @@ def p_patterns(s):
break # common reasons to break
else:
break
+
if seq:
- return MatchCaseNodes.MatchSequencePatternNode(pos, patterns = patterns)
+ return MatchCaseNodes.MatchSequencePatternNode(pos, patterns=patterns)
else:
return patterns[0]
+
def p_maybe_star_pattern(s):
# For match case. Either star_pattern or pattern
if s.sy == "*":
@@ -4115,12 +4144,13 @@ def p_maybe_star_pattern(s):
else:
s.next()
pattern = MatchCaseNodes.MatchAndAssignPatternNode(
- s.position(), target = target, is_star = True
+ s.position(), target=target, is_star=True
)
return pattern
else:
- p = p_pattern(s)
- return p
+ pattern = p_pattern(s)
+ return pattern
+
def p_pattern(s):
# try "as_pattern" then "or_pattern"
@@ -4133,13 +4163,15 @@ def p_pattern(s):
s.next()
else:
break
+
if len(patterns) > 1:
pattern = MatchCaseNodes.OrPatternNode(
pos,
- alternatives = patterns
+ alternatives=patterns
)
else:
pattern = patterns[0]
+
if s.sy == 'IDENT' and s.systring == 'as':
s.next()
with tentatively_scan(s) as errors:
@@ -4147,17 +4179,18 @@ def p_pattern(s):
if errors and s.sy == "_":
s.next()
# make this a specific error
- return Nodes.ErrorNode(errors[0].args[0], what = errors[0].args[1])
+ return Nodes.ErrorNode(errors[0].args[0], what=errors[0].args[1])
elif errors:
with tentatively_scan(s):
expr = p_test(s)
- return Nodes.ErrorNode(expr.pos, what = "Invalid pattern target")
+ return Nodes.ErrorNode(expr.pos, what="Invalid pattern target")
s.error(errors[0])
return pattern
def p_closed_pattern(s):
"""
+ The PEG parser specifies it as
| literal_pattern
| capture_pattern
| wildcard_pattern
@@ -4166,42 +4199,49 @@ def p_closed_pattern(s):
| sequence_pattern
| mapping_pattern
| class_pattern
+
+ For the sake avoiding too much backtracking, we know:
+ * starts with "{" is a sequence_pattern
+ * starts with "[" is a mapping_pattern
+ * starts with "(" is a group_pattern or sequence_pattern
+ * wildcard pattern is just identifier=='_'
+ The rest are then tried in order with backtracking
"""
+ if s.sy == 'IDENT' and s.systring == '_':
+ pos = s.position()
+ s.next()
+ return MatchCaseNodes.MatchAndAssignPatternNode(pos)
+ elif s.sy == '{':
+ return p_mapping_pattern(s)
+ elif s.sy == '[':
+ return p_sequence_pattern(s)
+ elif s.sy == '(':
+ with tentatively_scan(s) as errors:
+ result = p_group_pattern(s)
+ if not errors:
+ return result
+ return p_sequence_pattern(s)
+
with tentatively_scan(s) as errors:
result = p_literal_pattern(s)
- if not errors:
- return result
+ if not errors:
+ return result
with tentatively_scan(s) as errors:
result = p_capture_pattern(s)
- if not errors:
- return result
- with tentatively_scan(s) as errors:
- result = p_wildcard_pattern(s)
- if not errors:
- return result
+ if not errors:
+ return result
with tentatively_scan(s) as errors:
result = p_value_pattern(s)
- if not errors:
- return result
- with tentatively_scan(s) as errors:
- result = p_group_pattern(s)
- if not errors:
- return result
- with tentatively_scan(s) as errors:
- result = p_sequence_pattern(s)
- if not errors:
- return result
- with tentatively_scan(s) as errors:
- result = p_mapping_pattern(s)
- if not errors:
- return result
+ if not errors:
+ return result
return p_class_pattern(s)
+
def p_literal_pattern(s):
# a lot of duplication in this function with "p_atom"
next_must_be_a_number = False
sign = ''
- if s.sy in ['+', '-']:
+ if s.sy == '-':
sign = s.sy
sign_pos = s.position()
s.next()
@@ -4216,9 +4256,11 @@ def p_literal_pattern(s):
elif sy == 'FLOAT':
value = s.systring
s.next()
- res = ExprNodes.FloatNode(pos, value = value)
+ res = ExprNodes.FloatNode(pos, value=value)
+
if res and sign == "-":
res = ExprNodes.UnaryMinusNode(sign_pos, operand=res)
+
if res and s.sy in ['+', '-']:
sign = s.sy
s.next()
@@ -4231,61 +4273,43 @@ def p_literal_pattern(s):
res = ExprNodes.binop_node(
add_pos,
sign,
- operand1 = res,
- operand2 = ExprNodes.ImagNode(s.position(), value = value)
+ operand1=res,
+ operand2=ExprNodes.ImagNode(s.position(), value=value)
)
if not res and sy == 'IMAG':
value = s.systring[:-1]
s.next()
- res = ExprNodes.ImagNode(pos, value = sign+value)
+ res = ExprNodes.ImagNode(pos, value=sign+value)
if sign == "-":
res = ExprNodes.UnaryMinusNode(sign_pos, operand=res)
if res:
- return MatchCaseNodes.MatchValuePatternNode(pos, value = res)
+ return MatchCaseNodes.MatchValuePatternNode(pos, value=res)
+ if next_must_be_a_number:
+ s.error("Expected a number")
if sy == 'BEGIN_STRING':
- if next_must_be_a_number:
- s.error("Expected a number")
- kind, bytes_value, unicode_value = p_cat_string_literal(s)
- if kind == 'c':
- res = ExprNodes.CharNode(pos, value = bytes_value)
- elif kind == 'u':
- res = ExprNodes.UnicodeNode(pos, value = unicode_value, bytes_value = bytes_value)
- elif kind == 'b':
- res = ExprNodes.BytesNode(pos, value = bytes_value)
- elif kind == 'f':
- res = Nodes.ErrorNode(pos, what = "f-strings are not accepted for pattern matching")
- elif kind == '':
- res = ExprNodes.StringNode(pos, value = bytes_value, unicode_value = unicode_value)
- else:
- s.error("invalid string kind '%s'" % kind)
- return MatchCaseNodes.MatchValuePatternNode(pos, value = res)
+ res = p_atom_string(s)
+ # f-strings not being accepted is validated in PostParse
+ return MatchCaseNodes.MatchValuePatternNode(pos, value=res)
elif sy == 'IDENT':
- name = s.systring
- result = None
- if name == "None":
- result = ExprNodes.NoneNode(pos)
- elif name == "True":
- result = ExprNodes.BoolNode(pos, value=True, type=Builtin.bool_type)
- elif name == "False":
- result = ExprNodes.BoolNode(pos, value=False, type=Builtin.bool_type)
- elif name == "NULL" and not s.in_python_file:
- # Included Null as an exactly matched constant here
- result = ExprNodes.NullNode(pos)
+ # Note that p_atom_ident_constants includes NULL.
+ # This is a deliberate Cython addition to the pattern matching specification
+ result = p_atom_ident_constants(s, bools_are_pybool=True)
if result:
- s.next()
- return MatchCaseNodes.MatchValuePatternNode(pos, value = result, is_is_check = True)
+ return MatchCaseNodes.MatchValuePatternNode(pos, value=result, is_is_check=True)
s.error("Failed to match literal")
+
def p_capture_pattern(s):
return MatchCaseNodes.MatchAndAssignPatternNode(
s.position(),
- target = p_pattern_capture_target(s)
+ target=p_pattern_capture_target(s)
)
+
def p_value_pattern(s):
if s.sy != "IDENT":
s.error("Expected identifier")
@@ -4298,10 +4322,11 @@ def p_value_pattern(s):
attr_pos = s.position()
s.next()
attr = p_ident(s)
- res = ExprNodes.AttributeNode(attr_pos, obj = res, attribute=attr)
+ res = ExprNodes.AttributeNode(attr_pos, obj=res, attribute=attr)
if s.sy in ['(', '=']:
s.error("Unexpected symbol '%s'" % s.sy)
- return MatchCaseNodes.MatchValuePatternNode(pos, value = res)
+ return MatchCaseNodes.MatchValuePatternNode(pos, value=res)
+
def p_group_pattern(s):
s.expect("(")
@@ -4309,12 +4334,6 @@ def p_group_pattern(s):
s.expect(")")
return pattern
-def p_wildcard_pattern(s):
- if s.sy != "IDENT" or s.systring != "_":
- s.error("Expected '_'")
- pos = s.position()
- s.next()
- return MatchCaseNodes.MatchAndAssignPatternNode(pos)
def p_sequence_pattern(s):
opener = s.sy
@@ -4334,28 +4353,32 @@ def p_sequence_pattern(s):
if s.sy == closer:
break
else:
- if opener == ')' and len(patterns)==1:
+ if opener == ')' and len(patterns) == 1:
s.error("tuple-like pattern of length 1 must finish with ','")
break
s.expect(closer)
return MatchCaseNodes.MatchSequencePatternNode(pos, patterns=patterns)
else:
- s.error("Expected '[' or '('")
+ s.error("Expected '[' or '('")
+
def p_mapping_pattern(s):
pos = s.position()
s.expect('{')
if s.sy == '}':
+ # trivial empty mapping
s.next()
return MatchCaseNodes.MatchMappingPatternNode(pos)
+
double_star_capture_target = None
items_patterns = []
double_star_set_twice = None
pattern_after_double_star = None
+ star_star_arg_pos = None
while True:
+ if double_star_capture_target and not star_star_arg_pos:
+ star_star_arg_pos = s.position()
if s.sy == '**':
- if double_star_capture_target:
- double_star_set_twice = s.position()
s.next()
double_star_capture_target = p_pattern_capture_target(s)
else:
@@ -4384,6 +4407,11 @@ def p_mapping_pattern(s):
return Nodes.ErrorNode(double_star_set_twice, what = "Double star capture set twice")
if pattern_after_double_star:
return Nodes.ErrorNode(pattern_after_double_star, what = "pattern follows ** capture")
+ if star_star_arg_pos is not None:
+ return Nodes.ErrorNode(
+ star_star_arg_pos,
+ what = "** pattern must be the final part of a mapping pattern."
+ )
return MatchCaseNodes.MatchMappingPatternNode(
pos,
keys = [kv[0] for kv in items_patterns],
@@ -4391,8 +4419,9 @@ def p_mapping_pattern(s):
double_star_capture_target = double_star_capture_target
)
+
def p_class_pattern(s):
- # name_or_attr
+ # start by parsing the class as name_or_attr
pos = s.position()
res = p_name(s, s.systring)
s.next()
@@ -4400,12 +4429,16 @@ def p_class_pattern(s):
attr_pos = s.position()
s.next()
attr = p_ident(s)
- res = ExprNodes.AttributeNode(attr_pos, obj = res, attribute=attr)
+ res = ExprNodes.AttributeNode(attr_pos, obj=res, attribute=attr)
class_ = res
+
s.expect("(")
if s.sy == ")":
+ # trivial case with no arguments matched
s.next()
return MatchCaseNodes.ClassPatternNode(pos, class_=class_)
+
+ # parse the arguments
positional_patterns = []
keyword_patterns = []
keyword_patterns_error = None
@@ -4418,17 +4451,17 @@ def p_class_pattern(s):
else:
with tentatively_scan(s) as errors:
keyword_patterns.append(p_keyword_pattern(s))
- if s.sy == ",":
- s.next()
- if s.sy == ")":
- break
- else:
+ if s.sy != ",":
break
+ s.next()
+ if s.sy == ")":
+ break # Allow trailing comma.
s.expect(")")
+
if keyword_patterns_error is not None:
return Nodes.ErrorNode(
keyword_patterns_error,
- what = "Positional patterns follow keyword patterns"
+ what="Positional patterns follow keyword patterns"
)
return MatchCaseNodes.ClassPatternNode(
pos, class_ = class_,
@@ -4437,6 +4470,7 @@ def p_class_pattern(s):
keyword_pattern_patterns = [kv[1] for kv in keyword_patterns],
)
+
def p_keyword_pattern(s):
if s.sy != "IDENT":
s.error("Expected identifier")
@@ -4446,6 +4480,7 @@ def p_keyword_pattern(s):
value = p_pattern(s)
return arg, value
+
def p_pattern_capture_target(s):
# any name but '_', and with some constraints on what follows
if s.sy != 'IDENT':
diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py
index 8bcd26b6f..45a8e6f59 100644
--- a/Cython/TestUtils.py
+++ b/Cython/TestUtils.py
@@ -12,9 +12,10 @@ from functools import partial
from .Compiler import Errors
from .CodeWriter import CodeWriter
-from .Compiler.TreeFragment import TreeFragment, strip_common_indent
+from .Compiler.TreeFragment import TreeFragment, strip_common_indent, StringParseContext
from .Compiler.Visitor import TreeVisitor, VisitorTransform
from .Compiler import TreePath
+from .Compiler.ParseTreeTransforms import PostParse
class NodeTypeWriter(TreeVisitor):
@@ -357,3 +358,24 @@ def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None
while other_time is None or other_time >= os.path.getmtime(file_path):
write_file(file_path, content, dedent=dedent, encoding=encoding)
+
+
+def py_parse_code(code):
+ """
+ Compiles code far enough to get errors from the parser and post-parse stage.
+
+ Is useful for checking for syntax errors, however it doesn't generate runable
+ code.
+ """
+ context = StringParseContext("test")
+ # all the errors we care about are in the parsing or postparse stage
+ try:
+ with Errors.local_errors() as errors:
+ result = TreeFragment(code, pipeline=[PostParse(context)])
+ result = result.substitute()
+ if errors:
+ raise errors[0] # compile error, which should get caught
+ else:
+ return result
+ except Errors.CompileError as e:
+ raise SyntaxError(e.message_only)
diff --git a/Cython/Utility/MatchCase.c b/Cython/Utility/MatchCase.c
index fea08e0c1..55c8b99b4 100644
--- a/Cython/Utility/MatchCase.c
+++ b/Cython/Utility/MatchCase.c
@@ -1,28 +1,14 @@
///////////////////////////// ABCCheck //////////////////////////////
#if PY_VERSION_HEX < 0x030A0000
-static int __Pyx_MatchCase_IsExactSequence(PyObject *o) {
+static CYTHON_INLINE int __Pyx_MatchCase_IsExactSequence(PyObject *o) {
// is one of the small list of builtin types known to be a sequence
- if (PyList_CheckExact(o) || PyTuple_CheckExact(o)) {
+ if (PyList_CheckExact(o) || PyTuple_CheckExact(o) ||
+ PyType_CheckExact(o, PyRange_Type) || PyType_CheckExact(o, PyMemoryView_Type)) {
// Use exact type match for these checks. I in the event of inheritence we need to make sure
// that it isn't a mapping too
return 1;
}
- if (PyRange_Check(o) || PyMemoryView_Check(o)) {
- // Exact check isn't possible so do exact check in another way
- PyObject *mro = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__mro__");
- if (mro) {
- Py_ssize_t len = PyObject_Length(mro);
- Py_DECREF(mro);
- if (len < 0) {
- PyErr_Clear(); // doesn't really matter, just proceed with other checks
- } else if (len == 2) {
- return 1; // the type and "object" and no other bases
- }
- } else {
- PyErr_Clear(); // doesn't really matter, just proceed with other checks
- }
- }
return 0;
}
@@ -34,10 +20,13 @@ static CYTHON_INLINE int __Pyx_MatchCase_IsExactMapping(PyObject *o) {
}
static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) {
- if (PyUnicode_Check(o) || PyBytes_Check(o) || PyByteArray_Check(o)) {
+ if (PyType_GetFlags(Py_TYPE(o)) & (Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS)) ||
+ PyByteArray_Check(o)) {
return 1; // these types are deliberately excluded from the sequence test
// even though they look like sequences for most other purposes.
- // They're therefore "inexact" checks
+ // Leave them as inexact checks since they do pass
+ // "isinstance(o, collections.abc.Sequence)" so it's very hard to
+ // reason about their subclasses
}
if (o == Py_None || PyLong_CheckExact(o) || PyFloat_CheckExact(o)) {
return 1;
@@ -73,6 +62,16 @@ static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) {
#define __PYX_SEQUENCE_MAPPING_ERROR (1U<<4) // only used by the ABCCheck function
#endif
+static int __Pyx_MatchCase_InitAndIsInstanceAbc(PyObject *o, PyObject *abc_module,
+ PyObject **abc_type, PyObject *name) {
+ assert(!abc_type);
+ abc_type = PyObject_GetAttr(abc_module, name);
+ if (!abc_type) {
+ return -1;
+ }
+ return PyObject_IsInstance(o, abc_type);
+}
+
// the result is defined using the specification for sequence_mapping_temp
// (detailed in "is_sequence")
static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, int definitely_not_sequence, int definitely_not_mapping) {
@@ -101,12 +100,7 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in
result = __PYX_DEFINITELY_SEQUENCE_FLAG;
goto end;
}
- sequence_type = PyObject_GetAttr(abc_module, PYIDENT("Sequence"));
- if (!sequence_type) {
- result = __PYX_SEQUENCE_MAPPING_ERROR;
- goto end;
- }
- sequence_result = PyObject_IsInstance(o, sequence_type);
+ sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence"));
if (sequence_result < 0) {
result = __PYX_SEQUENCE_MAPPING_ERROR;
goto end;
@@ -114,41 +108,32 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in
result |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG;
goto end;
}
- // else wait to see what mapping is
+ // else wait to see what mapping is
}
if (!definitely_not_mapping) {
- mapping_type = PyObject_GetAttr(abc_module, PYIDENT("Mapping"));
- if (!mapping_type) {
+ mapping_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &mapping_type, PYIDENT("Mapping"));
+ if (mapping_result < 0) {
+ result = __PYX_SEQUENCE_MAPPING_ERROR;
goto end;
- }
- mapping_result = PyObject_IsInstance(o, mapping_type);
- }
- if (mapping_result < 0) {
- result = __PYX_SEQUENCE_MAPPING_ERROR;
- goto end;
- } else if (mapping_result == 0) {
- result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG;
- if (sequence_first) {
- assert(sequence_result);
- result |= __PYX_DEFINITELY_SEQUENCE_FLAG;
- }
- goto end;
- } else /* mapping_result == 1 */ {
- if (sequence_first && !sequence_result) {
- result |= __PYX_DEFINITELY_MAPPING_FLAG;
+ } else if (mapping_result == 0) {
+ result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG;
+ if (sequence_first) {
+ assert(sequence_result);
+ result |= __PYX_DEFINITELY_SEQUENCE_FLAG;
+ }
goto end;
+ } else /* mapping_result == 1 */ {
+ if (sequence_first && !sequence_result) {
+ result |= __PYX_DEFINITELY_MAPPING_FLAG;
+ goto end;
+ }
}
}
if (!sequence_first) {
// here we know mapping_result is true because we'd have returned otherwise
assert(mapping_result);
if (!definitely_not_sequence) {
- sequence_type = PyObject_GetAttr(abc_module, PYIDENT("Sequence"));
- if (!sequence_type) {
- result = __PYX_SEQUENCE_MAPPING_ERROR;
- goto end;
- }
- sequence_result = PyObject_IsInstance(o, sequence_type);
+ sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence"));
}
if (sequence_result < 0) {
result = __PYX_SEQUENCE_MAPPING_ERROR;
@@ -167,7 +152,7 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in
if (!mro) {
PyErr_Clear();
goto end;
- }
+ }
if (!PyTuple_Check(mro)) {
Py_DECREF(mro);
goto end;
@@ -322,7 +307,7 @@ static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_
PyObject *list;
ssizeargfunc slot;
PyTypeObject *type = Py_TYPE(x);
-
+
list = PyList_New(total);
if (!list) {
return NULL;
@@ -454,17 +439,15 @@ static int __Pyx_MatchCase_IsMapping(PyObject *o, unsigned int *sequence_mapping
#endif
}
-//////////////////////// DuplicateKeyCheck.proto ///////////////////////
+//////////////////////// MappingKeyCheck.proto /////////////////////////
-// Returns an new reference to any duplicate key.
-// NULL can indicate no duplicate keys or an error (so use PyErr_Occurred)
-static PyObject* __Pyx_MatchCase_CheckDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys, Py_ssize_t n_var_keys); /*proto*/
+static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *keys[], Py_ssize_t nFixedKeys, Py_ssize_t nKeys);
-//////////////////////// DuplicateKeyCheck /////////////////////////////
+//////////////////////// MappingKeyCheck ///////////////////////////////
-static PyObject* __Pyx_MatchCase_CheckDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys, Py_ssize_t n_var_keys) {
- // Inputs are tuples, and typically fairly small. It may be more efficient to
- // loop over the tuple than create a set.
+static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *keys[], Py_ssize_t nFixedKeys, Py_ssize_t nKeys) {
+ // Inputs are arrays, and typically fairly small. It may be more efficient to
+ // loop over the array than create a set.
// The CPython implementation (match_keys in ceval.c) does this concurrently with
// taking the keys out of the dictionary. I'm choosing to do it separately since the
@@ -472,71 +455,55 @@ static PyObject* __Pyx_MatchCase_CheckDuplicateKeys(PyObject *fixed_keys, PyObje
// this step completely.
PyObject *var_keys_set;
- PyObject *key = NULL;
+ PyObject *key;
Py_ssize_t n;
int contains;
var_keys_set = PySet_New(NULL);
- if (!var_keys_set) return NULL;
+ if (!var_keys_set) return -1;
- n_var_keys = (n_var_keys < 0) ? PyTuple_GET_SIZE(var_keys) : n_var_keys;
- for (n=0; n < n_var_keys; ++n) {
- key = PyTuple_GET_ITEM(var_keys, n);
+ for (n=nFixedKeys; n < nKeys; ++n) {
+ key = keys[n];
contains = PySet_Contains(var_keys_set, key);
if (contains < 0) {
- key = NULL;
- goto end;
+ goto bad;
} else if (contains == 1) {
- Py_INCREF(key);
- goto end;
+ goto raise_error;
} else {
if (PySet_Add(var_keys_set, key)) {
- key = NULL;
- goto end;
+ goto bad;
}
}
}
- for (n=0; n < PyTuple_GET_SIZE(fixed_keys); ++n) {
- key = PyTuple_GET_ITEM(fixed_keys, n);
+ for (n=0; n < nFixedKeys; ++n) {
+ key = keys[n];
contains = PySet_Contains(var_keys_set, key);
if (contains < 0) {
- key = NULL;
- goto end;
+ goto bad;
} else if (contains == 1) {
- Py_INCREF(key);
- goto end;
+ goto raise_error;
}
}
- key = NULL;
- end:
Py_DECREF(var_keys_set);
- return key;
-}
-
-//////////////////////// MappingKeyCheck.proto /////////////////////////
-
-static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys); /* proto */
-
-//////////////////////// MappingKeyCheck ///////////////////////////////
-//@requires: DuplicateKeyCheck
-
-static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *fixed_keys, PyObject *var_keys) {
- PyObject *key = __Pyx_MatchCase_CheckDuplicateKeys(fixed_keys, var_keys, -1);
- if (key) {
- PyErr_Format(PyExc_ValueError, "mapping pattern checks duplicate key (%R)", key);
- Py_DECREF(key);
- return -1;
- } else if (PyErr_Occurred()) {
- return -1;
- } else {
- return 0;
- }
+ return 0;
+
+ raise_error:
+ #if PY_MAJOR_VERSION > 2
+ PyErr_Format(PyExc_ValueError,
+ "mapping pattern checks duplicate key (%R)", key);
+ #else
+ // DW really can't be bothered working around features that don't exist in
+ // Python 2, so just provide less information!
+ PyErr_SetString(PyExc_ValueError,
+ "mapping pattern checks duplicate key");
+ #endif
+ bad:
+ Py_DECREF(var_keys_set);
+ return -1;
}
/////////////////////////// ExtractExactDict.proto ////////////////
-#include <stdarg.h>
-
// the variadic arguments are a list of PyObject** to subjects to be filled. They may be NULL
// in which case they're ignored.
//
@@ -544,48 +511,31 @@ static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *fixed_keys, PyObj
#if CYTHON_REFNANNY
#define __Pyx_MatchCase_Mapping_ExtractDict(...) __Pyx__MatchCase_Mapping_ExtractDict(__pyx_refnanny, __VA_ARGS__)
-#define __Pyx_MatchCase_Mapping_ExtractDictV(...) __Pyx__MatchCase_Mapping_ExtractDictV(__pyx_refnanny, __VA_ARGS__)
#else
#define __Pyx_MatchCase_Mapping_ExtractDict(...) __Pyx__MatchCase_Mapping_ExtractDict(NULL, __VA_ARGS__)
-#define __Pyx_MatchCase_Mapping_ExtractDictV(...) __Pyx__MatchCase_Mapping_ExtractDictV(NULL, __VA_ARGS__)
#endif
-static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, ...); /* proto */
-static int __Pyx__MatchCase_Mapping_ExtractDictV(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, va_list subjects); /* proto */
+static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */
/////////////////////////// ExtractExactDict ////////////////
-static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, ...) {
- int result;
- va_list subjects;
-
- va_start(subjects, var_keys);
- result = __Pyx_MatchCase_Mapping_ExtractDictV(dict, fixed_keys, var_keys, subjects);
- va_end(subjects);
- return result;
-}
+static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) {
+ Py_ssize_t i;
-static int __Pyx__MatchCase_Mapping_ExtractDictV(void *__pyx_refnanny, PyObject *dict, PyObject *fixed_keys, PyObject *var_keys, va_list subjects) {
- PyObject *keys[] = {fixed_keys, var_keys};
- Py_ssize_t i, j;
-
- for (i=0; i<2; ++i) {
- PyObject *tuple = keys[i];
- for (j=0; j<PyTuple_GET_SIZE(tuple); ++j) {
- PyObject *key = PyTuple_GET_ITEM(tuple, j);
- PyObject **subject = va_arg(subjects, PyObject**);
- if (!subject) {
- int contains = PyDict_Contains(dict, key);
- if (contains <= 0) {
- return -1; // any subjects that were already set will be cleaned up externally
- }
- } else {
- PyObject *value = __Pyx_PyDict_GetItemStrWithError(dict, key);
- if (!value) {
- return (PyErr_Occurred()) ? -1 : 0; // any subjects that were already set will be cleaned up externally
- }
- __Pyx_XDECREF_SET(*subject, value);
- __Pyx_INCREF(*subject); // capture this incref with refnanny!
+ for (i=0; i<nKeys; ++i) {
+ PyObject *key = keys[i];
+ PyObject **subject = subjects[i];
+ if (!subject) {
+ int contains = PyDict_Contains(dict, key);
+ if (contains <= 0) {
+ return -1; // any subjects that were already set will be cleaned up externally
+ }
+ } else {
+ PyObject *value = __Pyx_PyDict_GetItemStrWithError(dict, key);
+ if (!value) {
+ return (PyErr_Occurred()) ? -1 : 0; // any subjects that were already set will be cleaned up externally
}
+ __Pyx_XDECREF_SET(*subject, value);
+ __Pyx_INCREF(*subject); // capture this incref with refnanny!
}
}
return 1; // success
@@ -598,74 +548,74 @@ static int __Pyx__MatchCase_Mapping_ExtractDictV(void *__pyx_refnanny, PyObject
//
// This is a specialized version for the rarer case when the type isn't an exact dict.
-#include <stdarg.h>
-
#if CYTHON_REFNANNY
#define __Pyx_MatchCase_Mapping_ExtractNonDict(...) __Pyx__MatchCase_Mapping_ExtractNonDict(__pyx_refnanny, __VA_ARGS__)
-#define __Pyx_MatchCase_Mapping_ExtractNonDictV(...) __Pyx__MatchCase_Mapping_ExtractNonDictV(__pyx_refnanny, __VA_ARGS__)
#else
#define __Pyx_MatchCase_Mapping_ExtractNonDict(...) __Pyx__MatchCase_Mapping_ExtractNonDict(NULL, __VA_ARGS__)
-#define __Pyx_MatchCase_Mapping_ExtractNonDictV(...) __Pyx__MatchCase_Mapping_ExtractNonDictV(NULL, __VA_ARGS__)
#endif
-static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *fixed_keys, PyObject *var_keys, ...); /* proto */
-static int __Pyx__MatchCase_Mapping_ExtractNonDictV(void *__pyx_refnanny, PyObject *mapping, PyObject *fixed_keys, PyObject *var_keys, va_list subjects); /* proto */
+static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */
///////////////////////// ExtractNonDict //////////////////////////////////////
//@requires: ObjectHandling.c::PyObjectCall2Args
// largely adapted from match_keys in CPython ceval.c
-static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, ...) {
- int result;
- va_list subjects;
-
- va_start(subjects, var_keys);
- result = __Pyx_MatchCase_Mapping_ExtractNonDictV(map, fixed_keys, var_keys, subjects);
- va_end(subjects);
- return result;
-}
-
-static int __Pyx__MatchCase_Mapping_ExtractNonDictV(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, va_list subjects) {
+static int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) {
PyObject *dummy=NULL, *get=NULL;
- PyObject *keys[] = {fixed_keys, var_keys};
- Py_ssize_t i, j;
+ Py_ssize_t i;
int result = 0;
+#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL
+ PyObject *get_method = NULL, *get_self = NULL;
+#endif
dummy = PyObject_CallObject((PyObject *)&PyBaseObject_Type, NULL);
if (!dummy) {
return -1;
}
- get = PyObject_GetAttrString(map, "get");
+ get = PyObject_GetAttrString(mapping, "get");
if (!get) {
result = -1;
goto end;
}
+#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL
+ if (likely(PyMethod_Check(get))) {
+ // both of these are borrowed
+ get_method = PyMethod_GET_FUNCTION(get);
+ get_self = PyMethod_GET_SELF(get);
+ }
+#endif
- for (i=0; i<2; ++i) {
- PyObject *tuple = keys[i];
- for (j=0; j<PyTuple_GET_SIZE(tuple); ++j) {
- PyObject **subject;
- PyObject *value = NULL;
- PyObject *key = PyTuple_GET_ITEM(tuple, j);
-
- // TODO - there's an optimization here (although it deviates from the strict definition of pattern matching).
- // If we don't need the values then we can call PyObject_Contains instead of "get". If we don't need *any*
- // of the values then we can skip initialization "get" and "dummy"
+ for (i=0; i<nKeys; ++i) {
+ PyObject **subject;
+ PyObject *value = NULL;
+ PyObject *key = keys[i];
+
+ // TODO - there's an optimization here (although it deviates from the strict definition of pattern matching).
+ // If we don't need the values then we can call PyObject_Contains instead of "get". If we don't need *any*
+ // of the values then we can skip initialization "get" and "dummy"
+#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL
+ if (likely(get_method)) {
+ PyObject *args[] = { get_self, key, dummy };
+ value = _PyObject_Vectorcall(get_method, args, 3, NULL);
+ }
+ else
+#endif
+ {
value = __Pyx_PyObject_Call2Args(get, key, dummy);
- if (!value) {
- result = -1;
- goto end;
- } else if (value == dummy) {
- Py_DECREF(value);
- goto end; // failed
+ }
+ if (!value) {
+ result = -1;
+ goto end;
+ } else if (value == dummy) {
+ Py_DECREF(value);
+ goto end; // failed
+ } else {
+ subject = subjects[i];
+ if (subject) {
+ __Pyx_XDECREF_SET(*subject, value);
+ __Pyx_GOTREF(*subject);
} else {
- subject = va_arg(subjects, PyObject**);
- if (subject) {
- __Pyx_XDECREF_SET(*subject, value);
- __Pyx_GOTREF(*subject);
- } else {
- Py_DECREF(value);
- }
+ Py_DECREF(value);
}
}
}
@@ -679,36 +629,28 @@ static int __Pyx__MatchCase_Mapping_ExtractNonDictV(void *__pyx_refnanny, PyObje
///////////////////////// ExtractGeneric.proto ////////////////////////////////
-#include <stdarg.h>
-
#if CYTHON_REFNANNY
#define __Pyx_MatchCase_Mapping_Extract(...) __Pyx__MatchCase_Mapping_Extract(__pyx_refnanny, __VA_ARGS__)
#else
#define __Pyx_MatchCase_Mapping_Extract(...) __Pyx__MatchCase_Mapping_Extract(NULL, __VA_ARGS__)
#endif
-static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, ...); /* proto */
+static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */
////////////////////// ExtractGeneric //////////////////////////////////////
//@requires: ExtractExactDict
//@requires: ExtractNonDict
-static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *map, PyObject *fixed_keys, PyObject *var_keys, ...) {
- va_list subjects;
- int result;
-
- va_start(subjects, var_keys);
- if (PyDict_CheckExact(map)) {
- result = __Pyx_MatchCase_Mapping_ExtractDictV(map, fixed_keys, var_keys, subjects);
+static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) {
+ if (PyDict_CheckExact(mapping)) {
+ return __Pyx_MatchCase_Mapping_ExtractDict(mapping, keys, nKeys, subjects);
} else {
- result = __Pyx_MatchCase_Mapping_ExtractNonDictV(map, fixed_keys, var_keys, subjects);
+ return __Pyx_MatchCase_Mapping_ExtractNonDict(mapping, keys, nKeys, subjects);
}
- va_end(subjects);
- return result;
}
///////////////////////////// DoubleStarCapture.proto //////////////////////
-static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObject *const_temps, PyObject *var_temps); /* proto */
+static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys); /* proto */
//////////////////////////// DoubleStarCapture //////////////////////////////
@@ -717,31 +659,30 @@ static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObjec
// https://github.com/python/cpython/blob/145bf269df3530176f6ebeab1324890ef7070bf8/Python/ceval.c#L3977
// (now removed in favour of building the same thing from a combination of opcodes)
// The differences are:
-// 1. We loop over separate tuples for constant and runtime keys
-// 2. We add a shortcut for when there will be no left over keys (because I'm guess it's pretty common)
+// 1. We use an array of keys rather than a tuple of keys
+// 2. We add a shortcut for when there will be no left over keys (because I guess it's pretty common)
//
// Tempita variable 'tag' can be "NonDict", "ExactDict" or empty
-static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObject *const_temps, PyObject *var_temps) {
+static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys) {
PyObject *dict_out;
- PyObject *tuples[] = { const_temps, var_temps };
- Py_ssize_t i, j;
+ Py_ssize_t i;
{{if tag != "NonDict"}}
// shortcut for when there are no left-over keys
- if ({{if tag=="ExactDict"}}(1){{else}}PyDict_CheckExact(map){{endif}}) {
- Py_ssize_t s = PyDict_Size(map);
+ if ({{if tag=="ExactDict"}}(1){{else}}PyDict_CheckExact(mapping){{endif}}) {
+ Py_ssize_t s = PyDict_Size(mapping);
if (s == -1) {
return NULL;
}
- if (s == (PyTuple_GET_SIZE(const_temps) + PyTuple_GET_SIZE(var_temps))) {
+ if (s == nKeys) {
return PyDict_New();
}
}
{{endif}}
{{if tag=="ExactDict"}}
- dict_out = PyDict_Copy(map);
+ dict_out = PyDict_Copy(mapping);
{{else}}
dict_out = PyDict_New();
{{endif}}
@@ -749,19 +690,16 @@ static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObjec
return NULL;
}
{{if tag!="ExactDict"}}
- if (PyDict_Update(dict_out, map)) {
+ if (PyDict_Update(dict_out, mapping)) {
Py_DECREF(dict_out);
return NULL;
}
{{endif}}
- for (i=0; i<2; ++i) {
- PyObject *keys = tuples[i];
- for (j=0; j<PyTuple_GET_SIZE(keys); ++j) {
- if (PyDict_DelItem(dict_out, PyTuple_GET_ITEM(keys, j))) {
- Py_DECREF(dict_out);
- return NULL;
- }
+ for (i=0; i<nKeys; ++i) {
+ if (PyDict_DelItem(dict_out, keys[i])) {
+ Py_DECREF(dict_out);
+ return NULL;
}
}
return dict_out;
@@ -769,30 +707,81 @@ static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *map, PyObjec
////////////////////////////// ClassPositionalPatterns.proto ////////////////////////
-#include <stdarg.h>
-
#if CYTHON_REFNANNY
#define __Pyx_MatchCase_ClassPositional(...) __Pyx__MatchCase_ClassPositional(__pyx_refnanny, __VA_ARGS__)
#else
#define __Pyx_MatchCase_ClassPositional(...) __Pyx__MatchCase_ClassPositional(NULL, __VA_ARGS__)
#endif
-static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *keysnames_tuple, int match_self, int num_args, ...); /* proto */
+static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *fixed_names[], Py_ssize_t n_fixed, int match_self, PyObject **subjects[], Py_ssize_t n_subjects); /* proto */
/////////////////////////////// ClassPositionalPatterns //////////////////////////////
-//@requires: DuplicateKeyCheck
+
+static int __Pyx_MatchCase_ClassCheckDuplicateAttrs(const char *tp_name, PyObject *fixed_names[], Py_ssize_t n_fixed, PyObject *match_args, Py_ssize_t num_args) {
+ // a lot of the basic logic of this is shared with __Pyx_MatchCase_CheckMappingDuplicateKeys
+ // but they take different input types so it isn't easy to actually share the code.
+
+ // Inputs are tuples, and typically fairly small. It may be more efficient to
+ // loop over the tuple than create a set.
+
+ PyObject *attrs_set;
+ PyObject *attr = NULL;
+ Py_ssize_t n;
+ int contains;
+
+ attrs_set = PySet_New(NULL);
+ if (!attrs_set) return -1;
+
+ num_args = PyTuple_GET_SIZE(match_args) < num_args ? PyTuple_GET_SIZE(match_args) : num_args;
+ for (n=0; n < num_args; ++n) {
+ attr = PyTuple_GET_ITEM(match_args, n);
+ contains = PySet_Contains(attrs_set, attr);
+ if (contains < 0) {
+ goto bad;
+ } else if (contains == 1) {
+ goto raise_error;
+ } else {
+ if (PySet_Add(attrs_set, attr)) {
+ goto bad;
+ }
+ }
+ }
+ for (n=0; n < n_fixed; ++n) {
+ attr = fixed_names[n];
+ contains = PySet_Contains(attrs_set, attr);
+ if (contains < 0) {
+ goto bad;
+ } else if (contains == 1) {
+ goto raise_error;
+ }
+ }
+ Py_DECREF(attrs_set);
+ return 0;
+
+ raise_error:
+ #if PY_MAJOR_VERSION > 2
+ PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute %R",
+ tp_name, attr);
+ #else
+ // DW has no interest in working around the lack of %R in Python 2.7
+ PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute",
+ tp_name);
+ #endif
+ bad:
+ Py_DECREF(attrs_set);
+ return -1;
+}
// Adapted from ceval.c "match_class" in CPython
//
// The argument match_self can equal 1 for "known to be true"
// 0 for "known to be false"
// -1 for "unknown", runtime test
-
-static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *keysnames_tuple, int match_self, int num_args, ...)
+// nargs is >= 0 otherwise this function will be skipped
+static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *fixed_names[], Py_ssize_t n_fixed, int match_self, PyObject **subjects[], Py_ssize_t n_subjects)
{
- PyObject *match_args, *dup_key;
+ PyObject *match_args;
Py_ssize_t allowed, i;
int result;
- va_list subjects;
match_args = PyObject_GetAttrString((PyObject*)type, "__match_args__");
if (!match_args) {
@@ -805,19 +794,17 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj
_Py_TPFLAGS_MATCH_SELF);
#else
// probably an earlier version of Python. Go off the known list in the specification
- match_self = (PyType_IsSubtype(type, &PyByteArray_Type) ||
- PyType_IsSubtype(type, &PyBytes_Type) ||
- PyType_IsSubtype(type, &PyDict_Type) ||
+ match_self = ((PyType_GetFlags(type) &
+ // long should capture bool too
+ (Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS | Py_TPFLAGS_TUPLE_SUBCLASS |
+ Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS | Py_TPFLAGS_DICT_SUBCLASS
+ #if PY_MAJOR_VERSION < 3
+ | Py_TPFLAGS_IN_SUBCLASS
+ #endif
+ )) ||
+ PyType_IsSubtype(type, &PyByteArray_Type) ||
PyType_IsSubtype(type, &PyFloat_Type) ||
PyType_IsSubtype(type, &PyFrozenSet_Type) ||
- PyType_IsSubtype(type, &PyLong_Type) || // This should capture bool too
- #if PY_MAJOR_VERSION < 3
- PyType_IsSubtype(type, &PyInt_Type) ||
- #endif
- PyType_IsSubtype(type, &PyList_Type) ||
- PyType_IsSubtype(type, &PySet_Type) ||
- PyType_IsSubtype(type, &PyUnicode_Type) ||
- PyType_IsSubtype(type, &PyTuple_Type)
);
#endif
}
@@ -838,18 +825,17 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj
allowed = match_self ?
1 : (match_args ? PyTuple_GET_SIZE(match_args) : 0);
- if (allowed < num_args) {
+ if (allowed < n_subjects) {
const char *plural = (allowed == 1) ? "" : "s";
PyErr_Format(PyExc_TypeError,
"%s() accepts %d positional sub-pattern%s (%d given)",
type->tp_name,
- allowed, plural, num_args);
+ allowed, plural, n_subjects);
Py_XDECREF(match_args);
return -1;
}
- va_start(subjects, num_args);
if (match_self) {
- PyObject **self_subject = va_arg(subjects, PyObject**);
+ PyObject **self_subject = subjects[0];
if (self_subject) {
// Easy. Copy the subject itself, and move on to kwargs.
__Pyx_XDECREF_SET(*self_subject, subject);
@@ -858,20 +844,13 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj
result = 1;
goto end_match_self;
}
- // next stage is to check for duplicate keys. Reuse code from mapping
- dup_key = __Pyx_MatchCase_CheckDuplicateKeys(keysnames_tuple, match_args, num_args);
- if (dup_key) {
- PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute %R",
- type->tp_name, dup_key);
- Py_DECREF(dup_key);
- result = -1;
- goto end;
- } else if (PyErr_Occurred()) {
+ // next stage is to check for duplicate attributes.
+ if (__Pyx_MatchCase_ClassCheckDuplicateAttrs(type->tp_name, fixed_names, n_fixed, match_args, n_subjects)) {
result = -1;
goto end;
}
- for (i = 0; i < num_args; i++) {
+ for (i = 0; i < n_subjects; i++) {
PyObject *attr;
PyObject **subject_i;
PyObject *name = PyTuple_GET_ITEM(match_args, i);
@@ -889,7 +868,7 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj
result = 0;
goto end;
}
- subject_i = va_arg(subjects, PyObject**);
+ subject_i = subjects[i];
if (subject_i) {
__Pyx_XDECREF_SET(*subject_i, attr);
__Pyx_GOTREF(attr);
@@ -902,7 +881,6 @@ static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subj
end:
Py_DECREF(match_args);
end_match_self: // because match_args isn't set
- va_end(subjects);
return result;
}
@@ -913,6 +891,13 @@ static PyTypeObject* __Pyx_MatchCase_IsType(PyObject* type); /* proto */
//////////////////////// MatchClassIsType /////////////////////////////
static PyTypeObject* __Pyx_MatchCase_IsType(PyObject* type) {
+ #if PY_MAJOR_VERSION < 3
+ if (PyClass_Check(type)) {
+ // I don't really think it's worth the effort getting this to work!
+ PyErr_Format(PyExc_TypeError, "called match pattern must be a new-style class.");
+ return NULL;
+ }
+ #endif
if (!PyType_Check(type)) {
PyErr_Format(PyExc_TypeError, "called match pattern must be a type");
return NULL;
diff --git a/Tools/ci-run.sh b/Tools/ci-run.sh
index 905a9d1e3..ffde4cbe1 100644
--- a/Tools/ci-run.sh
+++ b/Tools/ci-run.sh
@@ -83,6 +83,8 @@ else
python -m pip install -r test-requirements.txt || exit 1
if [[ $PYTHON_VERSION != "pypy"* && $PYTHON_VERSION != "3."[1]* ]]; then
python -m pip install -r test-requirements-cpython.txt || exit 1
+ elif [[ $PYTHON_VERSION == "pypy-2.7" ]]; then
+ python -m pip install -r test-requirements-pypy27.txt || exit 1
fi
fi
fi
diff --git a/test-requirements-pypy27.txt b/test-requirements-pypy27.txt
index 9f9505240..6d4f83bca 100644
--- a/test-requirements-pypy27.txt
+++ b/test-requirements-pypy27.txt
@@ -1,2 +1,3 @@
-r test-requirements.txt
+enum34==1.1.10
mock==3.0.5
diff --git a/tests/run/extra_patma.pyx b/tests/run/extra_patma.pyx
index 3bded3426..76357f36f 100644
--- a/tests/run/extra_patma.pyx
+++ b/tests/run/extra_patma.pyx
@@ -5,6 +5,41 @@
cimport cython
import array
+import sys
+
+__doc__ = ""
+
+
+cdef bint is_null(int* x):
+ return False # disabled - currently just a parser test
+ match x:
+ case NULL:
+ return True
+ case _:
+ return False
+
+
+def test_is_null():
+ """
+ >>> test_is_null()
+ """
+ cdef int some_int = 1
+ return # disabled - currently just a parser test
+ assert is_null(&some_int) == False
+ assert is_null(NULL) == True
+
+
+if sys.version_info[0] > 2:
+ __doc__ += """
+ array.array doesn't have the buffer protocol in Py2 and
+ this doesn't really feel worth working around to test
+ >>> print(test_memoryview(array.array('i', [0, 1, 2])))
+ a 1
+ >>> print(test_memoryview(array.array('i', [])))
+ b
+ >>> print(test_memoryview(array.array('i', [5])))
+ c [5]
+ """
# goes via .shape instead
@cython.test_fail_if_path_exists("//CallNode//NameNode[@name = 'len']")
@@ -12,14 +47,8 @@ import array
@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']")
def test_memoryview(int[:] x):
"""
- >>> print(test_memoryview(array.array('i', [0, 1, 2])))
- a 1
- >>> print(test_memoryview(array.array('i', [])))
- b
>>> print(test_memoryview(None))
no!
- >>> print(test_memoryview(array.array('i', [5])))
- c [5]
"""
match x:
case [0, y, 2]:
@@ -45,6 +74,7 @@ def test_list_to_sequence(list x):
case _:
return False
+
@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']")
@cython.test_fail_if_path_exists("//CmpNode") # There's nothing to compare - it always succeeds!
def test_list_not_None_to_sequence(list x not None):
@@ -89,7 +119,7 @@ def class_attr_lookup(x):
assert cython.typeof(y) == "double", cython.typeof(y)
return y
-class PyClass:
+class PyClass(object):
pass
@cython.test_assert_path_exists("//PythonCapiFunctionNode[@cname='__Pyx_TypeCheck']")
@@ -110,6 +140,7 @@ def class_typecheck_exists(x):
case _:
return False
+
@cython.test_fail_if_path_exists("//NameNode[@name='isinstance']")
@cython.test_fail_if_path_exists("//PythonCapiFunctionNode[@cname='__Pyx_TypeCheck']")
def class_typecheck_doesnt_exist(C x):
diff --git a/tests/run/extra_patma_py.py b/tests/run/extra_patma_py.py
index e4b2aef84..a5046b997 100644
--- a/tests/run/extra_patma_py.py
+++ b/tests/run/extra_patma_py.py
@@ -4,6 +4,9 @@
from __future__ import print_function
import array
+import sys
+
+__doc__ = ""
def test_type_inference(x):
"""
@@ -63,10 +66,15 @@ def test_duplicate_keys(key1, key2):
>>> test_duplicate_keys("a", "b")
True
- >>> test_duplicate_keys("a", "a")
- Traceback (most recent call last):
- ...
- ValueError: mapping pattern checks duplicate key ('a')
+
+ Slightly awkward doctest to work around Py2 incompatibility
+ >>> try:
+ ... test_duplicate_keys("a", "a")
+ ... except ValueError as e:
+ ... if sys.version_info[0] > 2:
+ ... assert e.args[0] == "mapping pattern checks duplicate key ('a')", e.args[0]
+ ... else:
+ ... assert e.args[0] == "mapping pattern checks duplicate key"
"""
class Keys:
KEY_1 = key1
@@ -79,7 +87,7 @@ def test_duplicate_keys(key1, key2):
return False
-class PyClass:
+class PyClass(object):
pass
@@ -99,3 +107,20 @@ class PrivateAttrLookupOuter:
match x:
case PyClass(__something=y):
return y
+
+
+if sys.version_info[0] < 3:
+ class OldStyleClass:
+ pass
+
+ def test_oldstyle_class_failure(x):
+ match x:
+ case OldStyleClass():
+ return True
+
+ __doc__ += """
+ >>> test_oldstyle_class_failure(1)
+ Traceback (most recent call last):
+ ...
+ TypeError: called match pattern must be a new-style class.
+ """
diff --git a/tests/run/test_patma.py b/tests/run/test_patma.py
index 9344c8169..e51ba0dbd 100644
--- a/tests/run/test_patma.py
+++ b/tests/run/test_patma.py
@@ -1,43 +1,21 @@
-### COPIED FROM CPython 3.9
+### COPIED FROM CPython 3.12 alpha (July 2022)
### Original part after ############
# cython: language_level=3
# new code
import cython
-from Cython.Compiler.Main import compile as cython_compile, CompileError
-from Cython.Build.Inline import cython_inline
-import contextlib
-from tempfile import NamedTemporaryFile
-
-@contextlib.contextmanager
-def hidden_stderr():
- try:
- from StringIO import StringIO
- except ImportError:
- from io import StringIO
-
- old_stderr = sys.stderr
- try:
- sys.stderr = StringIO()
- yield
- finally:
- sys.stderr = old_stderr
-
-def _compile(code):
- with NamedTemporaryFile(suffix='.py') as f:
- f.write(code.encode('utf8'))
- f.flush()
-
- with hidden_stderr():
- result = cython_compile(f.name, language_level=3)
- return result
+from Cython.TestUtils import py_parse_code
+
if cython.compiled:
def compile(code, name, what):
assert what == 'exec'
- result = _compile(code)
- if not result.c_file:
- raise SyntaxError('unexpected EOF') # compile is only used for testing errors
+ py_parse_code(code)
+
+
+def disable(func):
+ pass
+
############## SLIGHTLY MODIFIED ORIGINAL CODE
import array
@@ -68,9 +46,47 @@ else:
return self.x == other.x and self.y == other.y
# TestCompiler removed - it's very CPython-specific
-# TestTracing also removed - doesn't seem like a core test
+# TestTracing also mainly removed - doesn't seem like a core test
+# except for one test that seems misplaced in CPython (which is below)
+
+class TestTracing(unittest.TestCase):
+ if sys.version_info < (3, 4):
+ class SubTestClass(object):
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ return
+ def __call__(self, *args):
+ return self
+ subTest = SubTestClass()
+
+ def test_parser_deeply_nested_patterns(self):
+ # Deeply nested patterns can cause exponential backtracking when parsing.
+ # See CPython gh-93671 for more information.
+ #
+ # DW: Cython note - this doesn't break the parser but may cause a
+ # RecursionError later in the code-generation. I don't believe that's
+ # easily avoidable with the way Cython visitors currently work
+
+ levels = 100
+
+ patterns = [
+ "A" + "(" * levels + ")" * levels,
+ "{1:" * levels + "1" + "}" * levels,
+ "[" * levels + "1" + "]" * levels,
+ ]
-# FIXME - return all the "return"s added to cause code to be dropped
+ for pattern in patterns:
+ with self.subTest(pattern):
+ code = inspect.cleandoc("""
+ match None:
+ case {}:
+ pass
+ """.format(pattern))
+ compile(code, "<string>", "exec")
+
+
+# FIXME - remove all the "return"s added to cause code to be dropped
############## ORIGINAL PART FROM CPYTHON
@@ -2706,6 +2722,21 @@ class TestPatma(unittest.TestCase):
self.assertEqual(y, 'bar')
+ def test_patma_249(self):
+ return # disabled
+ class C:
+ __attr = "eggs" # mangled to _C__attr
+ _Outer__attr = "bacon"
+ class Outer:
+ def f(self, x):
+ match x:
+ # looks up __attr, not _C__attr or _Outer__attr
+ case C(__attr=y):
+ return y
+ c = C()
+ setattr(c, "__attr", "spam") # setattr is needed because we're in a class scope
+ self.assertEqual(Outer().f(c), "spam")
+
class TestSyntaxErrors(unittest.TestCase):
@@ -2728,6 +2759,7 @@ class TestSyntaxErrors(unittest.TestCase):
""")
+ @disable # validation will be added when class patterns are added
def test_attribute_name_repeated_in_class_pattern(self):
self.assert_syntax_error("""
match ...:
@@ -2826,6 +2858,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # will be implemented as part of sequence patterns
def test_multiple_starred_names_in_sequence_pattern_0(self):
self.assert_syntax_error("""
match ...:
@@ -2833,6 +2866,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # will be implemented as part of sequence patterns
def test_multiple_starred_names_in_sequence_pattern_1(self):
self.assert_syntax_error("""
match ...:
@@ -2967,6 +3001,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # validation will be added when class patterns are added
def test_mapping_pattern_duplicate_key(self):
self.assert_syntax_error("""
match ...:
@@ -2974,6 +3009,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # validation will be added when class patterns are added
def test_mapping_pattern_duplicate_key_edge_case0(self):
self.assert_syntax_error("""
match ...:
@@ -2981,6 +3017,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # validation will be added when class patterns are added
def test_mapping_pattern_duplicate_key_edge_case1(self):
self.assert_syntax_error("""
match ...:
@@ -2988,6 +3025,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # validation will be added when class patterns are added
def test_mapping_pattern_duplicate_key_edge_case2(self):
self.assert_syntax_error("""
match ...:
@@ -2995,6 +3033,7 @@ class TestSyntaxErrors(unittest.TestCase):
pass
""")
+ @disable # validation will be added when class patterns are added
def test_mapping_pattern_duplicate_key_edge_case3(self):
self.assert_syntax_error("""
match ...: