summaryrefslogtreecommitdiff
path: root/Cython/Compiler/ParseTreeTransforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/ParseTreeTransforms.py')
-rw-r--r--Cython/Compiler/ParseTreeTransforms.py272
1 files changed, 238 insertions, 34 deletions
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
index 52a355e7f..981e4b174 100644
--- a/Cython/Compiler/ParseTreeTransforms.py
+++ b/Cython/Compiler/ParseTreeTransforms.py
@@ -6,10 +6,12 @@ import cython
cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
Options=object, UtilNodes=object, LetNode=object,
LetRefNode=object, TreeFragment=object, EncodedString=object,
- error=object, warning=object, copy=object, _unicode=object)
+ error=object, warning=object, copy=object, hashlib=object, sys=object,
+ _unicode=object)
import copy
import hashlib
+import sys
from . import PyrexTypes
from . import Naming
@@ -191,6 +193,7 @@ class PostParse(ScopeTrackingTransform):
self.specialattribute_handlers = {
'__cythonbufferdefaults__' : self.handle_bufferdefaults
}
+ self.in_pattern_node = False
def visit_LambdaNode(self, node):
# unpack a lambda expression into the corresponding DefNode
@@ -397,6 +400,18 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node)
return node
+ def visit_PatternNode(self, node):
+ in_pattern_node, self.in_pattern_node = self.in_pattern_node, True
+ self.visitchildren(node)
+ self.in_pattern_node = in_pattern_node
+ return node
+
+ def visit_JoinedStrNode(self, node):
+ if self.in_pattern_node:
+ error(node.pos, "f-strings are not accepted for pattern matching")
+ self.visitchildren(node)
+ return node
+
class _AssignmentExpressionTargetNameFinder(TreeVisitor):
def __init__(self):
super(_AssignmentExpressionTargetNameFinder, self).__init__()
@@ -844,6 +859,14 @@ class InterpretCompilerDirectives(CythonTransform):
}
special_methods.update(unop_method_nodes)
+ valid_cython_submodules = {
+ 'cimports',
+ 'dataclasses',
+ 'operator',
+ 'parallel',
+ 'view',
+ }
+
valid_parallel_directives = {
"parallel",
"prange",
@@ -872,6 +895,34 @@ class InterpretCompilerDirectives(CythonTransform):
error(pos, "Invalid directive: '%s'." % (directive,))
return True
+ def _check_valid_cython_module(self, pos, module_name):
+ if not module_name.startswith("cython."):
+ return
+ if module_name.split('.', 2)[1] in self.valid_cython_submodules:
+ return
+
+ extra = ""
+ # This is very rarely used, so don't waste space on static tuples.
+ hints = [
+ line.split() for line in """\
+ imp cimports
+ cimp cimports
+ para parallel
+ parra parallel
+ dataclass dataclasses
+ """.splitlines()[:-1]
+ ]
+ for wrong, correct in hints:
+ if module_name.startswith("cython." + wrong):
+ extra = "Did you mean 'cython.%s' ?" % correct
+ break
+
+ error(pos, "'%s' is not a valid cython.* module%s%s" % (
+ module_name,
+ ". " if extra else "",
+ extra,
+ ))
+
# Set up processing and handle the cython: comments.
def visit_ModuleNode(self, node):
for key in sorted(node.directive_comments):
@@ -942,6 +993,9 @@ class InterpretCompilerDirectives(CythonTransform):
elif module_name.startswith(u"cython."):
if module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module")
+ else:
+ self._check_valid_cython_module(node.pos, module_name)
+
if module_name == u"cython.parallel":
if node.as_name and node.as_name != u"cython":
self.parallel_directives[node.as_name] = module_name
@@ -968,10 +1022,10 @@ class InterpretCompilerDirectives(CythonTransform):
node.pos, module_name, node.relative_level, node.imported_names)
elif not node.relative_level and (
module_name == u"cython" or module_name.startswith(u"cython.")):
+ self._check_valid_cython_module(node.pos, module_name)
submodule = (module_name + u".")[7:]
newimp = []
-
- for pos, name, as_name, kind in node.imported_names:
+ for pos, name, as_name in node.imported_names:
full_name = submodule + name
qualified_name = u"cython." + full_name
if self.is_parallel_directive(qualified_name, node.pos):
@@ -980,15 +1034,12 @@ class InterpretCompilerDirectives(CythonTransform):
self.parallel_directives[as_name or name] = qualified_name
elif self.is_cython_directive(full_name):
self.directive_names[as_name or name] = full_name
- if kind is not None:
- self.context.nonfatal_error(PostParseError(pos,
- "Compiler directive imports must be plain imports"))
elif full_name in ['dataclasses', 'typing']:
self.directive_names[as_name or name] = full_name
# unlike many directives, still treat it as a regular module
- newimp.append((pos, name, as_name, kind))
+ newimp.append((pos, name, as_name))
else:
- newimp.append((pos, name, as_name, kind))
+ newimp.append((pos, name, as_name))
if not newimp:
return None
@@ -1003,10 +1054,11 @@ class InterpretCompilerDirectives(CythonTransform):
imported_names = []
for name, name_node in node.items:
imported_names.append(
- (name_node.pos, name, None if name == name_node.name else name_node.name, None))
+ (name_node.pos, name, None if name == name_node.name else name_node.name))
return self._create_cimport_from_import(
node.pos, module_name, import_node.level, imported_names)
elif module_name == u"cython" or module_name.startswith(u"cython."):
+ self._check_valid_cython_module(import_node.module_name.pos, module_name)
submodule = (module_name + u".")[7:]
newimp = []
for name, name_node in node.items:
@@ -1041,14 +1093,13 @@ class InterpretCompilerDirectives(CythonTransform):
module_name=dotted_name,
as_name=as_name,
is_absolute=level == 0)
- for pos, dotted_name, as_name, _ in imported_names
+ for pos, dotted_name, as_name in imported_names
]
def visit_SingleAssignmentNode(self, node):
if isinstance(node.rhs, ExprNodes.ImportNode):
module_name = node.rhs.module_name.value
- is_special_module = (module_name + u".").startswith((u"cython.parallel.", u"cython.cimports."))
- if module_name != u"cython" and not is_special_module:
+ if module_name != u"cython" and not module_name.startswith("cython."):
return node
node = Nodes.CImportStatNode(node.pos, module_name=module_name, as_name=node.lhs.name)
@@ -1197,7 +1248,7 @@ class InterpretCompilerDirectives(CythonTransform):
return (optname, directivetype(optname, str(args[0].value)))
elif directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS:
# signal to pass things on without processing
- return (optname, (args, kwds.as_python_dict()))
+ return (optname, (args, kwds.as_python_dict() if kwds else {}))
else:
assert False
@@ -1290,8 +1341,7 @@ class InterpretCompilerDirectives(CythonTransform):
name, value = directive
if self.directives.get(name, object()) != value:
directives.append(directive)
- if (directive[0] == 'staticmethod' or
- (directive[0] == 'dataclasses.dataclass' and scope_name == 'class')):
+ if directive[0] == 'staticmethod':
both.append(dec)
# Adapt scope type based on decorators that change it.
if directive[0] == 'cclass' and scope_name == 'class':
@@ -1301,10 +1351,11 @@ class InterpretCompilerDirectives(CythonTransform):
if realdecs and (scope_name == 'cclass' or
isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))):
for realdec in realdecs:
+ dec_pos = realdec.pos
realdec = realdec.decorator
if ((realdec.is_name and realdec.name == "dataclass") or
(realdec.is_attribute and realdec.attribute == "dataclass")):
- error(realdec.pos,
+ error(dec_pos,
"Use '@cython.dataclasses.dataclass' on cdef classes to create a dataclass")
# Note - arbitrary C function decorators are caught later in DecoratorTransform
raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
@@ -1602,6 +1653,128 @@ class WithTransform(VisitorTransform, SkipDeclarations):
visit_Node = VisitorTransform.recurse_to_children
+class _GeneratorExpressionArgumentsMarker(TreeVisitor, SkipDeclarations):
+ # called from "MarkClosureVisitor"
+ def __init__(self, gen_expr):
+ super(_GeneratorExpressionArgumentsMarker, self).__init__()
+ self.gen_expr = gen_expr
+
+ def visit_ExprNode(self, node):
+ if not node.is_literal:
+ # Don't bother tagging literal nodes
+ assert (not node.generator_arg_tag) # nobody has tagged this first
+ node.generator_arg_tag = self.gen_expr
+ self.visitchildren(node)
+
+ def visit_Node(self, node):
+ # We're only interested in the expressions that make up the iterator sequence,
+ # so don't go beyond ExprNodes (e.g. into ForFromStatNode).
+ return
+
+ def visit_GeneratorExpressionNode(self, node):
+ node.generator_arg_tag = self.gen_expr
+ # don't visit children, can't handle overlapping tags
+ # (and assume generator expressions don't end up optimized out in a way
+ # that would require overlapping tags)
+
+
+class _HandleGeneratorArguments(VisitorTransform, SkipDeclarations):
+ # used from within CreateClosureClasses
+
+ def __call__(self, node):
+ from . import Visitor
+ assert isinstance(node, ExprNodes.GeneratorExpressionNode)
+ self.gen_node = node
+
+ self.args = list(node.def_node.args)
+ self.call_parameters = list(node.call_parameters)
+ self.tag_count = 0
+ self.substitutions = {}
+
+ self.visitchildren(node)
+
+ for k, v in self.substitutions.items():
+ # doing another search for replacements here (at the end) allows us to sweep up
+ # CloneNodes too (which are often generated by the optimizer)
+ # (it could arguably be done more efficiently with a single traversal though)
+ Visitor.recursively_replace_node(node, k, v)
+
+ node.def_node.args = self.args
+ node.call_parameters = self.call_parameters
+ return node
+
+ def visit_GeneratorExpressionNode(self, node):
+ # a generator can also be substituted itself, so handle that case
+ new_node = self._handle_ExprNode(node, do_visit_children=False)
+ # However do not traverse into it. A new _HandleGeneratorArguments visitor will be used
+ # elsewhere to do that.
+ return node
+
+ def _handle_ExprNode(self, node, do_visit_children):
+ if (node.generator_arg_tag is not None and self.gen_node is not None and
+ self.gen_node == node.generator_arg_tag):
+ pos = node.pos
+ # The reason for using ".x" as the name is that this is how CPython
+ # tracks internal variables in loops (e.g.
+ # { locals() for v in range(10) }
+ # will produce "v" and ".0"). We don't replicate this behaviour completely
+ # but use it as a starting point
+ name_source = self.tag_count
+ self.tag_count += 1
+ name = EncodedString(".{0}".format(name_source))
+ def_node = self.gen_node.def_node
+ if not def_node.local_scope.lookup_here(name):
+ from . import Symtab
+ cname = EncodedString(Naming.genexpr_arg_prefix + Symtab.punycodify_name(str(name_source)))
+ name_decl = Nodes.CNameDeclaratorNode(pos=pos, name=name)
+ type = node.type
+ if type.is_reference and not type.is_fake_reference:
+ # It isn't obvious whether the right thing to do would be to capture by reference or by
+ # value (C++ itself doesn't know either for lambda functions and forces a choice).
+ # However, capture by reference involves converting to FakeReference which would require
+ # re-analysing AttributeNodes. Therefore I've picked capture-by-value out of convenience
+ # TODO - could probably be optimized by making the arg a reference but the closure not
+ # (see https://github.com/cython/cython/issues/2468)
+ type = type.ref_base_type
+
+ name_decl.type = type
+ new_arg = Nodes.CArgDeclNode(pos=pos, declarator=name_decl,
+ base_type=None, default=None, annotation=None)
+ new_arg.name = name_decl.name
+ new_arg.type = type
+
+ self.args.append(new_arg)
+ node.generator_arg_tag = None # avoid the possibility of this being caught again
+ self.call_parameters.append(node)
+ new_arg.entry = def_node.declare_argument(def_node.local_scope, new_arg)
+ new_arg.entry.cname = cname
+ new_arg.entry.in_closure = True
+
+ if do_visit_children:
+ # now visit the Nodes's children (but remove self.gen_node to not to further
+ # argument substitution)
+ gen_node, self.gen_node = self.gen_node, None
+ self.visitchildren(node)
+ self.gen_node = gen_node
+
+ # replace the node inside the generator with a looked-up name
+ # (initialized_check can safely be False because the source variable will be checked
+ # before it is captured if the check is required)
+ name_node = ExprNodes.NameNode(pos, name=name, initialized_check=False)
+ name_node.entry = self.gen_node.def_node.gbody.local_scope.lookup(name_node.name)
+ name_node.type = name_node.entry.type
+ self.substitutions[node] = name_node
+ return name_node
+ if do_visit_children:
+ self.visitchildren(node)
+ return node
+
+ def visit_ExprNode(self, node):
+ return self._handle_ExprNode(node, True)
+
+ visit_Node = VisitorTransform.recurse_to_children
+
+
class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
"""
Transforms method decorators in cdef classes into nested calls or properties.
@@ -2057,22 +2230,10 @@ if VALUE is not None:
if not e.type.is_pyobject:
e.type.create_to_py_utility_code(env)
e.type.create_from_py_utility_code(env)
- all_members_names = sorted([e.name for e in all_members])
-
- # Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons.
- # SHA-256 should be ok for years to come, but early Cython 3.0 alpha releases used SHA-1,
- # which may not be.
- checksum_algos = [hashlib.sha256, hashlib.sha1]
- try:
- checksum_algos.append(hashlib.md5)
- except AttributeError:
- pass
- member_names_string = ' '.join(all_members_names).encode('utf-8')
- checksums = [
- '0x' + mkchecksum(member_names_string).hexdigest()[:7]
- for mkchecksum in checksum_algos
- ]
+ all_members_names = [e.name for e in all_members]
+ checksums = _calculate_pickle_checksums(all_members_names)
+
unpickle_func_name = '__pyx_unpickle_%s' % node.punycode_class_name
# TODO(robertwb): Move the state into the third argument
@@ -2315,11 +2476,17 @@ if VALUE is not None:
assmt.analyse_declarations(env)
return assmt
+ def visit_func_outer_attrs(self, node):
+ # any names in the outer attrs should not be looked up in the function "seen_vars_stack"
+ stack = self.seen_vars_stack.pop()
+ super(AnalyseDeclarationsTransform, self).visit_func_outer_attrs(node)
+ self.seen_vars_stack.append(stack)
+
def visit_ScopedExprNode(self, node):
env = self.current_env()
node.analyse_declarations(env)
# the node may or may not have a local scope
- if node.has_local_scope:
+ if node.expr_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
self.enter_scope(node, node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope)
@@ -2327,6 +2494,7 @@ if VALUE is not None:
self.exit_scope()
self.seen_vars_stack.pop()
else:
+
node.analyse_scoped_declarations(env)
self.visitchildren(node)
return node
@@ -2483,6 +2651,24 @@ if VALUE is not None:
return node
+def _calculate_pickle_checksums(member_names):
+ # Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons.
+ # SHA-256 should be ok for years to come, but early Cython 3.0 alpha releases used SHA-1,
+ # which may not be.
+ member_names_string = ' '.join(member_names).encode('utf-8')
+ hash_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
+ checksums = []
+ for algo_name in ['sha256', 'sha1', 'md5']:
+ try:
+ mkchecksum = getattr(hashlib, algo_name)
+ checksum = mkchecksum(member_names_string, **hash_kwargs).hexdigest()
+ except (AttributeError, ValueError):
+ # The algorithm (i.e. MD5) might not be there at all, or might be blocked at runtime.
+ continue
+ checksums.append('0x' + checksum[:7])
+ return checksums
+
+
class CalculateQualifiedNamesTransform(EnvTransform):
"""
Calculate and store the '__qualname__' and the global
@@ -2874,8 +3060,7 @@ class RemoveUnreachableCode(CythonTransform):
if not self.current_directives['remove_unreachable']:
return node
self.visitchildren(node)
- for idx, stat in enumerate(node.stats):
- idx += 1
+ for idx, stat in enumerate(node.stats, 1):
if stat.is_terminator:
if idx < len(node.stats):
if self.current_directives['warn.unreachable']:
@@ -2974,6 +3159,8 @@ class YieldNodeCollector(TreeVisitor):
class MarkClosureVisitor(CythonTransform):
+ # In addition to marking closures this is also responsible to finding parts of the
+ # generator iterable and marking them
def visit_ModuleNode(self, node):
self.needs_closure = False
@@ -3044,6 +3231,19 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True
return node
+ def visit_GeneratorExpressionNode(self, node):
+ node = self.visit_LambdaNode(node)
+ if not isinstance(node.loop, Nodes._ForInStatNode):
+ # Possibly should handle ForFromStatNode
+ # but for now do nothing
+ return node
+ itseq = node.loop.iterator.sequence
+ # literals do not need replacing with an argument
+ if itseq.is_literal:
+ return node
+ _GeneratorExpressionArgumentsMarker(node).visit(itseq)
+ return node
+
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
@@ -3188,6 +3388,10 @@ class CreateClosureClasses(CythonTransform):
self.visitchildren(node)
return node
+ def visit_GeneratorExpressionNode(self, node):
+ node = _HandleGeneratorArguments()(node)
+ return self.visit_LambdaNode(node)
+
class InjectGilHandling(VisitorTransform, SkipDeclarations):
"""