summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorda-woods <dw-git@d-woods.co.uk>2023-01-10 21:34:51 +0000
committerGitHub <noreply@github.com>2023-01-10 22:34:51 +0100
commit74bcd91438a6f7f94417dfc4d50f572cfa50f820 (patch)
treef4f3fc39f1dbec1334c6ad836d1727efc99376c7
parent17fa2640293ac264788a404efc9441afff5129d7 (diff)
downloadcython-74bcd91438a6f7f94417dfc4d50f572cfa50f820.tar.gz
Improve "cpdef enum" to Python conversion (GH-4877)
Return the Python enum instead of an int. Also, use flag enums as wrapper since they behave more like C enums in that they allow or-combination. The one corner-case that isn't perfect is cpdef enums declared in a standalone pxd file - they don't actually ever generate the python wrapper. I've made these emit a warning and return an int (which I think is a reasonable solution for the moment). Closes https://github.com/cython/cython/issues/2732 Closes https://github.com/cython/cython/issues/4633
-rw-r--r--Cython/Compiler/PyrexTypes.py57
-rw-r--r--Cython/Compiler/UtilityCode.py7
-rw-r--r--Cython/Utility/CpdefEnums.pyx86
-rw-r--r--tests/run/cpdef_enums.pyx17
-rw-r--r--tests/run/cpdef_enums_import.srctree13
-rw-r--r--tests/run/cpdef_scoped_enums.pyx22
-rw-r--r--tests/run/cpp_stl_conversion.pyx2
7 files changed, 195 insertions, 9 deletions
diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py
index c3b06ba51..7e2777397 100644
--- a/Cython/Compiler/PyrexTypes.py
+++ b/Cython/Compiler/PyrexTypes.py
@@ -2787,6 +2787,12 @@ class CArrayType(CPointerBaseType):
source_code, result_code, self.size)
return code.error_goto_if_neg(call_code, error_pos)
+ def error_condition(self, result_code):
+ # It isn't possible to use CArrays as return type so the error_condition
+ # is irrelevant. Returning a falsy value does avoid an error when getting
+ # from_py_call_code from a typedef.
+ return ""
+
class CPtrType(CPointerBaseType):
# base_type CType Reference type
@@ -4257,6 +4263,25 @@ class CppScopedEnumType(CType):
def create_to_py_utility_code(self, env):
if self.to_py_function is not None:
return True
+ if self.entry.create_wrapper:
+ from .UtilityCode import CythonUtilityCode
+ self.to_py_function = "__Pyx_Enum_%s_to_py" % self.name
+ if self.entry.scope != env.global_scope():
+ module_name = self.entry.scope.qualified_name
+ else:
+ module_name = None
+ env.use_utility_code(CythonUtilityCode.load(
+ "EnumTypeToPy", "CpdefEnums.pyx",
+ context={"funcname": self.to_py_function,
+ "name": self.name,
+ "items": tuple(self.values),
+ "underlying_type": self.underlying_type.empty_declaration_code(),
+ "module_name": module_name,
+ "is_flag": False,
+ },
+ outer_module_scope=self.entry.scope # ensure that "name" is findable
+ ))
+ return True
if self.underlying_type.create_to_py_utility_code(env):
# Using a C++11 lambda here, which is fine since
# scoped enums are a C++11 feature
@@ -4383,15 +4408,47 @@ class CEnumType(CIntLike, CType):
def create_type_wrapper(self, env):
from .UtilityCode import CythonUtilityCode
+ # Generate "int"-like conversion function
+ old_to_py_function = self.to_py_function
+ self.to_py_function = None
+ CIntLike.create_to_py_utility_code(self, env)
+ enum_to_pyint_func = self.to_py_function
+ self.to_py_function = old_to_py_function # we don't actually want to overwrite this
+
env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx",
context={"name": self.name,
"items": tuple(self.values),
"enum_doc": self.doc,
+ "enum_to_pyint_func": enum_to_pyint_func,
"static_modname": env.qualified_name,
},
outer_module_scope=env.global_scope()))
+ def create_to_py_utility_code(self, env):
+ if self.to_py_function is not None:
+ return self.to_py_function
+ if not self.entry.create_wrapper:
+ return super(CEnumType, self).create_to_py_utility_code(env)
+ from .UtilityCode import CythonUtilityCode
+ self.to_py_function = "__Pyx_Enum_%s_to_py" % self.name
+ if self.entry.scope != env.global_scope():
+ module_name = self.entry.scope.qualified_name
+ else:
+ module_name = None
+ env.use_utility_code(CythonUtilityCode.load(
+ "EnumTypeToPy", "CpdefEnums.pyx",
+ context={"funcname": self.to_py_function,
+ "name": self.name,
+ "items": tuple(self.values),
+ "underlying_type": "int",
+ "module_name": module_name,
+ "is_flag": True,
+ },
+ outer_module_scope=self.entry.scope # ensure that "name" is findable
+ ))
+ return True
+
class CTupleType(CType):
# components [PyrexType]
diff --git a/Cython/Compiler/UtilityCode.py b/Cython/Compiler/UtilityCode.py
index 870abf3e5..e2df2586b 100644
--- a/Cython/Compiler/UtilityCode.py
+++ b/Cython/Compiler/UtilityCode.py
@@ -173,8 +173,15 @@ class CythonUtilityCode(Code.UtilityCodeBase):
if self.context_types:
# inject types into module scope
def scope_transform(module_node):
+ dummy_entry = object()
for name, type in self.context_types.items():
+ # Restore the old type entry after declaring the type.
+ # We need to access types in the scope, but this shouldn't alter the entry
+ # that is visible from everywhere else
+ old_type_entry = getattr(type, "entry", dummy_entry)
entry = module_node.scope.declare_type(name, type, None, visibility='extern')
+ if old_type_entry is not dummy_entry:
+ type.entry = old_type_entry
entry.in_cinclude = True
return module_node
diff --git a/Cython/Utility/CpdefEnums.pyx b/Cython/Utility/CpdefEnums.pyx
index 10a0afc54..0d311e84d 100644
--- a/Cython/Utility/CpdefEnums.pyx
+++ b/Cython/Utility/CpdefEnums.pyx
@@ -44,18 +44,53 @@ class __Pyx_EnumBase(int, metaclass=__Pyx_EnumMeta):
if PY_VERSION_HEX >= 0x03040000:
from enum import IntEnum as __Pyx_EnumBase
+cdef object __Pyx_FlagBase
+class __Pyx_FlagBase(int, metaclass=__Pyx_EnumMeta):
+ def __new__(cls, value, name=None):
+ for v in cls:
+ if v == value:
+ return v
+ res = int.__new__(cls, value)
+ if name is None:
+ # some bitwise combination, no validation here
+ res.name = ""
+ else:
+ res.name = name
+ setattr(cls, name, res)
+ cls.__members__[name] = res
+ return res
+ def __repr__(self):
+ return "<%s.%s: %d>" % (self.__class__.__name__, self.name, self)
+ def __str__(self):
+ return "%s.%s" % (self.__class__.__name__, self.name)
+
+if PY_VERSION_HEX >= 0x03060000:
+ from enum import IntFlag as __Pyx_FlagBase
+
#################### EnumType ####################
#@requires: EnumBase
+cdef extern from *:
+ object {{enum_to_pyint_func}}({{name}} value)
+
cdef dict __Pyx_globals = globals()
-if PY_VERSION_HEX >= 0x03040000:
- # create new IntEnum()
- {{name}} = __Pyx_EnumBase('{{name}}', [
+if PY_VERSION_HEX >= 0x03060000:
+ # create new IntFlag() - the assumption is that C enums are sufficiently commonly
+ # used as flags that this is the most appropriate base class
+ {{name}} = __Pyx_FlagBase('{{name}}', [
{{for item in items}}
- ('{{item}}', {{item}}),
+ ('{{item}}', {{enum_to_pyint_func}}({{item}})),
{{endfor}}
# Try to look up the module name dynamically if possible
], module=__Pyx_globals.get("__module__", '{{static_modname}}'))
+
+ if PY_VERSION_HEX >= 0x030B0000:
+ # Python 3.11 starts making the behaviour of flags stricter
+ # (only including powers of 2 when iterating). Since we're using
+ # "flag" because C enums *might* be used as flags, not because
+ # we want strict flag behaviour, manually undo some of this.
+ {{name}}._member_names_ = list({{name}}.__members__)
+
{{if enum_doc is not None}}
{{name}}.__doc__ = {{ repr(enum_doc) }}
{{endif}}
@@ -64,10 +99,10 @@ if PY_VERSION_HEX >= 0x03040000:
__Pyx_globals['{{item}}'] = {{name}}.{{item}}
{{endfor}}
else:
- class {{name}}(__Pyx_EnumBase):
+ class {{name}}(__Pyx_FlagBase):
{{ repr(enum_doc) if enum_doc is not None else 'pass' }}
{{for item in items}}
- __Pyx_globals['{{item}}'] = {{name}}({{item}}, '{{item}}')
+ __Pyx_globals['{{item}}'] = {{name}}({{enum_to_pyint_func}}({{item}}), '{{item}}')
{{endfor}}
#################### CppScopedEnumType ####################
@@ -75,7 +110,6 @@ else:
cdef dict __Pyx_globals = globals()
if PY_VERSION_HEX >= 0x03040000:
- # create new IntEnum()
__Pyx_globals["{{name}}"] = __Pyx_EnumBase('{{name}}', [
{{for item in items}}
('{{item}}', <{{underlying_type}}>({{name}}.{{item}})),
@@ -90,3 +124,41 @@ else:
{{if enum_doc is not None}}
__Pyx_globals["{{name}}"].__doc__ = {{ repr(enum_doc) }}
{{endif}}
+
+
+#################### EnumTypeToPy ####################
+
+@cname("{{funcname}}")
+cdef {{funcname}}({{name}} c_val):
+ cdef object __pyx_enum
+ # There's a complication here: the Python enum wrapping is only generated
+ # for enums defined in the same module that they're used in. Therefore, if
+ # the enum was cimported from a different module, we try to import it.
+ # If that fails we return an int equivalent as the next best option.
+{{if module_name}}
+ try:
+ from {{module_name}} import {{name}} as __pyx_enum
+ except ImportError:
+ import warnings
+ warnings.warn(
+ f"enum class {{name}} not importable from {{module_name}}. "
+ "You are probably using a cpdef enum declared in a .pxd file that "
+ "does not have a .py or .pyx file.")
+ return <{{underlying_type}}>c_val
+{{else}}
+ __pyx_enum = {{name}}
+{{endif}}
+ # TODO - Cython only manages to optimize C enums to a switch currently
+ if 0:
+ pass
+{{for item in items}}
+ elif c_val == {{name}}.{{item}}:
+ return __pyx_enum.{{item}}
+{{endfor}}
+ else:
+ underlying_c_val = <{{underlying_type}}>c_val
+{{if is_flag}}
+ return __pyx_enum(underlying_c_val)
+{{else}}
+ raise ValueError(f"{underlying_c_val} is not a valid {{name}}")
+{{endif}}
diff --git a/tests/run/cpdef_enums.pyx b/tests/run/cpdef_enums.pyx
index 82c31fb95..4a7256531 100644
--- a/tests/run/cpdef_enums.pyx
+++ b/tests/run/cpdef_enums.pyx
@@ -132,13 +132,28 @@ def check_docs():
"""
pass
+
+def to_from_py_conversion(PxdEnum val):
+ """
+ >>> to_from_py_conversion(RANK_1) is PxdEnum.RANK_1
+ True
+
+ C enums are commonly enough used as flags that it seems reasonable
+ to allow it in Cython
+ >>> to_from_py_conversion(RANK_1 | RANK_2) == (RANK_1 | RANK_2)
+ True
+ """
+ return val
+
+
def test_pickle():
"""
>>> from pickle import loads, dumps
>>> import sys
Pickling enums won't work without the enum module, so disable the test
- >>> if sys.version_info < (3, 4):
+ (now requires 3.6 for IntFlag)
+ >>> if sys.version_info < (3, 6):
... loads = dumps = lambda x: x
>>> loads(dumps(PyxEnum.TWO)) == PyxEnum.TWO
True
diff --git a/tests/run/cpdef_enums_import.srctree b/tests/run/cpdef_enums_import.srctree
index 4cf0eb7e2..928a2d0b1 100644
--- a/tests/run/cpdef_enums_import.srctree
+++ b/tests/run/cpdef_enums_import.srctree
@@ -28,13 +28,23 @@ cpdef enum NamedEnumType:
cpdef foo()
+######## enums_without_pyx.pxd #####
+
+cpdef enum EnumTypeNotInPyx:
+ AnotherEnumValue = 500
+
######## no_enums.pyx ########
from enums cimport *
+from enums_without_pyx cimport *
def get_named_enum_value():
return NamedEnumType.NamedEnumValue
+def get_named_without_pyx():
+ # This'll generate a warning but return a c int
+ return EnumTypeNotInPyx.AnotherEnumValue
+
######## import_enums_test.py ########
# We can import enums with a star import.
@@ -51,3 +61,6 @@ assert 'FOO' not in dir(no_enums)
assert 'foo' not in dir(no_enums)
assert no_enums.get_named_enum_value() == NamedEnumType.NamedEnumValue
+# In this case the enum isn't accessible from Python (by design)
+# but the conversion to Python goes through a reasonable fallback
+assert no_enums.get_named_without_pyx() == 500
diff --git a/tests/run/cpdef_scoped_enums.pyx b/tests/run/cpdef_scoped_enums.pyx
index 4c6236db3..afae84e57 100644
--- a/tests/run/cpdef_scoped_enums.pyx
+++ b/tests/run/cpdef_scoped_enums.pyx
@@ -42,6 +42,28 @@ def test_enum_doc():
pass
+def to_from_py_conversion(Enum1 val):
+ """
+ >>> to_from_py_conversion(Enum1.Item1) is Enum1.Item1
+ True
+
+ Scoped enums should not be used as flags, and therefore attempts to set them
+ with arbitrary values should fail
+ >>> to_from_py_conversion(500)
+ Traceback (most recent call last):
+ ...
+ ValueError: 500 is not a valid Enum1
+
+ # Note that the ability to bitwise-or together the two numbers is inherited
+ from the Python enum (so not in Cython's remit to prevent)
+ >>> to_from_py_conversion(Enum1.Item1 | Enum1.Item2)
+ Traceback (most recent call last):
+ ...
+ ValueError: 3 is not a valid Enum1
+ """
+ return val
+
+
def test_pickle():
"""
>>> from pickle import loads, dumps
diff --git a/tests/run/cpp_stl_conversion.pyx b/tests/run/cpp_stl_conversion.pyx
index 07d56d68b..a5c418140 100644
--- a/tests/run/cpp_stl_conversion.pyx
+++ b/tests/run/cpp_stl_conversion.pyx
@@ -270,7 +270,7 @@ cpdef enum Color:
def test_enum_map(o):
"""
>>> test_enum_map({RED: GREEN})
- {0: 1}
+ {<Color.RED: 0>: <Color.GREEN: 1>}
"""
cdef map[Color, Color] m = o
return m