summaryrefslogtreecommitdiff
path: root/Cython/Utility/CpdefEnums.pyx
blob: 39377b30ccedaf6c57e3162b4af8a944fdd50270 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#################### EnumBase ####################

cimport cython

cdef extern from *:
    int PY_VERSION_HEX

cdef object __Pyx_OrderedDict

if PY_VERSION_HEX >= 0x03060000:
    __Pyx_OrderedDict = dict
else:
    from collections import OrderedDict as __Pyx_OrderedDict

@cython.internal
cdef class __Pyx_EnumMeta(type):
    def __init__(cls, name, parents, dct):
        type.__init__(cls, name, parents, dct)
        cls.__members__ = __Pyx_OrderedDict()
    def __iter__(cls):
        return iter(cls.__members__.values())
    def __getitem__(cls, name):
        return cls.__members__[name]

# @cython.internal
cdef object __Pyx_EnumBase
class __Pyx_EnumBase(int, metaclass=__Pyx_EnumMeta):
    def __new__(cls, value, name=None):
        for v in cls:
            if v == value:
                return v
        if name is None:
            raise ValueError("Unknown enum value: '%s'" % value)
        res = int.__new__(cls, value)
        res.name = name
        setattr(cls, name, res)
        cls.__members__[name] = res
        return res
    def __repr__(self):
        return "<%s.%s: %d>" % (self.__class__.__name__, self.name, self)
    def __str__(self):
        return "%s.%s" % (self.__class__.__name__, self.name)

if PY_VERSION_HEX >= 0x03040000:
    from enum import IntEnum as __Pyx_EnumBase

cdef object __Pyx_FlagBase
class __Pyx_FlagBase(int, metaclass=__Pyx_EnumMeta):
    def __new__(cls, value, name=None):
        for v in cls:
            if v == value:
                return v
        res = int.__new__(cls, value)
        if name is None:
            # some bitwise combination, no validation here
            res.name = ""
        else:
            res.name = name
            setattr(cls, name, res)
            cls.__members__[name] = res
        return res
    def __repr__(self):
        return "<%s.%s: %d>" % (self.__class__.__name__, self.name, self)
    def __str__(self):
        return "%s.%s" % (self.__class__.__name__, self.name)

if PY_VERSION_HEX >= 0x03060000:
    from enum import IntFlag as __Pyx_FlagBase

#################### EnumType ####################
#@requires: EnumBase

cdef extern from *:
    object {{enum_to_pyint_func}}({{name}} value)

cdef dict __Pyx_globals = globals()
if PY_VERSION_HEX >= 0x03060000:
    # create new IntFlag() - the assumption is that C enums are sufficiently commonly
    # used as flags that this is the most appropriate base class
    {{name}} = __Pyx_FlagBase('{{name}}', [
        {{for item in items}}
        ('{{item}}', {{enum_to_pyint_func}}({{item}})),
        {{endfor}}
        # Try to look up the module name dynamically if possible
    ], module=__Pyx_globals.get("__module__", '{{static_modname}}'))

    if PY_VERSION_HEX >= 0x030B0000:
        # Python 3.11 starts making the behaviour of flags stricter
        # (only including powers of 2 when iterating). Since we're using
        # "flag" because C enums *might* be used as flags, not because
        # we want strict flag behaviour, manually undo some of this.
        {{name}}._member_names_ = list({{name}}.__members__)

    {{if enum_doc is not None}}
    {{name}}.__doc__ = {{ repr(enum_doc) }}
    {{endif}}

    {{for item in items}}
    __Pyx_globals['{{item}}'] = {{name}}.{{item}}
    {{endfor}}
else:
    class {{name}}(__Pyx_FlagBase):
        {{ repr(enum_doc) if enum_doc is not None else 'pass' }}
    {{for item in items}}
    __Pyx_globals['{{item}}'] = {{name}}({{enum_to_pyint_func}}({{item}}), '{{item}}')
    {{endfor}}

#################### CppScopedEnumType ####################
#@requires: EnumBase
cdef dict __Pyx_globals = globals()

if PY_VERSION_HEX >= 0x03040000:
    __Pyx_globals["{{name}}"] = __Pyx_EnumBase('{{name}}', [
        {{for item in items}}
        ('{{item}}', <{{underlying_type}}>({{name}}.{{item}})),
        {{endfor}}
    ], module=__Pyx_globals.get("__module__", '{{static_modname}}'))
else:
    __Pyx_globals["{{name}}"] = type('{{name}}', (__Pyx_EnumBase,), {})
    {{for item in items}}
    __Pyx_globals["{{name}}"](<{{underlying_type}}>({{name}}.{{item}}), '{{item}}')
    {{endfor}}

{{if enum_doc is not None}}
__Pyx_globals["{{name}}"].__doc__ = {{ repr(enum_doc) }}
{{endif}}


#################### EnumTypeToPy ####################

@cname("{{funcname}}")
cdef {{funcname}}({{name}} c_val):
    cdef object __pyx_enum
    # There's a complication here: the Python enum wrapping is only generated
    # for enums defined in the same module that they're used in. Therefore, if
    # the enum was cimported from a different module, we try to import it.
    # If that fails we return an int equivalent as the next best option.
{{if module_name}}
    try:
        from {{module_name}} import {{name}} as __pyx_enum
    except ImportError:
        import warnings
        warnings.warn(
            f"enum class {{name}} not importable from {{module_name}}. "
            "You are probably using a cpdef enum declared in a .pxd file that "
            "does not have a .py  or .pyx file.")
        return <{{underlying_type}}>c_val
{{else}}
    __pyx_enum = {{name}}
{{endif}}
    # TODO - Cython only manages to optimize C enums to a switch currently
    if 0:
        pass
{{for item in items}}
    elif c_val == {{name}}.{{item}}:
        return __pyx_enum.{{item}}
{{endfor}}
    else:
        underlying_c_val = <{{underlying_type}}>c_val
{{if is_flag}}
        return __pyx_enum(underlying_c_val)
{{else}}
        raise ValueError(f"{underlying_c_val} is not a valid {{name}}")
{{endif}}