diff options
Diffstat (limited to 'Cython/Compiler/Dataclass.py')
-rw-r--r-- | Cython/Compiler/Dataclass.py | 328 |
1 files changed, 193 insertions, 135 deletions
diff --git a/Cython/Compiler/Dataclass.py b/Cython/Compiler/Dataclass.py index 48c1888d6..7cbbab954 100644 --- a/Cython/Compiler/Dataclass.py +++ b/Cython/Compiler/Dataclass.py @@ -81,6 +81,59 @@ class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations): return node +class TemplateCode(object): + _placeholder_count = 0 + + def __init__(self): + self.code_lines = [] + self.placeholders = {} + self.extra_stats = [] + + def insertion_point(self): + return len(self.code_lines) + + def insert_code_line(self, insertion_point, code_line): + self.code_lines.insert(insertion_point, code_line) + + def reset(self, insertion_point=0): + del self.code_lines[insertion_point:] + + def add_code_line(self, code_line): + self.code_lines.append(code_line) + + def add_code_lines(self, code_lines): + self.code_lines.extend(code_lines) + + def new_placeholder(self, field_names, value): + name = self._new_placeholder_name(field_names) + self.placeholders[name] = value + return name + + def add_extra_statements(self, statements): + self.extra_stats.extend(statements) + + def _new_placeholder_name(self, field_names): + while True: + name = "INIT_PLACEHOLDER_%d" % self._placeholder_count + if (name not in self.placeholders + and name not in field_names): + # make sure name isn't already used and doesn't + # conflict with a variable name (which is unlikely but possible) + break + self._placeholder_count += 1 + return name + + def generate_tree(self, level='c_class'): + stat_list_node = TreeFragment( + "\n".join(self.code_lines), + level=level, + pipeline=[NormalizeTree(None)], + ).substitute(self.placeholders) + + stat_list_node.stats += self.extra_stats + return stat_list_node + + class _MISSING_TYPE(object): pass MISSING = _MISSING_TYPE() @@ -147,33 +200,39 @@ def process_class_get_fields(node): transform(node) default_value_assignments = transform.removed_assignments - if node.base_type and node.base_type.dataclass_fields: - fields = node.base_type.dataclass_fields.copy() - else: - fields = OrderedDict() + base_type = node.base_type + fields = OrderedDict() + while base_type: + if base_type.is_external or not base_type.scope.implemented: + warning(node.pos, "Cannot reliably handle Cython dataclasses with base types " + "in external modules since it is not possible to tell what fields they have", 2) + if base_type.dataclass_fields: + fields = base_type.dataclass_fields.copy() + break + base_type = base_type.base_type for entry in var_entries: name = entry.name - is_initvar = (entry.type.python_type_constructor_name == "dataclasses.InitVar") + is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar") # TODO - classvars aren't included in "var_entries" so are missed here # and thus this code is never triggered - is_classvar = (entry.type.python_type_constructor_name == "typing.ClassVar") - if is_initvar or is_classvar: - entry.type = entry.type.resolve() # no longer need the special type + is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar") if name in default_value_assignments: assignment = default_value_assignments[name] if (isinstance(assignment, ExprNodes.CallNode) and assignment.function.as_cython_attribute() == "dataclasses.field"): # I believe most of this is well-enforced when it's treated as a directive # but it doesn't hurt to make sure - if (not isinstance(assignment, ExprNodes.GeneralCallNode) - or not isinstance(assignment.positional_args, ExprNodes.TupleNode) - or assignment.positional_args.args - or not isinstance(assignment.keyword_args, ExprNodes.DictNode)): + valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode) + and isinstance(assignment.positional_args, ExprNodes.TupleNode) + and not assignment.positional_args.args + and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode))) + valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args) + if not (valid_general_call or valid_simple_call): error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist " "of compile-time keyword arguments") continue - keyword_args = assignment.keyword_args.as_python_dict() + keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {} if 'default' in keyword_args and 'default_factory' in keyword_args: error(assignment.pos, "cannot specify both default and default_factory") continue @@ -218,7 +277,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): if not isinstance(v, ExprNodes.BoolNode): error(node.pos, "Arguments passed to cython.dataclasses.dataclass must be True or False") - kwargs[k] = v + kwargs[k] = v.value # remove everything that does not belong into _DataclassParams() kw_only = kwargs.pop("kw_only") @@ -251,23 +310,14 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): stats = Nodes.StatListNode(node.pos, stats=[dataclass_params_assignment] + dataclass_fields_stats) - code_lines = [] - placeholders = {} - extra_stats = [] - for cl, ph, es in [ generate_init_code(kwargs['init'], node, fields, kw_only), - generate_repr_code(kwargs['repr'], node, fields), - generate_eq_code(kwargs['eq'], node, fields), - generate_order_code(kwargs['order'], node, fields), - generate_hash_code(kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields) ]: - code_lines.append(cl) - placeholders.update(ph) - extra_stats.extend(extra_stats) - - code_lines = "\n".join(code_lines) - code_tree = TreeFragment(code_lines, level='c_class', pipeline=[NormalizeTree(node.scope)] - ).substitute(placeholders) - - stats.stats += (code_tree.stats + extra_stats) + code = TemplateCode() + generate_init_code(code, kwargs['init'], node, fields, kw_only) + generate_repr_code(code, kwargs['repr'], node, fields) + generate_eq_code(code, kwargs['eq'], node, fields) + generate_order_code(code, kwargs['order'], node, fields) + generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields) + + stats.stats += code.generate_tree().stats # turn off annotation typing, so all arguments to __init__ are accepted as # generic objects and thus can accept _HAS_DEFAULT_FACTORY. @@ -285,14 +335,8 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): node.body.stats.append(comp_directives) -def generate_init_code(init, node, fields, kw_only): +def generate_init_code(code, init, node, fields, kw_only): """ - All of these "generate_*_code" functions return a tuple of: - - code string - - placeholder dict (often empty) - - stat list (often empty) - which can then be combined later and processed once. - Notes on CPython generated "__init__": * Implemented in `_init_fn`. * The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as @@ -304,9 +348,15 @@ def generate_init_code(init, node, fields, kw_only): * seen_default and the associated error message are copied directly from Python * Call to user-defined __post_init__ function (if it exists) is copied from CPython. + + Cython behaviour deviates a little here (to be decided if this is right...) + Because the class variable from the assignment does not exist Cython fields will + return None (or whatever their type default is) if not initialized while Python + dataclasses will fall back to looking up the class variable. """ if not init or node.scope.lookup_here("__init__"): - return "", {}, [] + return + # selfname behaviour copied from the cpython module selfname = "__dataclass_self__" if "self" in fields else "self" args = [selfname] @@ -314,8 +364,7 @@ def generate_init_code(init, node, fields, kw_only): if kw_only: args.append("*") - placeholders = {} - placeholder_count = [0] + function_start_point = code.insertion_point() # create a temp to get _HAS_DEFAULT_FACTORY dataclass_module = make_dataclasses_module_callnode(node.pos) @@ -325,26 +374,10 @@ def generate_init_code(init, node, fields, kw_only): attribute=EncodedString("_HAS_DEFAULT_FACTORY") ) - def get_placeholder_name(): - while True: - name = "INIT_PLACEHOLDER_%d" % placeholder_count[0] - if (name not in placeholders - and name not in fields): - # make sure name isn't already used and doesn't - # conflict with a variable name (which is unlikely but possible) - break - placeholder_count[0] += 1 - return name - - default_factory_placeholder = get_placeholder_name() - placeholders[default_factory_placeholder] = has_default_factory - - function_body_code_lines = [] + default_factory_placeholder = code.new_placeholder(fields, has_default_factory) seen_default = False for name, field in fields.items(): - if not field.init.value: - continue entry = node.scope.lookup(name) if entry.annotation: annotation = u": %s" % entry.annotation.string.value @@ -356,50 +389,53 @@ def generate_init_code(init, node, fields, kw_only): if field.default_factory is not MISSING: ph_name = default_factory_placeholder else: - ph_name = get_placeholder_name() - placeholders[ph_name] = field.default # should be a node + ph_name = code.new_placeholder(fields, field.default) # 'default' should be a node assignment = u" = %s" % ph_name - elif seen_default and not kw_only: + elif seen_default and not kw_only and field.init.value: error(entry.pos, ("non-default argument '%s' follows default argument " "in dataclass __init__") % name) - return "", {}, [] + code.reset(function_start_point) + return - args.append(u"%s%s%s" % (name, annotation, assignment)) + if field.init.value: + args.append(u"%s%s%s" % (name, annotation, assignment)) if field.is_initvar: continue elif field.default_factory is MISSING: if field.init.value: - function_body_code_lines.append(u" %s.%s = %s" % (selfname, name, name)) + code.add_code_line(u" %s.%s = %s" % (selfname, name, name)) + elif assignment: + # not an argument to the function, but is still initialized + code.add_code_line(u" %s.%s%s" % (selfname, name, assignment)) else: - ph_name = get_placeholder_name() - placeholders[ph_name] = field.default_factory + ph_name = code.new_placeholder(fields, field.default_factory) if field.init.value: # close to: # def __init__(self, name=_PLACEHOLDER_VALUE): # self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name - function_body_code_lines.append(u" %s.%s = %s() if %s is %s else %s" % ( + code.add_code_line(u" %s.%s = %s() if %s is %s else %s" % ( selfname, name, ph_name, name, default_factory_placeholder, name)) else: # still need to use the default factory to initialize - function_body_code_lines.append(u" %s.%s = %s()" - % (selfname, name, ph_name)) - - args = u", ".join(args) - func_def = u"def __init__(%s):" % args - - code_lines = [func_def] + (function_body_code_lines or ["pass"]) + code.add_code_line(u" %s.%s = %s()" % ( + selfname, name, ph_name)) if node.scope.lookup("__post_init__"): post_init_vars = ", ".join(name for name, field in fields.items() if field.is_initvar) - code_lines.append(" %s.__post_init__(%s)" % (selfname, post_init_vars)) - return u"\n".join(code_lines), placeholders, [] + code.add_code_line(" %s.__post_init__(%s)" % (selfname, post_init_vars)) + if function_start_point == code.insertion_point(): + code.add_code_line(" pass") -def generate_repr_code(repr, node, fields): + args = u", ".join(args) + code.insert_code_line(function_start_point, u"def __init__(%s):" % args) + + +def generate_repr_code(code, repr, node, fields): """ - The CPython implementation is just: + The core of the CPython implementation is just: ['return self.__class__.__qualname__ + f"(' + ', '.join([f"{f.name}={{self.{f.name}!r}}" for f in fields]) + @@ -407,38 +443,65 @@ def generate_repr_code(repr, node, fields): The only notable difference here is self.__class__.__qualname__ -> type(self).__name__ which is because Cython currently supports Python 2. + + However, it also has some guards for recursive repr invokations. In the standard + library implementation they're done with a wrapper decorator that captures a set + (with the set keyed by id and thread). Here we create a set as a thread local + variable and key only by id. """ if not repr or node.scope.lookup("__repr__"): - return "", {}, [] - code_lines = ["def __repr__(self):"] + return + + # The recursive guard is likely a little costly, so skip it if possible. + # is_gc_simple defines where it can contain recursive objects + needs_recursive_guard = False + for name in fields.keys(): + entry = node.scope.lookup(name) + type_ = entry.type + if type_.is_memoryviewslice: + type_ = type_.dtype + if not type_.is_pyobject: + continue # no GC + if not type_.is_gc_simple: + needs_recursive_guard = True + break + + if needs_recursive_guard: + code.add_code_line("__pyx_recursive_repr_guard = __import__('threading').local()") + code.add_code_line("__pyx_recursive_repr_guard.running = set()") + code.add_code_line("def __repr__(self):") + if needs_recursive_guard: + code.add_code_line(" key = id(self)") + code.add_code_line(" guard_set = self.__pyx_recursive_repr_guard.running") + code.add_code_line(" if key in guard_set: return '...'") + code.add_code_line(" guard_set.add(key)") + code.add_code_line(" try:") strs = [u"%s={self.%s!r}" % (name, name) for name, field in fields.items() if field.repr.value and not field.is_initvar] format_string = u", ".join(strs) - code_lines.append(u' name = getattr(type(self), "__qualname__", type(self).__name__)') - code_lines.append(u" return f'{name}(%s)'" % format_string) - code_lines = u"\n".join(code_lines) - return code_lines, {}, [] + code.add_code_line(u' name = getattr(type(self), "__qualname__", type(self).__name__)') + code.add_code_line(u" return f'{name}(%s)'" % format_string) + if needs_recursive_guard: + code.add_code_line(" finally:") + code.add_code_line(" guard_set.remove(key)") -def generate_cmp_code(op, funcname, node, fields): +def generate_cmp_code(code, op, funcname, node, fields): if node.scope.lookup_here(funcname): - return "", {}, [] + return names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)] - if not names: - return "", {}, [] # no comparable types - - code_lines = [ + code.add_code_lines([ "def %s(self, other):" % funcname, + " if not isinstance(other, %s):" % node.class_name, + " return NotImplemented", + # " cdef %s other_cast" % node.class_name, - " if isinstance(other, %s):" % node.class_name, - " other_cast = <%s>other" % node.class_name, - " else:", - " return NotImplemented" - ] + " other_cast = <%s>other" % node.class_name, + ]) # The Python implementation of dataclasses.py does a tuple comparison # (roughly): @@ -456,42 +519,32 @@ def generate_cmp_code(op, funcname, node, fields): name, op, name)) if checks: - code_lines.append(" return " + " and ".join(checks)) + code.add_code_line(" return " + " and ".join(checks)) else: if "=" in op: - code_lines.append(" return True") # "() == ()" is True + code.add_code_line(" return True") # "() == ()" is True else: - code_lines.append(" return False") + code.add_code_line(" return False") - code_lines = u"\n".join(code_lines) - return code_lines, {}, [] - - -def generate_eq_code(eq, node, fields): +def generate_eq_code(code, eq, node, fields): if not eq: - return code_lines, {}, [] - return generate_cmp_code("==", "__eq__", node, fields) + return + generate_cmp_code(code, "==", "__eq__", node, fields) -def generate_order_code(order, node, fields): +def generate_order_code(code, order, node, fields): if not order: - return "", {}, [] - code_lines = [] - placeholders = {} - stats = [] + return + for op, name in [("<", "__lt__"), ("<=", "__le__"), (">", "__gt__"), (">=", "__ge__")]: - res = generate_cmp_code(op, name, node, fields) - code_lines.append(res[0]) - placeholders.update(res[1]) - stats.extend(res[2]) - return "\n".join(code_lines), placeholders, stats + generate_cmp_code(code, op, name, node, fields) -def generate_hash_code(unsafe_hash, eq, frozen, node, fields): +def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields): """ Copied from CPython implementation - the intention is to follow this as far as is possible: @@ -536,35 +589,37 @@ def generate_hash_code(unsafe_hash, eq, frozen, node, fields): if unsafe_hash: # error message taken from CPython dataclasses module error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name) - return "", {}, [] + return + if not unsafe_hash: if not eq: return if not frozen: - return "", {}, [Nodes.SingleAssignmentNode( - node.pos, - lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")), - rhs=ExprNodes.NoneNode(node.pos), - )] + code.add_extra_statements([ + Nodes.SingleAssignmentNode( + node.pos, + lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")), + rhs=ExprNodes.NoneNode(node.pos), + ) + ]) + return names = [ name for name, field in fields.items() - if (not field.is_initvar and - (field.compare.value if field.hash.value is None else field.hash.value)) + if not field.is_initvar and ( + field.compare.value if field.hash.value is None else field.hash.value) ] - if not names: - return "", {}, [] # nothing to hash # make a tuple of the hashes - tpl = u", ".join(u"hash(self.%s)" % name for name in names ) + hash_tuple_items = u", ".join(u"self.%s" % name for name in names) + if hash_tuple_items: + hash_tuple_items += u"," # ensure that one arg form is a tuple # if we're here we want to generate a hash - code_lines = dedent(u"""\ - def __hash__(self): - return hash((%s)) - """) % tpl - - return code_lines, {}, [] + code.add_code_lines([ + "def __hash__(self):", + " return hash((%s))" % hash_tuple_items, + ]) def get_field_type(pos, entry): @@ -666,8 +721,11 @@ def _set_up_dataclass_fields(node, fields, dataclass_module): name) # create an entry in the global scope for this variable to live field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name)) - field_node.entry = global_scope.declare_var(field_node.name, type=field_default.type or PyrexTypes.unspecified_type, - pos=field_default.pos, cname=field_node.name, is_cdef=1) + field_node.entry = global_scope.declare_var( + field_node.name, type=field_default.type or PyrexTypes.unspecified_type, + pos=field_default.pos, cname=field_node.name, is_cdef=True, + # TODO: do we need to set 'pytyping_modifiers' here? + ) # replace the field so that future users just receive the namenode setattr(field, attrname, field_node) |