diff options
author | da-woods <dw-git@d-woods.co.uk> | 2022-09-27 18:04:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-27 18:04:07 +0100 |
commit | 1ba6a55579e57aa9f104d01d1f962886d447ec9a (patch) | |
tree | 46a99cb29770f3a17334c4f0639ca43b0b2b4049 /Tools | |
parent | 849d14785367e8e74faf0ecc579430007cf1c6e6 (diff) | |
download | cython-1ba6a55579e57aa9f104d01d1f962886d447ec9a.tar.gz |
Fix recursive repr on cdef dataclasses (#5045)
The dataclass module specifically guards repr from being
invoked recursively. I use a slightly different method here
to do the same thing.
Part of https://github.com/cython/cython/issues/4956
Diffstat (limited to 'Tools')
-rw-r--r-- | Tools/make_dataclass_tests.py | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/Tools/make_dataclass_tests.py b/Tools/make_dataclass_tests.py index e8bd1b188..c39a4b2db 100644 --- a/Tools/make_dataclass_tests.py +++ b/Tools/make_dataclass_tests.py @@ -139,11 +139,6 @@ skip_tests = frozenset( "TestCase", "test_overwrite_fields_in_derived_class", ), # invalid C code (__pyx_base?) - ("TestReplace", "test_recursive_repr"), # recursion error - ("TestReplace", "test_recursive_repr_two_attrs"), # recursion error - ("TestReplace", "test_recursive_repr_misc_attrs"), # recursion error - ("TestReplace", "test_recursive_repr_indirection"), # recursion error - ("TestReplace", "test_recursive_repr_indirection_two"), # recursion error ( "TestCase", "test_intermediate_non_dataclass", @@ -233,12 +228,11 @@ class SubstituteNameString(ast.NodeTransformer): if node.value.find("<locals>") != -1: import re - new_value = re.sub("[\w.]*<locals>", "", node.value) + new_value = new_value2 = re.sub("[\w.]*<locals>", "", node.value) for key, value in self.substitutions.items(): - new_value2 = re.sub(f"(?<![\w])[.]{key}(?![\w])", value, new_value) - if new_value != new_value2: - node.value = new_value2 - break + new_value2 = re.sub(f"(?<![\w])[.]{key}(?![\w])", value, new_value2) + if new_value != new_value2: + node.value = new_value2 return node @@ -323,6 +317,7 @@ class ExtractDataclassesToTopLevel(ast.NodeTransformer): self.used_names.add(new_name) # hmmmm... possibly there's a few cases where there's more than one name? self.collected_substitutions[old_name] = node.name + return ast.Assign( targets=[ast.Name(id=old_name, ctx=ast.Store())], value=ast.Name(id=new_name, ctx=ast.Load()), @@ -409,6 +404,17 @@ class ExtractDataclassesToTopLevel(ast.NodeTransformer): node.body[0:0] = self.global_classes return node + def visit_AnnAssign(self, node): + # string annotations are forward declarations but the string will be wrong + # (because we're renaming the class) + if (isinstance(node.annotation, ast.Constant) and + isinstance(node.annotation.value, str)): + # although it'd be good to resolve these declarations, for the + # sake of the tests they only need to be "object" + node.annotation = ast.Name(id="object", ctx=ast.Load) + + return node + def main(): script_path = os.path.split(sys.argv[0])[0] |