diff options
author | da-woods <dw-git@d-woods.co.uk> | 2023-01-10 21:34:51 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-10 22:34:51 +0100 |
commit | 74bcd91438a6f7f94417dfc4d50f572cfa50f820 (patch) | |
tree | f4f3fc39f1dbec1334c6ad836d1727efc99376c7 | |
parent | 17fa2640293ac264788a404efc9441afff5129d7 (diff) | |
download | cython-74bcd91438a6f7f94417dfc4d50f572cfa50f820.tar.gz |
Improve "cpdef enum" to Python conversion (GH-4877)
Return the Python enum instead of an int.
Also, use flag enums as wrapper since they behave more like C enums in that they allow or-combination.
The one corner-case that isn't perfect is cpdef enums declared
in a standalone pxd file - they don't actually ever generate
the python wrapper. I've made these emit a warning and return
an int (which I think is a reasonable solution for the moment).
Closes https://github.com/cython/cython/issues/2732
Closes https://github.com/cython/cython/issues/4633
-rw-r--r-- | Cython/Compiler/PyrexTypes.py | 57 | ||||
-rw-r--r-- | Cython/Compiler/UtilityCode.py | 7 | ||||
-rw-r--r-- | Cython/Utility/CpdefEnums.pyx | 86 | ||||
-rw-r--r-- | tests/run/cpdef_enums.pyx | 17 | ||||
-rw-r--r-- | tests/run/cpdef_enums_import.srctree | 13 | ||||
-rw-r--r-- | tests/run/cpdef_scoped_enums.pyx | 22 | ||||
-rw-r--r-- | tests/run/cpp_stl_conversion.pyx | 2 |
7 files changed, 195 insertions, 9 deletions
diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index c3b06ba51..7e2777397 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -2787,6 +2787,12 @@ class CArrayType(CPointerBaseType): source_code, result_code, self.size) return code.error_goto_if_neg(call_code, error_pos) + def error_condition(self, result_code): + # It isn't possible to use CArrays as return type so the error_condition + # is irrelevant. Returning a falsy value does avoid an error when getting + # from_py_call_code from a typedef. + return "" + class CPtrType(CPointerBaseType): # base_type CType Reference type @@ -4257,6 +4263,25 @@ class CppScopedEnumType(CType): def create_to_py_utility_code(self, env): if self.to_py_function is not None: return True + if self.entry.create_wrapper: + from .UtilityCode import CythonUtilityCode + self.to_py_function = "__Pyx_Enum_%s_to_py" % self.name + if self.entry.scope != env.global_scope(): + module_name = self.entry.scope.qualified_name + else: + module_name = None + env.use_utility_code(CythonUtilityCode.load( + "EnumTypeToPy", "CpdefEnums.pyx", + context={"funcname": self.to_py_function, + "name": self.name, + "items": tuple(self.values), + "underlying_type": self.underlying_type.empty_declaration_code(), + "module_name": module_name, + "is_flag": False, + }, + outer_module_scope=self.entry.scope # ensure that "name" is findable + )) + return True if self.underlying_type.create_to_py_utility_code(env): # Using a C++11 lambda here, which is fine since # scoped enums are a C++11 feature @@ -4383,15 +4408,47 @@ class CEnumType(CIntLike, CType): def create_type_wrapper(self, env): from .UtilityCode import CythonUtilityCode + # Generate "int"-like conversion function + old_to_py_function = self.to_py_function + self.to_py_function = None + CIntLike.create_to_py_utility_code(self, env) + enum_to_pyint_func = self.to_py_function + self.to_py_function = old_to_py_function # we don't actually want to overwrite this + env.use_utility_code(CythonUtilityCode.load( "EnumType", "CpdefEnums.pyx", context={"name": self.name, "items": tuple(self.values), "enum_doc": self.doc, + "enum_to_pyint_func": enum_to_pyint_func, "static_modname": env.qualified_name, }, outer_module_scope=env.global_scope())) + def create_to_py_utility_code(self, env): + if self.to_py_function is not None: + return self.to_py_function + if not self.entry.create_wrapper: + return super(CEnumType, self).create_to_py_utility_code(env) + from .UtilityCode import CythonUtilityCode + self.to_py_function = "__Pyx_Enum_%s_to_py" % self.name + if self.entry.scope != env.global_scope(): + module_name = self.entry.scope.qualified_name + else: + module_name = None + env.use_utility_code(CythonUtilityCode.load( + "EnumTypeToPy", "CpdefEnums.pyx", + context={"funcname": self.to_py_function, + "name": self.name, + "items": tuple(self.values), + "underlying_type": "int", + "module_name": module_name, + "is_flag": True, + }, + outer_module_scope=self.entry.scope # ensure that "name" is findable + )) + return True + class CTupleType(CType): # components [PyrexType] diff --git a/Cython/Compiler/UtilityCode.py b/Cython/Compiler/UtilityCode.py index 870abf3e5..e2df2586b 100644 --- a/Cython/Compiler/UtilityCode.py +++ b/Cython/Compiler/UtilityCode.py @@ -173,8 +173,15 @@ class CythonUtilityCode(Code.UtilityCodeBase): if self.context_types: # inject types into module scope def scope_transform(module_node): + dummy_entry = object() for name, type in self.context_types.items(): + # Restore the old type entry after declaring the type. + # We need to access types in the scope, but this shouldn't alter the entry + # that is visible from everywhere else + old_type_entry = getattr(type, "entry", dummy_entry) entry = module_node.scope.declare_type(name, type, None, visibility='extern') + if old_type_entry is not dummy_entry: + type.entry = old_type_entry entry.in_cinclude = True return module_node diff --git a/Cython/Utility/CpdefEnums.pyx b/Cython/Utility/CpdefEnums.pyx index 10a0afc54..0d311e84d 100644 --- a/Cython/Utility/CpdefEnums.pyx +++ b/Cython/Utility/CpdefEnums.pyx @@ -44,18 +44,53 @@ class __Pyx_EnumBase(int, metaclass=__Pyx_EnumMeta): if PY_VERSION_HEX >= 0x03040000: from enum import IntEnum as __Pyx_EnumBase +cdef object __Pyx_FlagBase +class __Pyx_FlagBase(int, metaclass=__Pyx_EnumMeta): + def __new__(cls, value, name=None): + for v in cls: + if v == value: + return v + res = int.__new__(cls, value) + if name is None: + # some bitwise combination, no validation here + res.name = "" + else: + res.name = name + setattr(cls, name, res) + cls.__members__[name] = res + return res + def __repr__(self): + return "<%s.%s: %d>" % (self.__class__.__name__, self.name, self) + def __str__(self): + return "%s.%s" % (self.__class__.__name__, self.name) + +if PY_VERSION_HEX >= 0x03060000: + from enum import IntFlag as __Pyx_FlagBase + #################### EnumType #################### #@requires: EnumBase +cdef extern from *: + object {{enum_to_pyint_func}}({{name}} value) + cdef dict __Pyx_globals = globals() -if PY_VERSION_HEX >= 0x03040000: - # create new IntEnum() - {{name}} = __Pyx_EnumBase('{{name}}', [ +if PY_VERSION_HEX >= 0x03060000: + # create new IntFlag() - the assumption is that C enums are sufficiently commonly + # used as flags that this is the most appropriate base class + {{name}} = __Pyx_FlagBase('{{name}}', [ {{for item in items}} - ('{{item}}', {{item}}), + ('{{item}}', {{enum_to_pyint_func}}({{item}})), {{endfor}} # Try to look up the module name dynamically if possible ], module=__Pyx_globals.get("__module__", '{{static_modname}}')) + + if PY_VERSION_HEX >= 0x030B0000: + # Python 3.11 starts making the behaviour of flags stricter + # (only including powers of 2 when iterating). Since we're using + # "flag" because C enums *might* be used as flags, not because + # we want strict flag behaviour, manually undo some of this. + {{name}}._member_names_ = list({{name}}.__members__) + {{if enum_doc is not None}} {{name}}.__doc__ = {{ repr(enum_doc) }} {{endif}} @@ -64,10 +99,10 @@ if PY_VERSION_HEX >= 0x03040000: __Pyx_globals['{{item}}'] = {{name}}.{{item}} {{endfor}} else: - class {{name}}(__Pyx_EnumBase): + class {{name}}(__Pyx_FlagBase): {{ repr(enum_doc) if enum_doc is not None else 'pass' }} {{for item in items}} - __Pyx_globals['{{item}}'] = {{name}}({{item}}, '{{item}}') + __Pyx_globals['{{item}}'] = {{name}}({{enum_to_pyint_func}}({{item}}), '{{item}}') {{endfor}} #################### CppScopedEnumType #################### @@ -75,7 +110,6 @@ else: cdef dict __Pyx_globals = globals() if PY_VERSION_HEX >= 0x03040000: - # create new IntEnum() __Pyx_globals["{{name}}"] = __Pyx_EnumBase('{{name}}', [ {{for item in items}} ('{{item}}', <{{underlying_type}}>({{name}}.{{item}})), @@ -90,3 +124,41 @@ else: {{if enum_doc is not None}} __Pyx_globals["{{name}}"].__doc__ = {{ repr(enum_doc) }} {{endif}} + + +#################### EnumTypeToPy #################### + +@cname("{{funcname}}") +cdef {{funcname}}({{name}} c_val): + cdef object __pyx_enum + # There's a complication here: the Python enum wrapping is only generated + # for enums defined in the same module that they're used in. Therefore, if + # the enum was cimported from a different module, we try to import it. + # If that fails we return an int equivalent as the next best option. +{{if module_name}} + try: + from {{module_name}} import {{name}} as __pyx_enum + except ImportError: + import warnings + warnings.warn( + f"enum class {{name}} not importable from {{module_name}}. " + "You are probably using a cpdef enum declared in a .pxd file that " + "does not have a .py or .pyx file.") + return <{{underlying_type}}>c_val +{{else}} + __pyx_enum = {{name}} +{{endif}} + # TODO - Cython only manages to optimize C enums to a switch currently + if 0: + pass +{{for item in items}} + elif c_val == {{name}}.{{item}}: + return __pyx_enum.{{item}} +{{endfor}} + else: + underlying_c_val = <{{underlying_type}}>c_val +{{if is_flag}} + return __pyx_enum(underlying_c_val) +{{else}} + raise ValueError(f"{underlying_c_val} is not a valid {{name}}") +{{endif}} diff --git a/tests/run/cpdef_enums.pyx b/tests/run/cpdef_enums.pyx index 82c31fb95..4a7256531 100644 --- a/tests/run/cpdef_enums.pyx +++ b/tests/run/cpdef_enums.pyx @@ -132,13 +132,28 @@ def check_docs(): """ pass + +def to_from_py_conversion(PxdEnum val): + """ + >>> to_from_py_conversion(RANK_1) is PxdEnum.RANK_1 + True + + C enums are commonly enough used as flags that it seems reasonable + to allow it in Cython + >>> to_from_py_conversion(RANK_1 | RANK_2) == (RANK_1 | RANK_2) + True + """ + return val + + def test_pickle(): """ >>> from pickle import loads, dumps >>> import sys Pickling enums won't work without the enum module, so disable the test - >>> if sys.version_info < (3, 4): + (now requires 3.6 for IntFlag) + >>> if sys.version_info < (3, 6): ... loads = dumps = lambda x: x >>> loads(dumps(PyxEnum.TWO)) == PyxEnum.TWO True diff --git a/tests/run/cpdef_enums_import.srctree b/tests/run/cpdef_enums_import.srctree index 4cf0eb7e2..928a2d0b1 100644 --- a/tests/run/cpdef_enums_import.srctree +++ b/tests/run/cpdef_enums_import.srctree @@ -28,13 +28,23 @@ cpdef enum NamedEnumType: cpdef foo() +######## enums_without_pyx.pxd ##### + +cpdef enum EnumTypeNotInPyx: + AnotherEnumValue = 500 + ######## no_enums.pyx ######## from enums cimport * +from enums_without_pyx cimport * def get_named_enum_value(): return NamedEnumType.NamedEnumValue +def get_named_without_pyx(): + # This'll generate a warning but return a c int + return EnumTypeNotInPyx.AnotherEnumValue + ######## import_enums_test.py ######## # We can import enums with a star import. @@ -51,3 +61,6 @@ assert 'FOO' not in dir(no_enums) assert 'foo' not in dir(no_enums) assert no_enums.get_named_enum_value() == NamedEnumType.NamedEnumValue +# In this case the enum isn't accessible from Python (by design) +# but the conversion to Python goes through a reasonable fallback +assert no_enums.get_named_without_pyx() == 500 diff --git a/tests/run/cpdef_scoped_enums.pyx b/tests/run/cpdef_scoped_enums.pyx index 4c6236db3..afae84e57 100644 --- a/tests/run/cpdef_scoped_enums.pyx +++ b/tests/run/cpdef_scoped_enums.pyx @@ -42,6 +42,28 @@ def test_enum_doc(): pass +def to_from_py_conversion(Enum1 val): + """ + >>> to_from_py_conversion(Enum1.Item1) is Enum1.Item1 + True + + Scoped enums should not be used as flags, and therefore attempts to set them + with arbitrary values should fail + >>> to_from_py_conversion(500) + Traceback (most recent call last): + ... + ValueError: 500 is not a valid Enum1 + + # Note that the ability to bitwise-or together the two numbers is inherited + from the Python enum (so not in Cython's remit to prevent) + >>> to_from_py_conversion(Enum1.Item1 | Enum1.Item2) + Traceback (most recent call last): + ... + ValueError: 3 is not a valid Enum1 + """ + return val + + def test_pickle(): """ >>> from pickle import loads, dumps diff --git a/tests/run/cpp_stl_conversion.pyx b/tests/run/cpp_stl_conversion.pyx index 07d56d68b..a5c418140 100644 --- a/tests/run/cpp_stl_conversion.pyx +++ b/tests/run/cpp_stl_conversion.pyx @@ -270,7 +270,7 @@ cpdef enum Color: def test_enum_map(o): """ >>> test_enum_map({RED: GREEN}) - {0: 1} + {<Color.RED: 0>: <Color.GREEN: 1>} """ cdef map[Color, Color] m = o return m |