summaryrefslogtreecommitdiff
path: root/Cython/Compiler/Dataclass.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/Dataclass.py')
-rw-r--r--Cython/Compiler/Dataclass.py328
1 files changed, 193 insertions, 135 deletions
diff --git a/Cython/Compiler/Dataclass.py b/Cython/Compiler/Dataclass.py
index 48c1888d6..7cbbab954 100644
--- a/Cython/Compiler/Dataclass.py
+++ b/Cython/Compiler/Dataclass.py
@@ -81,6 +81,59 @@ class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations):
return node
+class TemplateCode(object):
+ _placeholder_count = 0
+
+ def __init__(self):
+ self.code_lines = []
+ self.placeholders = {}
+ self.extra_stats = []
+
+ def insertion_point(self):
+ return len(self.code_lines)
+
+ def insert_code_line(self, insertion_point, code_line):
+ self.code_lines.insert(insertion_point, code_line)
+
+ def reset(self, insertion_point=0):
+ del self.code_lines[insertion_point:]
+
+ def add_code_line(self, code_line):
+ self.code_lines.append(code_line)
+
+ def add_code_lines(self, code_lines):
+ self.code_lines.extend(code_lines)
+
+ def new_placeholder(self, field_names, value):
+ name = self._new_placeholder_name(field_names)
+ self.placeholders[name] = value
+ return name
+
+ def add_extra_statements(self, statements):
+ self.extra_stats.extend(statements)
+
+ def _new_placeholder_name(self, field_names):
+ while True:
+ name = "INIT_PLACEHOLDER_%d" % self._placeholder_count
+ if (name not in self.placeholders
+ and name not in field_names):
+ # make sure name isn't already used and doesn't
+ # conflict with a variable name (which is unlikely but possible)
+ break
+ self._placeholder_count += 1
+ return name
+
+ def generate_tree(self, level='c_class'):
+ stat_list_node = TreeFragment(
+ "\n".join(self.code_lines),
+ level=level,
+ pipeline=[NormalizeTree(None)],
+ ).substitute(self.placeholders)
+
+ stat_list_node.stats += self.extra_stats
+ return stat_list_node
+
+
class _MISSING_TYPE(object):
pass
MISSING = _MISSING_TYPE()
@@ -147,33 +200,39 @@ def process_class_get_fields(node):
transform(node)
default_value_assignments = transform.removed_assignments
- if node.base_type and node.base_type.dataclass_fields:
- fields = node.base_type.dataclass_fields.copy()
- else:
- fields = OrderedDict()
+ base_type = node.base_type
+ fields = OrderedDict()
+ while base_type:
+ if base_type.is_external or not base_type.scope.implemented:
+ warning(node.pos, "Cannot reliably handle Cython dataclasses with base types "
+ "in external modules since it is not possible to tell what fields they have", 2)
+ if base_type.dataclass_fields:
+ fields = base_type.dataclass_fields.copy()
+ break
+ base_type = base_type.base_type
for entry in var_entries:
name = entry.name
- is_initvar = (entry.type.python_type_constructor_name == "dataclasses.InitVar")
+ is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar")
# TODO - classvars aren't included in "var_entries" so are missed here
# and thus this code is never triggered
- is_classvar = (entry.type.python_type_constructor_name == "typing.ClassVar")
- if is_initvar or is_classvar:
- entry.type = entry.type.resolve() # no longer need the special type
+ is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar")
if name in default_value_assignments:
assignment = default_value_assignments[name]
if (isinstance(assignment, ExprNodes.CallNode)
and assignment.function.as_cython_attribute() == "dataclasses.field"):
# I believe most of this is well-enforced when it's treated as a directive
# but it doesn't hurt to make sure
- if (not isinstance(assignment, ExprNodes.GeneralCallNode)
- or not isinstance(assignment.positional_args, ExprNodes.TupleNode)
- or assignment.positional_args.args
- or not isinstance(assignment.keyword_args, ExprNodes.DictNode)):
+ valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode)
+ and isinstance(assignment.positional_args, ExprNodes.TupleNode)
+ and not assignment.positional_args.args
+ and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode)))
+ valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args)
+ if not (valid_general_call or valid_simple_call):
error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist "
"of compile-time keyword arguments")
continue
- keyword_args = assignment.keyword_args.as_python_dict()
+ keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {}
if 'default' in keyword_args and 'default_factory' in keyword_args:
error(assignment.pos, "cannot specify both default and default_factory")
continue
@@ -218,7 +277,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
if not isinstance(v, ExprNodes.BoolNode):
error(node.pos,
"Arguments passed to cython.dataclasses.dataclass must be True or False")
- kwargs[k] = v
+ kwargs[k] = v.value
# remove everything that does not belong into _DataclassParams()
kw_only = kwargs.pop("kw_only")
@@ -251,23 +310,14 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
stats = Nodes.StatListNode(node.pos,
stats=[dataclass_params_assignment] + dataclass_fields_stats)
- code_lines = []
- placeholders = {}
- extra_stats = []
- for cl, ph, es in [ generate_init_code(kwargs['init'], node, fields, kw_only),
- generate_repr_code(kwargs['repr'], node, fields),
- generate_eq_code(kwargs['eq'], node, fields),
- generate_order_code(kwargs['order'], node, fields),
- generate_hash_code(kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields) ]:
- code_lines.append(cl)
- placeholders.update(ph)
- extra_stats.extend(extra_stats)
-
- code_lines = "\n".join(code_lines)
- code_tree = TreeFragment(code_lines, level='c_class', pipeline=[NormalizeTree(node.scope)]
- ).substitute(placeholders)
-
- stats.stats += (code_tree.stats + extra_stats)
+ code = TemplateCode()
+ generate_init_code(code, kwargs['init'], node, fields, kw_only)
+ generate_repr_code(code, kwargs['repr'], node, fields)
+ generate_eq_code(code, kwargs['eq'], node, fields)
+ generate_order_code(code, kwargs['order'], node, fields)
+ generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields)
+
+ stats.stats += code.generate_tree().stats
# turn off annotation typing, so all arguments to __init__ are accepted as
# generic objects and thus can accept _HAS_DEFAULT_FACTORY.
@@ -285,14 +335,8 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
node.body.stats.append(comp_directives)
-def generate_init_code(init, node, fields, kw_only):
+def generate_init_code(code, init, node, fields, kw_only):
"""
- All of these "generate_*_code" functions return a tuple of:
- - code string
- - placeholder dict (often empty)
- - stat list (often empty)
- which can then be combined later and processed once.
-
Notes on CPython generated "__init__":
* Implemented in `_init_fn`.
* The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as
@@ -304,9 +348,15 @@ def generate_init_code(init, node, fields, kw_only):
* seen_default and the associated error message are copied directly from Python
* Call to user-defined __post_init__ function (if it exists) is copied from
CPython.
+
+ Cython behaviour deviates a little here (to be decided if this is right...)
+ Because the class variable from the assignment does not exist Cython fields will
+ return None (or whatever their type default is) if not initialized while Python
+ dataclasses will fall back to looking up the class variable.
"""
if not init or node.scope.lookup_here("__init__"):
- return "", {}, []
+ return
+
# selfname behaviour copied from the cpython module
selfname = "__dataclass_self__" if "self" in fields else "self"
args = [selfname]
@@ -314,8 +364,7 @@ def generate_init_code(init, node, fields, kw_only):
if kw_only:
args.append("*")
- placeholders = {}
- placeholder_count = [0]
+ function_start_point = code.insertion_point()
# create a temp to get _HAS_DEFAULT_FACTORY
dataclass_module = make_dataclasses_module_callnode(node.pos)
@@ -325,26 +374,10 @@ def generate_init_code(init, node, fields, kw_only):
attribute=EncodedString("_HAS_DEFAULT_FACTORY")
)
- def get_placeholder_name():
- while True:
- name = "INIT_PLACEHOLDER_%d" % placeholder_count[0]
- if (name not in placeholders
- and name not in fields):
- # make sure name isn't already used and doesn't
- # conflict with a variable name (which is unlikely but possible)
- break
- placeholder_count[0] += 1
- return name
-
- default_factory_placeholder = get_placeholder_name()
- placeholders[default_factory_placeholder] = has_default_factory
-
- function_body_code_lines = []
+ default_factory_placeholder = code.new_placeholder(fields, has_default_factory)
seen_default = False
for name, field in fields.items():
- if not field.init.value:
- continue
entry = node.scope.lookup(name)
if entry.annotation:
annotation = u": %s" % entry.annotation.string.value
@@ -356,50 +389,53 @@ def generate_init_code(init, node, fields, kw_only):
if field.default_factory is not MISSING:
ph_name = default_factory_placeholder
else:
- ph_name = get_placeholder_name()
- placeholders[ph_name] = field.default # should be a node
+ ph_name = code.new_placeholder(fields, field.default) # 'default' should be a node
assignment = u" = %s" % ph_name
- elif seen_default and not kw_only:
+ elif seen_default and not kw_only and field.init.value:
error(entry.pos, ("non-default argument '%s' follows default argument "
"in dataclass __init__") % name)
- return "", {}, []
+ code.reset(function_start_point)
+ return
- args.append(u"%s%s%s" % (name, annotation, assignment))
+ if field.init.value:
+ args.append(u"%s%s%s" % (name, annotation, assignment))
if field.is_initvar:
continue
elif field.default_factory is MISSING:
if field.init.value:
- function_body_code_lines.append(u" %s.%s = %s" % (selfname, name, name))
+ code.add_code_line(u" %s.%s = %s" % (selfname, name, name))
+ elif assignment:
+ # not an argument to the function, but is still initialized
+ code.add_code_line(u" %s.%s%s" % (selfname, name, assignment))
else:
- ph_name = get_placeholder_name()
- placeholders[ph_name] = field.default_factory
+ ph_name = code.new_placeholder(fields, field.default_factory)
if field.init.value:
# close to:
# def __init__(self, name=_PLACEHOLDER_VALUE):
# self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name
- function_body_code_lines.append(u" %s.%s = %s() if %s is %s else %s" % (
+ code.add_code_line(u" %s.%s = %s() if %s is %s else %s" % (
selfname, name, ph_name, name, default_factory_placeholder, name))
else:
# still need to use the default factory to initialize
- function_body_code_lines.append(u" %s.%s = %s()"
- % (selfname, name, ph_name))
-
- args = u", ".join(args)
- func_def = u"def __init__(%s):" % args
-
- code_lines = [func_def] + (function_body_code_lines or ["pass"])
+ code.add_code_line(u" %s.%s = %s()" % (
+ selfname, name, ph_name))
if node.scope.lookup("__post_init__"):
post_init_vars = ", ".join(name for name, field in fields.items()
if field.is_initvar)
- code_lines.append(" %s.__post_init__(%s)" % (selfname, post_init_vars))
- return u"\n".join(code_lines), placeholders, []
+ code.add_code_line(" %s.__post_init__(%s)" % (selfname, post_init_vars))
+ if function_start_point == code.insertion_point():
+ code.add_code_line(" pass")
-def generate_repr_code(repr, node, fields):
+ args = u", ".join(args)
+ code.insert_code_line(function_start_point, u"def __init__(%s):" % args)
+
+
+def generate_repr_code(code, repr, node, fields):
"""
- The CPython implementation is just:
+ The core of the CPython implementation is just:
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
@@ -407,38 +443,65 @@ def generate_repr_code(repr, node, fields):
The only notable difference here is self.__class__.__qualname__ -> type(self).__name__
which is because Cython currently supports Python 2.
+
+ However, it also has some guards for recursive repr invokations. In the standard
+ library implementation they're done with a wrapper decorator that captures a set
+ (with the set keyed by id and thread). Here we create a set as a thread local
+ variable and key only by id.
"""
if not repr or node.scope.lookup("__repr__"):
- return "", {}, []
- code_lines = ["def __repr__(self):"]
+ return
+
+ # The recursive guard is likely a little costly, so skip it if possible.
+ # is_gc_simple defines where it can contain recursive objects
+ needs_recursive_guard = False
+ for name in fields.keys():
+ entry = node.scope.lookup(name)
+ type_ = entry.type
+ if type_.is_memoryviewslice:
+ type_ = type_.dtype
+ if not type_.is_pyobject:
+ continue # no GC
+ if not type_.is_gc_simple:
+ needs_recursive_guard = True
+ break
+
+ if needs_recursive_guard:
+ code.add_code_line("__pyx_recursive_repr_guard = __import__('threading').local()")
+ code.add_code_line("__pyx_recursive_repr_guard.running = set()")
+ code.add_code_line("def __repr__(self):")
+ if needs_recursive_guard:
+ code.add_code_line(" key = id(self)")
+ code.add_code_line(" guard_set = self.__pyx_recursive_repr_guard.running")
+ code.add_code_line(" if key in guard_set: return '...'")
+ code.add_code_line(" guard_set.add(key)")
+ code.add_code_line(" try:")
strs = [u"%s={self.%s!r}" % (name, name)
for name, field in fields.items()
if field.repr.value and not field.is_initvar]
format_string = u", ".join(strs)
- code_lines.append(u' name = getattr(type(self), "__qualname__", type(self).__name__)')
- code_lines.append(u" return f'{name}(%s)'" % format_string)
- code_lines = u"\n".join(code_lines)
- return code_lines, {}, []
+ code.add_code_line(u' name = getattr(type(self), "__qualname__", type(self).__name__)')
+ code.add_code_line(u" return f'{name}(%s)'" % format_string)
+ if needs_recursive_guard:
+ code.add_code_line(" finally:")
+ code.add_code_line(" guard_set.remove(key)")
-def generate_cmp_code(op, funcname, node, fields):
+def generate_cmp_code(code, op, funcname, node, fields):
if node.scope.lookup_here(funcname):
- return "", {}, []
+ return
names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]
- if not names:
- return "", {}, [] # no comparable types
-
- code_lines = [
+ code.add_code_lines([
"def %s(self, other):" % funcname,
+ " if not isinstance(other, %s):" % node.class_name,
+ " return NotImplemented",
+ #
" cdef %s other_cast" % node.class_name,
- " if isinstance(other, %s):" % node.class_name,
- " other_cast = <%s>other" % node.class_name,
- " else:",
- " return NotImplemented"
- ]
+ " other_cast = <%s>other" % node.class_name,
+ ])
# The Python implementation of dataclasses.py does a tuple comparison
# (roughly):
@@ -456,42 +519,32 @@ def generate_cmp_code(op, funcname, node, fields):
name, op, name))
if checks:
- code_lines.append(" return " + " and ".join(checks))
+ code.add_code_line(" return " + " and ".join(checks))
else:
if "=" in op:
- code_lines.append(" return True") # "() == ()" is True
+ code.add_code_line(" return True") # "() == ()" is True
else:
- code_lines.append(" return False")
+ code.add_code_line(" return False")
- code_lines = u"\n".join(code_lines)
- return code_lines, {}, []
-
-
-def generate_eq_code(eq, node, fields):
+def generate_eq_code(code, eq, node, fields):
if not eq:
- return code_lines, {}, []
- return generate_cmp_code("==", "__eq__", node, fields)
+ return
+ generate_cmp_code(code, "==", "__eq__", node, fields)
-def generate_order_code(order, node, fields):
+def generate_order_code(code, order, node, fields):
if not order:
- return "", {}, []
- code_lines = []
- placeholders = {}
- stats = []
+ return
+
for op, name in [("<", "__lt__"),
("<=", "__le__"),
(">", "__gt__"),
(">=", "__ge__")]:
- res = generate_cmp_code(op, name, node, fields)
- code_lines.append(res[0])
- placeholders.update(res[1])
- stats.extend(res[2])
- return "\n".join(code_lines), placeholders, stats
+ generate_cmp_code(code, op, name, node, fields)
-def generate_hash_code(unsafe_hash, eq, frozen, node, fields):
+def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields):
"""
Copied from CPython implementation - the intention is to follow this as far as
is possible:
@@ -536,35 +589,37 @@ def generate_hash_code(unsafe_hash, eq, frozen, node, fields):
if unsafe_hash:
# error message taken from CPython dataclasses module
error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name)
- return "", {}, []
+ return
+
if not unsafe_hash:
if not eq:
return
if not frozen:
- return "", {}, [Nodes.SingleAssignmentNode(
- node.pos,
- lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")),
- rhs=ExprNodes.NoneNode(node.pos),
- )]
+ code.add_extra_statements([
+ Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")),
+ rhs=ExprNodes.NoneNode(node.pos),
+ )
+ ])
+ return
names = [
name for name, field in fields.items()
- if (not field.is_initvar and
- (field.compare.value if field.hash.value is None else field.hash.value))
+ if not field.is_initvar and (
+ field.compare.value if field.hash.value is None else field.hash.value)
]
- if not names:
- return "", {}, [] # nothing to hash
# make a tuple of the hashes
- tpl = u", ".join(u"hash(self.%s)" % name for name in names )
+ hash_tuple_items = u", ".join(u"self.%s" % name for name in names)
+ if hash_tuple_items:
+ hash_tuple_items += u"," # ensure that one arg form is a tuple
# if we're here we want to generate a hash
- code_lines = dedent(u"""\
- def __hash__(self):
- return hash((%s))
- """) % tpl
-
- return code_lines, {}, []
+ code.add_code_lines([
+ "def __hash__(self):",
+ " return hash((%s))" % hash_tuple_items,
+ ])
def get_field_type(pos, entry):
@@ -666,8 +721,11 @@ def _set_up_dataclass_fields(node, fields, dataclass_module):
name)
# create an entry in the global scope for this variable to live
field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name))
- field_node.entry = global_scope.declare_var(field_node.name, type=field_default.type or PyrexTypes.unspecified_type,
- pos=field_default.pos, cname=field_node.name, is_cdef=1)
+ field_node.entry = global_scope.declare_var(
+ field_node.name, type=field_default.type or PyrexTypes.unspecified_type,
+ pos=field_default.pos, cname=field_node.name, is_cdef=True,
+ # TODO: do we need to set 'pytyping_modifiers' here?
+ )
# replace the field so that future users just receive the namenode
setattr(field, attrname, field_node)