diff options
author | da-woods <dw-git@d-woods.co.uk> | 2022-09-24 13:35:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-24 13:35:25 +0100 |
commit | c2a54864217a6d4295c7f3748c15943da898b1c2 (patch) | |
tree | 015d40af661bbfd5a82218120da65ced500c6a00 | |
parent | ab1053b2b1171664038488cb6721b9e407fe5679 (diff) | |
download | cython-c2a54864217a6d4295c7f3748c15943da898b1c2.tar.gz |
Allow empty args to dataclass and field directives (#4957)
Part of the bug fixes in https://github.com/cython/cython/issues/4956
-rw-r--r-- | Cython/Compiler/Dataclass.py | 12 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 2 | ||||
-rw-r--r-- | Tools/make_dataclass_tests.py | 5 | ||||
-rw-r--r-- | tests/run/test_dataclasses.pyx | 20 |
4 files changed, 29 insertions, 10 deletions
diff --git a/Cython/Compiler/Dataclass.py b/Cython/Compiler/Dataclass.py index 609520004..327f11e19 100644 --- a/Cython/Compiler/Dataclass.py +++ b/Cython/Compiler/Dataclass.py @@ -217,14 +217,16 @@ def process_class_get_fields(node): and assignment.function.as_cython_attribute() == "dataclasses.field"): # I believe most of this is well-enforced when it's treated as a directive # but it doesn't hurt to make sure - if (not isinstance(assignment, ExprNodes.GeneralCallNode) - or not isinstance(assignment.positional_args, ExprNodes.TupleNode) - or assignment.positional_args.args - or not isinstance(assignment.keyword_args, ExprNodes.DictNode)): + valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode) + and isinstance(assignment.positional_args, ExprNodes.TupleNode) + and not assignment.positional_args.args + and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode))) + valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args) + if not (valid_general_call or valid_simple_call): error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist " "of compile-time keyword arguments") continue - keyword_args = assignment.keyword_args.as_python_dict() + keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {} if 'default' in keyword_args and 'default_factory' in keyword_args: error(assignment.pos, "cannot specify both default and default_factory") continue diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index baf5b4ef7..54d861d8a 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1225,7 +1225,7 @@ class InterpretCompilerDirectives(CythonTransform): return (optname, directivetype(optname, str(args[0].value))) elif directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS: # signal to pass things on without processing - return (optname, (args, kwds.as_python_dict())) + return (optname, (args, kwds.as_python_dict() if kwds else {})) else: assert False diff --git a/Tools/make_dataclass_tests.py b/Tools/make_dataclass_tests.py index 6a3cee7ac..25f43cc2d 100644 --- a/Tools/make_dataclass_tests.py +++ b/Tools/make_dataclass_tests.py @@ -111,6 +111,7 @@ skip_tests = frozenset( ("TestInit", "test_base_has_init"), # needs __dict__ for vars # Requires arbitrary attributes to be writeable ("TestCase", "test_post_init_super"), + ('TestCase', 'test_init_in_order'), # Cython being strict about argument types - expected difference ("TestDescriptors", "test_getting_field_calls_get"), ("TestDescriptors", "test_init_calls_set"), @@ -129,10 +130,6 @@ skip_tests = frozenset( # not possible to add attributes on extension types ("TestCase", "test_post_init_classmethod"), # Bugs - # ==== - ("TestCase", "test_no_options"), # @dataclass() - ("TestCase", "test_field_no_default"), # field() - ("TestCase", "test_init_in_order"), # field() ("TestCase", "test_hash_field_rules"), # compiler crash ("TestCase", "test_class_var"), # not sure but compiler crash ("TestCase", "test_field_order"), # invalid C code (__pyx_base?) diff --git a/tests/run/test_dataclasses.pyx b/tests/run/test_dataclasses.pyx index 8321b9de0..5ea83f82a 100644 --- a/tests/run/test_dataclasses.pyx +++ b/tests/run/test_dataclasses.pyx @@ -66,6 +66,11 @@ class C_TestCase_test_1_field_compare: @dataclass @cclass +class C_TestCase_test_field_no_default: + x: int = field() + +@dataclass +@cclass class C_TestCase_test_not_in_compare: x: int = 0 y: int = field(compare=False, default=4) @@ -80,6 +85,11 @@ class Mutable_TestCase_test_deliberately_mutable_defaults: class C_TestCase_test_deliberately_mutable_defaults: x: Mutable_TestCase_test_deliberately_mutable_defaults +@dataclass() +@cclass +class C_TestCase_test_no_options: + x: int + @dataclass @cclass class Point_TestCase_test_not_tuple: @@ -580,6 +590,12 @@ class TestCase(unittest.TestCase): self.assertGreaterEqual(C(1), C(0)) self.assertGreaterEqual(C(1), C(1)) + def test_field_no_default(self): + C = C_TestCase_test_field_no_default + self.assertEqual(C(5).x, 5) + with self.assertRaises(TypeError): + C() + def test_not_in_compare(self): C = C_TestCase_test_not_in_compare self.assertEqual(C(), C(0, 20)) @@ -599,6 +615,10 @@ class TestCase(unittest.TestCase): self.assertEqual(o1.x.l, [1, 2]) self.assertIs(o1.x, o2.x) + def test_no_options(self): + C = C_TestCase_test_no_options + self.assertEqual(C(42).x, 42) + def test_not_tuple(self): Point = Point_TestCase_test_not_tuple self.assertNotEqual(Point(1, 2), (1, 2)) |