diff options
Diffstat (limited to 'Cython/Compiler/ParseTreeTransforms.py')
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 272 |
1 files changed, 238 insertions, 34 deletions
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 52a355e7f..981e4b174 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -6,10 +6,12 @@ import cython cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object, Options=object, UtilNodes=object, LetNode=object, LetRefNode=object, TreeFragment=object, EncodedString=object, - error=object, warning=object, copy=object, _unicode=object) + error=object, warning=object, copy=object, hashlib=object, sys=object, + _unicode=object) import copy import hashlib +import sys from . import PyrexTypes from . import Naming @@ -191,6 +193,7 @@ class PostParse(ScopeTrackingTransform): self.specialattribute_handlers = { '__cythonbufferdefaults__' : self.handle_bufferdefaults } + self.in_pattern_node = False def visit_LambdaNode(self, node): # unpack a lambda expression into the corresponding DefNode @@ -397,6 +400,18 @@ class PostParse(ScopeTrackingTransform): self.visitchildren(node) return node + def visit_PatternNode(self, node): + in_pattern_node, self.in_pattern_node = self.in_pattern_node, True + self.visitchildren(node) + self.in_pattern_node = in_pattern_node + return node + + def visit_JoinedStrNode(self, node): + if self.in_pattern_node: + error(node.pos, "f-strings are not accepted for pattern matching") + self.visitchildren(node) + return node + class _AssignmentExpressionTargetNameFinder(TreeVisitor): def __init__(self): super(_AssignmentExpressionTargetNameFinder, self).__init__() @@ -844,6 +859,14 @@ class InterpretCompilerDirectives(CythonTransform): } special_methods.update(unop_method_nodes) + valid_cython_submodules = { + 'cimports', + 'dataclasses', + 'operator', + 'parallel', + 'view', + } + valid_parallel_directives = { "parallel", "prange", @@ -872,6 +895,34 @@ class InterpretCompilerDirectives(CythonTransform): error(pos, "Invalid directive: '%s'." % (directive,)) return True + def _check_valid_cython_module(self, pos, module_name): + if not module_name.startswith("cython."): + return + if module_name.split('.', 2)[1] in self.valid_cython_submodules: + return + + extra = "" + # This is very rarely used, so don't waste space on static tuples. + hints = [ + line.split() for line in """\ + imp cimports + cimp cimports + para parallel + parra parallel + dataclass dataclasses + """.splitlines()[:-1] + ] + for wrong, correct in hints: + if module_name.startswith("cython." + wrong): + extra = "Did you mean 'cython.%s' ?" % correct + break + + error(pos, "'%s' is not a valid cython.* module%s%s" % ( + module_name, + ". " if extra else "", + extra, + )) + # Set up processing and handle the cython: comments. def visit_ModuleNode(self, node): for key in sorted(node.directive_comments): @@ -942,6 +993,9 @@ class InterpretCompilerDirectives(CythonTransform): elif module_name.startswith(u"cython."): if module_name.startswith(u"cython.parallel."): error(node.pos, node.module_name + " is not a module") + else: + self._check_valid_cython_module(node.pos, module_name) + if module_name == u"cython.parallel": if node.as_name and node.as_name != u"cython": self.parallel_directives[node.as_name] = module_name @@ -968,10 +1022,10 @@ class InterpretCompilerDirectives(CythonTransform): node.pos, module_name, node.relative_level, node.imported_names) elif not node.relative_level and ( module_name == u"cython" or module_name.startswith(u"cython.")): + self._check_valid_cython_module(node.pos, module_name) submodule = (module_name + u".")[7:] newimp = [] - - for pos, name, as_name, kind in node.imported_names: + for pos, name, as_name in node.imported_names: full_name = submodule + name qualified_name = u"cython." + full_name if self.is_parallel_directive(qualified_name, node.pos): @@ -980,15 +1034,12 @@ class InterpretCompilerDirectives(CythonTransform): self.parallel_directives[as_name or name] = qualified_name elif self.is_cython_directive(full_name): self.directive_names[as_name or name] = full_name - if kind is not None: - self.context.nonfatal_error(PostParseError(pos, - "Compiler directive imports must be plain imports")) elif full_name in ['dataclasses', 'typing']: self.directive_names[as_name or name] = full_name # unlike many directives, still treat it as a regular module - newimp.append((pos, name, as_name, kind)) + newimp.append((pos, name, as_name)) else: - newimp.append((pos, name, as_name, kind)) + newimp.append((pos, name, as_name)) if not newimp: return None @@ -1003,10 +1054,11 @@ class InterpretCompilerDirectives(CythonTransform): imported_names = [] for name, name_node in node.items: imported_names.append( - (name_node.pos, name, None if name == name_node.name else name_node.name, None)) + (name_node.pos, name, None if name == name_node.name else name_node.name)) return self._create_cimport_from_import( node.pos, module_name, import_node.level, imported_names) elif module_name == u"cython" or module_name.startswith(u"cython."): + self._check_valid_cython_module(import_node.module_name.pos, module_name) submodule = (module_name + u".")[7:] newimp = [] for name, name_node in node.items: @@ -1041,14 +1093,13 @@ class InterpretCompilerDirectives(CythonTransform): module_name=dotted_name, as_name=as_name, is_absolute=level == 0) - for pos, dotted_name, as_name, _ in imported_names + for pos, dotted_name, as_name in imported_names ] def visit_SingleAssignmentNode(self, node): if isinstance(node.rhs, ExprNodes.ImportNode): module_name = node.rhs.module_name.value - is_special_module = (module_name + u".").startswith((u"cython.parallel.", u"cython.cimports.")) - if module_name != u"cython" and not is_special_module: + if module_name != u"cython" and not module_name.startswith("cython."): return node node = Nodes.CImportStatNode(node.pos, module_name=module_name, as_name=node.lhs.name) @@ -1197,7 +1248,7 @@ class InterpretCompilerDirectives(CythonTransform): return (optname, directivetype(optname, str(args[0].value))) elif directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS: # signal to pass things on without processing - return (optname, (args, kwds.as_python_dict())) + return (optname, (args, kwds.as_python_dict() if kwds else {})) else: assert False @@ -1290,8 +1341,7 @@ class InterpretCompilerDirectives(CythonTransform): name, value = directive if self.directives.get(name, object()) != value: directives.append(directive) - if (directive[0] == 'staticmethod' or - (directive[0] == 'dataclasses.dataclass' and scope_name == 'class')): + if directive[0] == 'staticmethod': both.append(dec) # Adapt scope type based on decorators that change it. if directive[0] == 'cclass' and scope_name == 'class': @@ -1301,10 +1351,11 @@ class InterpretCompilerDirectives(CythonTransform): if realdecs and (scope_name == 'cclass' or isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))): for realdec in realdecs: + dec_pos = realdec.pos realdec = realdec.decorator if ((realdec.is_name and realdec.name == "dataclass") or (realdec.is_attribute and realdec.attribute == "dataclass")): - error(realdec.pos, + error(dec_pos, "Use '@cython.dataclasses.dataclass' on cdef classes to create a dataclass") # Note - arbitrary C function decorators are caught later in DecoratorTransform raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.") @@ -1602,6 +1653,128 @@ class WithTransform(VisitorTransform, SkipDeclarations): visit_Node = VisitorTransform.recurse_to_children +class _GeneratorExpressionArgumentsMarker(TreeVisitor, SkipDeclarations): + # called from "MarkClosureVisitor" + def __init__(self, gen_expr): + super(_GeneratorExpressionArgumentsMarker, self).__init__() + self.gen_expr = gen_expr + + def visit_ExprNode(self, node): + if not node.is_literal: + # Don't bother tagging literal nodes + assert (not node.generator_arg_tag) # nobody has tagged this first + node.generator_arg_tag = self.gen_expr + self.visitchildren(node) + + def visit_Node(self, node): + # We're only interested in the expressions that make up the iterator sequence, + # so don't go beyond ExprNodes (e.g. into ForFromStatNode). + return + + def visit_GeneratorExpressionNode(self, node): + node.generator_arg_tag = self.gen_expr + # don't visit children, can't handle overlapping tags + # (and assume generator expressions don't end up optimized out in a way + # that would require overlapping tags) + + +class _HandleGeneratorArguments(VisitorTransform, SkipDeclarations): + # used from within CreateClosureClasses + + def __call__(self, node): + from . import Visitor + assert isinstance(node, ExprNodes.GeneratorExpressionNode) + self.gen_node = node + + self.args = list(node.def_node.args) + self.call_parameters = list(node.call_parameters) + self.tag_count = 0 + self.substitutions = {} + + self.visitchildren(node) + + for k, v in self.substitutions.items(): + # doing another search for replacements here (at the end) allows us to sweep up + # CloneNodes too (which are often generated by the optimizer) + # (it could arguably be done more efficiently with a single traversal though) + Visitor.recursively_replace_node(node, k, v) + + node.def_node.args = self.args + node.call_parameters = self.call_parameters + return node + + def visit_GeneratorExpressionNode(self, node): + # a generator can also be substituted itself, so handle that case + new_node = self._handle_ExprNode(node, do_visit_children=False) + # However do not traverse into it. A new _HandleGeneratorArguments visitor will be used + # elsewhere to do that. + return node + + def _handle_ExprNode(self, node, do_visit_children): + if (node.generator_arg_tag is not None and self.gen_node is not None and + self.gen_node == node.generator_arg_tag): + pos = node.pos + # The reason for using ".x" as the name is that this is how CPython + # tracks internal variables in loops (e.g. + # { locals() for v in range(10) } + # will produce "v" and ".0"). We don't replicate this behaviour completely + # but use it as a starting point + name_source = self.tag_count + self.tag_count += 1 + name = EncodedString(".{0}".format(name_source)) + def_node = self.gen_node.def_node + if not def_node.local_scope.lookup_here(name): + from . import Symtab + cname = EncodedString(Naming.genexpr_arg_prefix + Symtab.punycodify_name(str(name_source))) + name_decl = Nodes.CNameDeclaratorNode(pos=pos, name=name) + type = node.type + if type.is_reference and not type.is_fake_reference: + # It isn't obvious whether the right thing to do would be to capture by reference or by + # value (C++ itself doesn't know either for lambda functions and forces a choice). + # However, capture by reference involves converting to FakeReference which would require + # re-analysing AttributeNodes. Therefore I've picked capture-by-value out of convenience + # TODO - could probably be optimized by making the arg a reference but the closure not + # (see https://github.com/cython/cython/issues/2468) + type = type.ref_base_type + + name_decl.type = type + new_arg = Nodes.CArgDeclNode(pos=pos, declarator=name_decl, + base_type=None, default=None, annotation=None) + new_arg.name = name_decl.name + new_arg.type = type + + self.args.append(new_arg) + node.generator_arg_tag = None # avoid the possibility of this being caught again + self.call_parameters.append(node) + new_arg.entry = def_node.declare_argument(def_node.local_scope, new_arg) + new_arg.entry.cname = cname + new_arg.entry.in_closure = True + + if do_visit_children: + # now visit the Nodes's children (but remove self.gen_node to not to further + # argument substitution) + gen_node, self.gen_node = self.gen_node, None + self.visitchildren(node) + self.gen_node = gen_node + + # replace the node inside the generator with a looked-up name + # (initialized_check can safely be False because the source variable will be checked + # before it is captured if the check is required) + name_node = ExprNodes.NameNode(pos, name=name, initialized_check=False) + name_node.entry = self.gen_node.def_node.gbody.local_scope.lookup(name_node.name) + name_node.type = name_node.entry.type + self.substitutions[node] = name_node + return name_node + if do_visit_children: + self.visitchildren(node) + return node + + def visit_ExprNode(self, node): + return self._handle_ExprNode(node, True) + + visit_Node = VisitorTransform.recurse_to_children + + class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): """ Transforms method decorators in cdef classes into nested calls or properties. @@ -2057,22 +2230,10 @@ if VALUE is not None: if not e.type.is_pyobject: e.type.create_to_py_utility_code(env) e.type.create_from_py_utility_code(env) - all_members_names = sorted([e.name for e in all_members]) - - # Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons. - # SHA-256 should be ok for years to come, but early Cython 3.0 alpha releases used SHA-1, - # which may not be. - checksum_algos = [hashlib.sha256, hashlib.sha1] - try: - checksum_algos.append(hashlib.md5) - except AttributeError: - pass - member_names_string = ' '.join(all_members_names).encode('utf-8') - checksums = [ - '0x' + mkchecksum(member_names_string).hexdigest()[:7] - for mkchecksum in checksum_algos - ] + all_members_names = [e.name for e in all_members] + checksums = _calculate_pickle_checksums(all_members_names) + unpickle_func_name = '__pyx_unpickle_%s' % node.punycode_class_name # TODO(robertwb): Move the state into the third argument @@ -2315,11 +2476,17 @@ if VALUE is not None: assmt.analyse_declarations(env) return assmt + def visit_func_outer_attrs(self, node): + # any names in the outer attrs should not be looked up in the function "seen_vars_stack" + stack = self.seen_vars_stack.pop() + super(AnalyseDeclarationsTransform, self).visit_func_outer_attrs(node) + self.seen_vars_stack.append(stack) + def visit_ScopedExprNode(self, node): env = self.current_env() node.analyse_declarations(env) # the node may or may not have a local scope - if node.has_local_scope: + if node.expr_scope: self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) self.enter_scope(node, node.expr_scope) node.analyse_scoped_declarations(node.expr_scope) @@ -2327,6 +2494,7 @@ if VALUE is not None: self.exit_scope() self.seen_vars_stack.pop() else: + node.analyse_scoped_declarations(env) self.visitchildren(node) return node @@ -2483,6 +2651,24 @@ if VALUE is not None: return node +def _calculate_pickle_checksums(member_names): + # Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons. + # SHA-256 should be ok for years to come, but early Cython 3.0 alpha releases used SHA-1, + # which may not be. + member_names_string = ' '.join(member_names).encode('utf-8') + hash_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {} + checksums = [] + for algo_name in ['sha256', 'sha1', 'md5']: + try: + mkchecksum = getattr(hashlib, algo_name) + checksum = mkchecksum(member_names_string, **hash_kwargs).hexdigest() + except (AttributeError, ValueError): + # The algorithm (i.e. MD5) might not be there at all, or might be blocked at runtime. + continue + checksums.append('0x' + checksum[:7]) + return checksums + + class CalculateQualifiedNamesTransform(EnvTransform): """ Calculate and store the '__qualname__' and the global @@ -2874,8 +3060,7 @@ class RemoveUnreachableCode(CythonTransform): if not self.current_directives['remove_unreachable']: return node self.visitchildren(node) - for idx, stat in enumerate(node.stats): - idx += 1 + for idx, stat in enumerate(node.stats, 1): if stat.is_terminator: if idx < len(node.stats): if self.current_directives['warn.unreachable']: @@ -2974,6 +3159,8 @@ class YieldNodeCollector(TreeVisitor): class MarkClosureVisitor(CythonTransform): + # In addition to marking closures this is also responsible to finding parts of the + # generator iterable and marking them def visit_ModuleNode(self, node): self.needs_closure = False @@ -3044,6 +3231,19 @@ class MarkClosureVisitor(CythonTransform): self.needs_closure = True return node + def visit_GeneratorExpressionNode(self, node): + node = self.visit_LambdaNode(node) + if not isinstance(node.loop, Nodes._ForInStatNode): + # Possibly should handle ForFromStatNode + # but for now do nothing + return node + itseq = node.loop.iterator.sequence + # literals do not need replacing with an argument + if itseq.is_literal: + return node + _GeneratorExpressionArgumentsMarker(node).visit(itseq) + return node + class CreateClosureClasses(CythonTransform): # Output closure classes in module scope for all functions @@ -3188,6 +3388,10 @@ class CreateClosureClasses(CythonTransform): self.visitchildren(node) return node + def visit_GeneratorExpressionNode(self, node): + node = _HandleGeneratorArguments()(node) + return self.visit_LambdaNode(node) + class InjectGilHandling(VisitorTransform, SkipDeclarations): """ |