summaryrefslogtreecommitdiff
path: root/Cython/Compiler/Dataclass.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/Dataclass.py')
-rw-r--r--Cython/Compiler/Dataclass.py840
1 files changed, 840 insertions, 0 deletions
diff --git a/Cython/Compiler/Dataclass.py b/Cython/Compiler/Dataclass.py
new file mode 100644
index 000000000..e775e9182
--- /dev/null
+++ b/Cython/Compiler/Dataclass.py
@@ -0,0 +1,840 @@
+# functions to transform a c class into a dataclass
+
+from collections import OrderedDict
+from textwrap import dedent
+import operator
+
+from . import ExprNodes
+from . import Nodes
+from . import PyrexTypes
+from . import Builtin
+from . import Naming
+from .Errors import error, warning
+from .Code import UtilityCode, TempitaUtilityCode, PyxCodeWriter
+from .Visitor import VisitorTransform
+from .StringEncoding import EncodedString
+from .TreeFragment import TreeFragment
+from .ParseTreeTransforms import NormalizeTree, SkipDeclarations
+from .Options import copy_inherited_directives
+
+_dataclass_loader_utilitycode = None
+
+def make_dataclasses_module_callnode(pos):
+ global _dataclass_loader_utilitycode
+ if not _dataclass_loader_utilitycode:
+ python_utility_code = UtilityCode.load_cached("Dataclasses_fallback", "Dataclasses.py")
+ python_utility_code = EncodedString(python_utility_code.impl)
+ _dataclass_loader_utilitycode = TempitaUtilityCode.load(
+ "SpecificModuleLoader", "Dataclasses.c",
+ context={'cname': "dataclasses", 'py_code': python_utility_code.as_c_string_literal()})
+ return ExprNodes.PythonCapiCallNode(
+ pos, "__Pyx_Load_dataclasses_Module",
+ PyrexTypes.CFuncType(PyrexTypes.py_object_type, []),
+ utility_code=_dataclass_loader_utilitycode,
+ args=[],
+ )
+
+def make_dataclass_call_helper(pos, callable, kwds):
+ utility_code = UtilityCode.load_cached("DataclassesCallHelper", "Dataclasses.c")
+ func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("callable", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("kwds", PyrexTypes.py_object_type, None)
+ ],
+ )
+ return ExprNodes.PythonCapiCallNode(
+ pos,
+ function_name="__Pyx_DataclassesCallHelper",
+ func_type=func_type,
+ utility_code=utility_code,
+ args=[callable, kwds],
+ )
+
+
+class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations):
+ """
+ Cython (and Python) normally treats
+
+ class A:
+ x = 1
+
+ as generating a class attribute. However for dataclasses the `= 1` should be interpreted as
+ a default value to initialize an instance attribute with.
+ This transform therefore removes the `x=1` assignment so that the class attribute isn't
+ generated, while recording what it has removed so that it can be used in the initialization.
+ """
+ def __init__(self, names):
+ super(RemoveAssignmentsToNames, self).__init__()
+ self.names = names
+ self.removed_assignments = {}
+
+ def visit_CClassNode(self, node):
+ self.visitchildren(node)
+ return node
+
+ def visit_PyClassNode(self, node):
+ return node # go no further
+
+ def visit_FuncDefNode(self, node):
+ return node # go no further
+
+ def visit_SingleAssignmentNode(self, node):
+ if node.lhs.is_name and node.lhs.name in self.names:
+ if node.lhs.name in self.removed_assignments:
+ warning(node.pos, ("Multiple assignments for '%s' in dataclass; "
+ "using most recent") % node.lhs.name, 1)
+ self.removed_assignments[node.lhs.name] = node.rhs
+ return []
+ return node
+
+ # I believe cascaded assignment is always a syntax error with annotations
+ # so there's no need to define visit_CascadedAssignmentNode
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+ return node
+
+
+class TemplateCode(object):
+ """
+ Adds the ability to keep track of placeholder argument names to PyxCodeWriter.
+
+ Also adds extra_stats which are nodes bundled at the end when this
+ is converted to a tree.
+ """
+ _placeholder_count = 0
+
+ def __init__(self, writer=None, placeholders=None, extra_stats=None):
+ self.writer = PyxCodeWriter() if writer is None else writer
+ self.placeholders = {} if placeholders is None else placeholders
+ self.extra_stats = [] if extra_stats is None else extra_stats
+
+ def add_code_line(self, code_line):
+ self.writer.putln(code_line)
+
+ def add_code_lines(self, code_lines):
+ for line in code_lines:
+ self.writer.putln(line)
+
+ def reset(self):
+ # don't attempt to reset placeholders - it really doesn't matter if
+ # we have unused placeholders
+ self.writer.reset()
+
+ def empty(self):
+ return self.writer.empty()
+
+ def indenter(self):
+ return self.writer.indenter()
+
+ def new_placeholder(self, field_names, value):
+ name = self._new_placeholder_name(field_names)
+ self.placeholders[name] = value
+ return name
+
+ def add_extra_statements(self, statements):
+ if self.extra_stats is None:
+ assert False, "Can only use add_extra_statements on top-level writer"
+ self.extra_stats.extend(statements)
+
+ def _new_placeholder_name(self, field_names):
+ while True:
+ name = "DATACLASS_PLACEHOLDER_%d" % self._placeholder_count
+ if (name not in self.placeholders
+ and name not in field_names):
+ # make sure name isn't already used and doesn't
+ # conflict with a variable name (which is unlikely but possible)
+ break
+ self._placeholder_count += 1
+ return name
+
+ def generate_tree(self, level='c_class'):
+ stat_list_node = TreeFragment(
+ self.writer.getvalue(),
+ level=level,
+ pipeline=[NormalizeTree(None)],
+ ).substitute(self.placeholders)
+
+ stat_list_node.stats += self.extra_stats
+ return stat_list_node
+
+ def insertion_point(self):
+ new_writer = self.writer.insertion_point()
+ return TemplateCode(
+ writer=new_writer,
+ placeholders=self.placeholders,
+ extra_stats=self.extra_stats
+ )
+
+
+class _MISSING_TYPE(object):
+ pass
+MISSING = _MISSING_TYPE()
+
+
+class Field(object):
+ """
+ Field is based on the dataclasses.field class from the standard library module.
+ It is used internally during the generation of Cython dataclasses to keep track
+ of the settings for individual attributes.
+
+ Attributes of this class are stored as nodes so they can be used in code construction
+ more readily (i.e. we store BoolNode rather than bool)
+ """
+ default = MISSING
+ default_factory = MISSING
+ private = False
+
+ literal_keys = ("repr", "hash", "init", "compare", "metadata")
+
+ # default values are defined by the CPython dataclasses.field
+ def __init__(self, pos, default=MISSING, default_factory=MISSING,
+ repr=None, hash=None, init=None,
+ compare=None, metadata=None,
+ is_initvar=False, is_classvar=False,
+ **additional_kwds):
+ if default is not MISSING:
+ self.default = default
+ if default_factory is not MISSING:
+ self.default_factory = default_factory
+ self.repr = repr or ExprNodes.BoolNode(pos, value=True)
+ self.hash = hash or ExprNodes.NoneNode(pos)
+ self.init = init or ExprNodes.BoolNode(pos, value=True)
+ self.compare = compare or ExprNodes.BoolNode(pos, value=True)
+ self.metadata = metadata or ExprNodes.NoneNode(pos)
+ self.is_initvar = is_initvar
+ self.is_classvar = is_classvar
+
+ for k, v in additional_kwds.items():
+ # There should not be any additional keywords!
+ error(v.pos, "cython.dataclasses.field() got an unexpected keyword argument '%s'" % k)
+
+ for field_name in self.literal_keys:
+ field_value = getattr(self, field_name)
+ if not field_value.is_literal:
+ error(field_value.pos,
+ "cython.dataclasses.field parameter '%s' must be a literal value" % field_name)
+
+ def iterate_record_node_arguments(self):
+ for key in (self.literal_keys + ('default', 'default_factory')):
+ value = getattr(self, key)
+ if value is not MISSING:
+ yield key, value
+
+
+def process_class_get_fields(node):
+ var_entries = node.scope.var_entries
+ # order of definition is used in the dataclass
+ var_entries = sorted(var_entries, key=operator.attrgetter('pos'))
+ var_names = [entry.name for entry in var_entries]
+
+ # don't treat `x = 1` as an assignment of a class attribute within the dataclass
+ transform = RemoveAssignmentsToNames(var_names)
+ transform(node)
+ default_value_assignments = transform.removed_assignments
+
+ base_type = node.base_type
+ fields = OrderedDict()
+ while base_type:
+ if base_type.is_external or not base_type.scope.implemented:
+ warning(node.pos, "Cannot reliably handle Cython dataclasses with base types "
+ "in external modules since it is not possible to tell what fields they have", 2)
+ if base_type.dataclass_fields:
+ fields = base_type.dataclass_fields.copy()
+ break
+ base_type = base_type.base_type
+
+ for entry in var_entries:
+ name = entry.name
+ is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar")
+ # TODO - classvars aren't included in "var_entries" so are missed here
+ # and thus this code is never triggered
+ is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar")
+ if name in default_value_assignments:
+ assignment = default_value_assignments[name]
+ if (isinstance(assignment, ExprNodes.CallNode)
+ and assignment.function.as_cython_attribute() == "dataclasses.field"):
+ # I believe most of this is well-enforced when it's treated as a directive
+ # but it doesn't hurt to make sure
+ valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode)
+ and isinstance(assignment.positional_args, ExprNodes.TupleNode)
+ and not assignment.positional_args.args
+ and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode)))
+ valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args)
+ if not (valid_general_call or valid_simple_call):
+ error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist "
+ "of compile-time keyword arguments")
+ continue
+ keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {}
+ if 'default' in keyword_args and 'default_factory' in keyword_args:
+ error(assignment.pos, "cannot specify both default and default_factory")
+ continue
+ field = Field(node.pos, **keyword_args)
+ else:
+ if isinstance(assignment, ExprNodes.CallNode):
+ func = assignment.function
+ if ((func.is_name and func.name == "field")
+ or (func.is_attribute and func.attribute == "field")):
+ warning(assignment.pos, "Do you mean cython.dataclasses.field instead?", 1)
+ if assignment.type in [Builtin.list_type, Builtin.dict_type, Builtin.set_type]:
+ # The standard library module generates a TypeError at runtime
+ # in this situation.
+ # Error message is copied from CPython
+ error(assignment.pos, "mutable default <class '{0}'> for field {1} is not allowed: "
+ "use default_factory".format(assignment.type.name, name))
+
+ field = Field(node.pos, default=assignment)
+ else:
+ field = Field(node.pos)
+ field.is_initvar = is_initvar
+ field.is_classvar = is_classvar
+ if entry.visibility == "private":
+ field.private = True
+ fields[name] = field
+ node.entry.type.dataclass_fields = fields
+ return fields
+
+
+def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
+ # default argument values from https://docs.python.org/3/library/dataclasses.html
+ kwargs = dict(init=True, repr=True, eq=True,
+ order=False, unsafe_hash=False,
+ frozen=False, kw_only=False)
+ if dataclass_args is not None:
+ if dataclass_args[0]:
+ error(node.pos, "cython.dataclasses.dataclass takes no positional arguments")
+ for k, v in dataclass_args[1].items():
+ if k not in kwargs:
+ error(node.pos,
+ "cython.dataclasses.dataclass() got an unexpected keyword argument '%s'" % k)
+ if not isinstance(v, ExprNodes.BoolNode):
+ error(node.pos,
+ "Arguments passed to cython.dataclasses.dataclass must be True or False")
+ kwargs[k] = v.value
+
+ kw_only = kwargs['kw_only']
+
+ fields = process_class_get_fields(node)
+
+ dataclass_module = make_dataclasses_module_callnode(node.pos)
+
+ # create __dataclass_params__ attribute. I try to use the exact
+ # `_DataclassParams` class defined in the standard library module if at all possible
+ # for maximum duck-typing compatibility.
+ dataclass_params_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
+ attribute=EncodedString("_DataclassParams"))
+ dataclass_params_keywords = ExprNodes.DictNode.from_pairs(
+ node.pos,
+ [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
+ ExprNodes.BoolNode(node.pos, value=v))
+ for k, v in kwargs.items() ] +
+ [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
+ ExprNodes.BoolNode(node.pos, value=v))
+ for k, v in [('kw_only', kw_only), ('match_args', False),
+ ('slots', False), ('weakref_slot', False)]
+ ])
+ dataclass_params = make_dataclass_call_helper(
+ node.pos, dataclass_params_func, dataclass_params_keywords)
+ dataclass_params_assignment = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs = ExprNodes.NameNode(node.pos, name=EncodedString("__dataclass_params__")),
+ rhs = dataclass_params)
+
+ dataclass_fields_stats = _set_up_dataclass_fields(node, fields, dataclass_module)
+
+ stats = Nodes.StatListNode(node.pos,
+ stats=[dataclass_params_assignment] + dataclass_fields_stats)
+
+ code = TemplateCode()
+ generate_init_code(code, kwargs['init'], node, fields, kw_only)
+ generate_repr_code(code, kwargs['repr'], node, fields)
+ generate_eq_code(code, kwargs['eq'], node, fields)
+ generate_order_code(code, kwargs['order'], node, fields)
+ generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields)
+
+ stats.stats += code.generate_tree().stats
+
+ # turn off annotation typing, so all arguments to __init__ are accepted as
+ # generic objects and thus can accept _HAS_DEFAULT_FACTORY.
+ # Type conversion comes later
+ comp_directives = Nodes.CompilerDirectivesNode(node.pos,
+ directives=copy_inherited_directives(node.scope.directives, annotation_typing=False),
+ body=stats)
+
+ comp_directives.analyse_declarations(node.scope)
+ # probably already in this scope, but it doesn't hurt to make sure
+ analyse_decs_transform.enter_scope(node, node.scope)
+ analyse_decs_transform.visit(comp_directives)
+ analyse_decs_transform.exit_scope()
+
+ node.body.stats.append(comp_directives)
+
+
+def generate_init_code(code, init, node, fields, kw_only):
+ """
+ Notes on CPython generated "__init__":
+ * Implemented in `_init_fn`.
+ * The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as
+ the default argument for fields that need constructing with a factory
+ function is copied from the CPython implementation. (`None` isn't
+ suitable because it could also be a value for the user to pass.)
+ There's no real reason why it needs importing from the dataclasses module
+ though - it could equally be a value generated by Cython when the module loads.
+ * seen_default and the associated error message are copied directly from Python
+ * Call to user-defined __post_init__ function (if it exists) is copied from
+ CPython.
+
+ Cython behaviour deviates a little here (to be decided if this is right...)
+ Because the class variable from the assignment does not exist Cython fields will
+ return None (or whatever their type default is) if not initialized while Python
+ dataclasses will fall back to looking up the class variable.
+ """
+ if not init or node.scope.lookup_here("__init__"):
+ return
+
+ # selfname behaviour copied from the cpython module
+ selfname = "__dataclass_self__" if "self" in fields else "self"
+ args = [selfname]
+
+ if kw_only:
+ args.append("*")
+
+ function_start_point = code.insertion_point()
+ code = code.insertion_point()
+
+ # create a temp to get _HAS_DEFAULT_FACTORY
+ dataclass_module = make_dataclasses_module_callnode(node.pos)
+ has_default_factory = ExprNodes.AttributeNode(
+ node.pos,
+ obj=dataclass_module,
+ attribute=EncodedString("_HAS_DEFAULT_FACTORY")
+ )
+
+ default_factory_placeholder = code.new_placeholder(fields, has_default_factory)
+
+ seen_default = False
+ for name, field in fields.items():
+ entry = node.scope.lookup(name)
+ if entry.annotation:
+ annotation = u": %s" % entry.annotation.string.value
+ else:
+ annotation = u""
+ assignment = u''
+ if field.default is not MISSING or field.default_factory is not MISSING:
+ seen_default = True
+ if field.default_factory is not MISSING:
+ ph_name = default_factory_placeholder
+ else:
+ ph_name = code.new_placeholder(fields, field.default) # 'default' should be a node
+ assignment = u" = %s" % ph_name
+ elif seen_default and not kw_only and field.init.value:
+ error(entry.pos, ("non-default argument '%s' follows default argument "
+ "in dataclass __init__") % name)
+ code.reset()
+ return
+
+ if field.init.value:
+ args.append(u"%s%s%s" % (name, annotation, assignment))
+
+ if field.is_initvar:
+ continue
+ elif field.default_factory is MISSING:
+ if field.init.value:
+ code.add_code_line(u" %s.%s = %s" % (selfname, name, name))
+ elif assignment:
+ # not an argument to the function, but is still initialized
+ code.add_code_line(u" %s.%s%s" % (selfname, name, assignment))
+ else:
+ ph_name = code.new_placeholder(fields, field.default_factory)
+ if field.init.value:
+ # close to:
+ # def __init__(self, name=_PLACEHOLDER_VALUE):
+ # self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name
+ code.add_code_line(u" %s.%s = %s() if %s is %s else %s" % (
+ selfname, name, ph_name, name, default_factory_placeholder, name))
+ else:
+ # still need to use the default factory to initialize
+ code.add_code_line(u" %s.%s = %s()" % (
+ selfname, name, ph_name))
+
+ if node.scope.lookup("__post_init__"):
+ post_init_vars = ", ".join(name for name, field in fields.items()
+ if field.is_initvar)
+ code.add_code_line(" %s.__post_init__(%s)" % (selfname, post_init_vars))
+
+ if code.empty():
+ code.add_code_line(" pass")
+
+ args = u", ".join(args)
+ function_start_point.add_code_line(u"def __init__(%s):" % args)
+
+
+def generate_repr_code(code, repr, node, fields):
+ """
+ The core of the CPython implementation is just:
+ ['return self.__class__.__qualname__ + f"(' +
+ ', '.join([f"{f.name}={{self.{f.name}!r}}"
+ for f in fields]) +
+ ')"'],
+
+ The only notable difference here is self.__class__.__qualname__ -> type(self).__name__
+ which is because Cython currently supports Python 2.
+
+ However, it also has some guards for recursive repr invokations. In the standard
+ library implementation they're done with a wrapper decorator that captures a set
+ (with the set keyed by id and thread). Here we create a set as a thread local
+ variable and key only by id.
+ """
+ if not repr or node.scope.lookup("__repr__"):
+ return
+
+ # The recursive guard is likely a little costly, so skip it if possible.
+ # is_gc_simple defines where it can contain recursive objects
+ needs_recursive_guard = False
+ for name in fields.keys():
+ entry = node.scope.lookup(name)
+ type_ = entry.type
+ if type_.is_memoryviewslice:
+ type_ = type_.dtype
+ if not type_.is_pyobject:
+ continue # no GC
+ if not type_.is_gc_simple:
+ needs_recursive_guard = True
+ break
+
+ if needs_recursive_guard:
+ code.add_code_line("__pyx_recursive_repr_guard = __import__('threading').local()")
+ code.add_code_line("__pyx_recursive_repr_guard.running = set()")
+ code.add_code_line("def __repr__(self):")
+ if needs_recursive_guard:
+ code.add_code_line(" key = id(self)")
+ code.add_code_line(" guard_set = self.__pyx_recursive_repr_guard.running")
+ code.add_code_line(" if key in guard_set: return '...'")
+ code.add_code_line(" guard_set.add(key)")
+ code.add_code_line(" try:")
+ strs = [u"%s={self.%s!r}" % (name, name)
+ for name, field in fields.items()
+ if field.repr.value and not field.is_initvar]
+ format_string = u", ".join(strs)
+
+ code.add_code_line(u' name = getattr(type(self), "__qualname__", type(self).__name__)')
+ code.add_code_line(u" return f'{name}(%s)'" % format_string)
+ if needs_recursive_guard:
+ code.add_code_line(" finally:")
+ code.add_code_line(" guard_set.remove(key)")
+
+
+def generate_cmp_code(code, op, funcname, node, fields):
+ if node.scope.lookup_here(funcname):
+ return
+
+ names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]
+
+ code.add_code_lines([
+ "def %s(self, other):" % funcname,
+ " if not isinstance(other, %s):" % node.class_name,
+ " return NotImplemented",
+ #
+ " cdef %s other_cast" % node.class_name,
+ " other_cast = <%s>other" % node.class_name,
+ ])
+
+ # The Python implementation of dataclasses.py does a tuple comparison
+ # (roughly):
+ # return self._attributes_to_tuple() {op} other._attributes_to_tuple()
+ #
+ # For the Cython implementation a tuple comparison isn't an option because
+ # not all attributes can be converted to Python objects and stored in a tuple
+ #
+ # TODO - better diagnostics of whether the types support comparison before
+ # generating the code. Plus, do we want to convert C structs to dicts and
+ # compare them that way (I think not, but it might be in demand)?
+ checks = []
+ for name in names:
+ checks.append("(self.%s %s other_cast.%s)" % (
+ name, op, name))
+
+ if checks:
+ code.add_code_line(" return " + " and ".join(checks))
+ else:
+ if "=" in op:
+ code.add_code_line(" return True") # "() == ()" is True
+ else:
+ code.add_code_line(" return False")
+
+
+def generate_eq_code(code, eq, node, fields):
+ if not eq:
+ return
+ generate_cmp_code(code, "==", "__eq__", node, fields)
+
+
+def generate_order_code(code, order, node, fields):
+ if not order:
+ return
+
+ for op, name in [("<", "__lt__"),
+ ("<=", "__le__"),
+ (">", "__gt__"),
+ (">=", "__ge__")]:
+ generate_cmp_code(code, op, name, node, fields)
+
+
+def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields):
+ """
+ Copied from CPython implementation - the intention is to follow this as far as
+ is possible:
+ # +------------------- unsafe_hash= parameter
+ # | +----------- eq= parameter
+ # | | +--- frozen= parameter
+ # | | |
+ # v v v | | |
+ # | no | yes | <--- class has explicitly defined __hash__
+ # +=======+=======+=======+========+========+
+ # | False | False | False | | | No __eq__, use the base class __hash__
+ # +-------+-------+-------+--------+--------+
+ # | False | False | True | | | No __eq__, use the base class __hash__
+ # +-------+-------+-------+--------+--------+
+ # | False | True | False | None | | <-- the default, not hashable
+ # +-------+-------+-------+--------+--------+
+ # | False | True | True | add | | Frozen, so hashable, allows override
+ # +-------+-------+-------+--------+--------+
+ # | True | False | False | add | raise | Has no __eq__, but hashable
+ # +-------+-------+-------+--------+--------+
+ # | True | False | True | add | raise | Has no __eq__, but hashable
+ # +-------+-------+-------+--------+--------+
+ # | True | True | False | add | raise | Not frozen, but hashable
+ # +-------+-------+-------+--------+--------+
+ # | True | True | True | add | raise | Frozen, so hashable
+ # +=======+=======+=======+========+========+
+ # For boxes that are blank, __hash__ is untouched and therefore
+ # inherited from the base class. If the base is object, then
+ # id-based hashing is used.
+
+ The Python implementation creates a tuple of all the fields, then hashes them.
+ This implementation creates a tuple of all the hashes of all the fields and hashes that.
+ The reason for this slight difference is to avoid to-Python conversions for anything
+ that Cython knows how to hash directly (It doesn't look like this currently applies to
+ anything though...).
+ """
+
+ hash_entry = node.scope.lookup_here("__hash__")
+ if hash_entry:
+ # TODO ideally assignment of __hash__ to None shouldn't trigger this
+ # but difficult to get the right information here
+ if unsafe_hash:
+ # error message taken from CPython dataclasses module
+ error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name)
+ return
+
+ if not unsafe_hash:
+ if not eq:
+ return
+ if not frozen:
+ code.add_extra_statements([
+ Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")),
+ rhs=ExprNodes.NoneNode(node.pos),
+ )
+ ])
+ return
+
+ names = [
+ name for name, field in fields.items()
+ if not field.is_initvar and (
+ field.compare.value if field.hash.value is None else field.hash.value)
+ ]
+
+ # make a tuple of the hashes
+ hash_tuple_items = u", ".join(u"self.%s" % name for name in names)
+ if hash_tuple_items:
+ hash_tuple_items += u"," # ensure that one arg form is a tuple
+
+ # if we're here we want to generate a hash
+ code.add_code_lines([
+ "def __hash__(self):",
+ " return hash((%s))" % hash_tuple_items,
+ ])
+
+
+def get_field_type(pos, entry):
+ """
+ sets the .type attribute for a field
+
+ Returns the annotation if possible (since this is what the dataclasses
+ module does). If not (for example, attributes defined with cdef) then
+ it creates a string fallback.
+ """
+ if entry.annotation:
+ # Right now it doesn't look like cdef classes generate an
+ # __annotations__ dict, therefore it's safe to just return
+ # entry.annotation
+ # (TODO: remove .string if we ditch PEP563)
+ return entry.annotation.string
+ # If they do in future then we may need to look up into that
+ # to duplicating the node. The code below should do this:
+ #class_name_node = ExprNodes.NameNode(pos, name=entry.scope.name)
+ #annotations = ExprNodes.AttributeNode(
+ # pos, obj=class_name_node,
+ # attribute=EncodedString("__annotations__")
+ #)
+ #return ExprNodes.IndexNode(
+ # pos, base=annotations,
+ # index=ExprNodes.StringNode(pos, value=entry.name)
+ #)
+ else:
+ # it's slightly unclear what the best option is here - we could
+ # try to return PyType_Type. This case should only happen with
+ # attributes defined with cdef so Cython is free to make it's own
+ # decision
+ s = EncodedString(entry.type.declaration_code("", for_display=1))
+ return ExprNodes.StringNode(pos, value=s)
+
+
+class FieldRecordNode(ExprNodes.ExprNode):
+ """
+ __dataclass_fields__ contains a bunch of field objects recording how each field
+ of the dataclass was initialized (mainly corresponding to the arguments passed to
+ the "field" function). This node is used for the attributes of these field objects.
+
+ If possible, coerces `arg` to a Python object.
+ Otherwise, generates a sensible backup string.
+ """
+ subexprs = ['arg']
+
+ def __init__(self, pos, arg):
+ super(FieldRecordNode, self).__init__(pos, arg=arg)
+
+ def analyse_types(self, env):
+ self.arg.analyse_types(env)
+ self.type = self.arg.type
+ return self
+
+ def coerce_to_pyobject(self, env):
+ if self.arg.type.can_coerce_to_pyobject(env):
+ return self.arg.coerce_to_pyobject(env)
+ else:
+ # A string representation of the code that gave the field seems like a reasonable
+ # fallback. This'll mostly happen for "default" and "default_factory" where the
+ # type may be a C-type that can't be converted to Python.
+ return self._make_string()
+
+ def _make_string(self):
+ from .AutoDocTransforms import AnnotationWriter
+ writer = AnnotationWriter(description="Dataclass field")
+ string = writer.write(self.arg)
+ return ExprNodes.StringNode(self.pos, value=EncodedString(string))
+
+ def generate_evaluation_code(self, code):
+ return self.arg.generate_evaluation_code(code)
+
+
+def _set_up_dataclass_fields(node, fields, dataclass_module):
+ # For defaults and default_factories containing things like lambda,
+ # they're already declared in the class scope, and it creates a big
+ # problem if multiple copies are floating around in both the __init__
+ # function, and in the __dataclass_fields__ structure.
+ # Therefore, create module-level constants holding these values and
+ # pass those around instead
+ #
+ # If possible we use the `Field` class defined in the standard library
+ # module so that the information stored here is as close to a regular
+ # dataclass as is possible.
+ variables_assignment_stats = []
+ for name, field in fields.items():
+ if field.private:
+ continue # doesn't appear in the public interface
+ for attrname in [ "default", "default_factory" ]:
+ field_default = getattr(field, attrname)
+ if field_default is MISSING or field_default.is_literal or field_default.is_name:
+ # some simple cases where we don't need to set up
+ # the variable as a module-level constant
+ continue
+ global_scope = node.scope.global_scope()
+ module_field_name = global_scope.mangle(
+ global_scope.mangle(Naming.dataclass_field_default_cname, node.class_name),
+ name)
+ # create an entry in the global scope for this variable to live
+ field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name))
+ field_node.entry = global_scope.declare_var(
+ field_node.name, type=field_default.type or PyrexTypes.unspecified_type,
+ pos=field_default.pos, cname=field_node.name, is_cdef=True,
+ # TODO: do we need to set 'pytyping_modifiers' here?
+ )
+ # replace the field so that future users just receive the namenode
+ setattr(field, attrname, field_node)
+
+ variables_assignment_stats.append(
+ Nodes.SingleAssignmentNode(field_default.pos, lhs=field_node, rhs=field_default))
+
+ placeholders = {}
+ field_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
+ attribute=EncodedString("field"))
+ dc_fields = ExprNodes.DictNode(node.pos, key_value_pairs=[])
+ dc_fields_namevalue_assignments = []
+
+ for name, field in fields.items():
+ if field.private:
+ continue # doesn't appear in the public interface
+ type_placeholder_name = "PLACEHOLDER_%s" % name
+ placeholders[type_placeholder_name] = get_field_type(
+ node.pos, node.scope.entries[name]
+ )
+
+ # defining these make the fields introspect more like a Python dataclass
+ field_type_placeholder_name = "PLACEHOLDER_FIELD_TYPE_%s" % name
+ if field.is_initvar:
+ placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
+ node.pos, obj=dataclass_module,
+ attribute=EncodedString("_FIELD_INITVAR")
+ )
+ elif field.is_classvar:
+ # TODO - currently this isn't triggered
+ placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
+ node.pos, obj=dataclass_module,
+ attribute=EncodedString("_FIELD_CLASSVAR")
+ )
+ else:
+ placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
+ node.pos, obj=dataclass_module,
+ attribute=EncodedString("_FIELD")
+ )
+
+ dc_field_keywords = ExprNodes.DictNode.from_pairs(
+ node.pos,
+ [(ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
+ FieldRecordNode(node.pos, arg=v))
+ for k, v in field.iterate_record_node_arguments()]
+
+ )
+ dc_field_call = make_dataclass_call_helper(
+ node.pos, field_func, dc_field_keywords
+ )
+ dc_fields.key_value_pairs.append(
+ ExprNodes.DictItemNode(
+ node.pos,
+ key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)),
+ value=dc_field_call))
+ dc_fields_namevalue_assignments.append(
+ dedent(u"""\
+ __dataclass_fields__[{0!r}].name = {0!r}
+ __dataclass_fields__[{0!r}].type = {1}
+ __dataclass_fields__[{0!r}]._field_type = {2}
+ """).format(name, type_placeholder_name, field_type_placeholder_name))
+
+ dataclass_fields_assignment = \
+ Nodes.SingleAssignmentNode(node.pos,
+ lhs = ExprNodes.NameNode(node.pos,
+ name=EncodedString("__dataclass_fields__")),
+ rhs = dc_fields)
+
+ dc_fields_namevalue_assignments = u"\n".join(dc_fields_namevalue_assignments)
+ dc_fields_namevalue_assignments = TreeFragment(dc_fields_namevalue_assignments,
+ level="c_class",
+ pipeline=[NormalizeTree(None)])
+ dc_fields_namevalue_assignments = dc_fields_namevalue_assignments.substitute(placeholders)
+
+ return (variables_assignment_stats
+ + [dataclass_fields_assignment]
+ + dc_fields_namevalue_assignments.stats)