summaryrefslogtreecommitdiff
path: root/Lib/enum.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/enum.py')
-rw-r--r--Lib/enum.py335
1 files changed, 318 insertions, 17 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index b8787d19b8..056400d04c 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -1,8 +1,20 @@
import sys
-from collections import OrderedDict
from types import MappingProxyType, DynamicClassAttribute
+from functools import reduce
+from operator import or_ as _or_
-__all__ = ['Enum', 'IntEnum', 'unique']
+# try _collections first to reduce startup cost
+try:
+ from _collections import OrderedDict
+except ImportError:
+ from collections import OrderedDict
+
+
+__all__ = [
+ 'EnumMeta',
+ 'Enum', 'IntEnum', 'Flag', 'IntFlag',
+ 'auto', 'unique',
+ ]
def _is_descriptor(obj):
@@ -28,7 +40,6 @@ def _is_sunder(name):
name[-2:-1] != '_' and
len(name) > 2)
-
def _make_class_unpicklable(cls):
"""Make the given class un-picklable."""
def _break_on_call_reduce(self, proto):
@@ -36,6 +47,13 @@ def _make_class_unpicklable(cls):
cls.__reduce_ex__ = _break_on_call_reduce
cls.__module__ = '<unknown>'
+_auto_null = object()
+class auto:
+ """
+ Instances are replaced with an appropriate value in Enum class suites.
+ """
+ value = _auto_null
+
class _EnumDict(dict):
"""Track enum member order and ensure member names are not reused.
@@ -47,6 +65,7 @@ class _EnumDict(dict):
def __init__(self):
super().__init__()
self._member_names = []
+ self._last_values = []
def __setitem__(self, key, value):
"""Changes anything not dundered or not a descriptor.
@@ -58,21 +77,32 @@ class _EnumDict(dict):
"""
if _is_sunder(key):
- raise ValueError('_names_ are reserved for future Enum use')
+ if key not in (
+ '_order_', '_create_pseudo_member_',
+ '_generate_next_value_', '_missing_',
+ ):
+ raise ValueError('_names_ are reserved for future Enum use')
+ if key == '_generate_next_value_':
+ setattr(self, '_generate_next_value', value)
elif _is_dunder(key):
- pass
+ if key == '__order__':
+ key = '_order_'
elif key in self._member_names:
# descriptor overwriting an enum?
raise TypeError('Attempted to reuse key: %r' % key)
elif not _is_descriptor(value):
if key in self:
# enum overwriting a descriptor?
- raise TypeError('Key already defined as: %r' % self[key])
+ raise TypeError('%r already defined as: %r' % (key, self[key]))
+ if isinstance(value, auto):
+ if value.value == _auto_null:
+ value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:])
+ value = value.value
self._member_names.append(key)
+ self._last_values.append(value)
super().__setitem__(key, value)
-
# Dummy value for Enum as EnumMeta explicitly checks for it, but of course
# until EnumMeta finishes running the first time the Enum class doesn't exist.
# This is also why there are checks in EnumMeta like `if Enum is not None`
@@ -83,7 +113,13 @@ class EnumMeta(type):
"""Metaclass for Enum"""
@classmethod
def __prepare__(metacls, cls, bases):
- return _EnumDict()
+ # create the namespace dict
+ enum_dict = _EnumDict()
+ # inherit previous flags and _generate_next_value_ function
+ member_type, first_enum = metacls._get_mixins_(bases)
+ if first_enum is not None:
+ enum_dict['_generate_next_value_'] = getattr(first_enum, '_generate_next_value_', None)
+ return enum_dict
def __new__(metacls, cls, bases, classdict):
# an Enum class is final once enumeration items have been defined; it
@@ -96,12 +132,15 @@ class EnumMeta(type):
# save enum items into separate mapping so they don't get baked into
# the new class
- members = {k: classdict[k] for k in classdict._member_names}
+ enum_members = {k: classdict[k] for k in classdict._member_names}
for name in classdict._member_names:
del classdict[name]
+ # adjust the sunders
+ _order_ = classdict.pop('_order_', None)
+
# check for illegal enum names (any others?)
- invalid_names = set(members) & {'mro', }
+ invalid_names = set(enum_members) & {'mro', }
if invalid_names:
raise ValueError('Invalid enum member name: {0}'.format(
','.join(invalid_names)))
@@ -145,7 +184,7 @@ class EnumMeta(type):
# a custom __new__ is doing something funky with the values -- such as
# auto-numbering ;)
for member_name in classdict._member_names:
- value = members[member_name]
+ value = enum_members[member_name]
if not isinstance(value, tuple):
args = (value, )
else:
@@ -159,7 +198,10 @@ class EnumMeta(type):
else:
enum_member = __new__(enum_class, *args)
if not hasattr(enum_member, '_value_'):
- enum_member._value_ = member_type(*args)
+ if member_type is object:
+ enum_member._value_ = value
+ else:
+ enum_member._value_ = member_type(*args)
value = enum_member._value_
enum_member._name_ = member_name
enum_member.__objclass__ = enum_class
@@ -204,6 +246,14 @@ class EnumMeta(type):
if save_new:
enum_class.__new_member__ = __new__
enum_class.__new__ = Enum.__new__
+
+ # py3 support for definition order (helps keep py2/py3 code in sync)
+ if _order_ is not None:
+ if isinstance(_order_, str):
+ _order_ = _order_.replace(',', ' ').split()
+ if _order_ != enum_class._member_names_:
+ raise TypeError('member order does not match _order_')
+
return enum_class
def __bool__(self):
@@ -217,7 +267,7 @@ class EnumMeta(type):
This method is used both when an enum class is given a value to match
to an enumeration member (i.e. Color(3)) and for the functional API
- (i.e. Color = Enum('Color', names='red green blue')).
+ (i.e. Color = Enum('Color', names='RED GREEN BLUE')).
When used for the functional API:
@@ -325,13 +375,19 @@ class EnumMeta(type):
"""
metacls = cls.__class__
bases = (cls, ) if type is None else (type, cls)
+ _, first_enum = cls._get_mixins_(bases)
classdict = metacls.__prepare__(class_name, bases)
# special processing needed for names?
if isinstance(names, str):
names = names.replace(',', ' ').split()
if isinstance(names, (tuple, list)) and isinstance(names[0], str):
- names = [(e, i) for (i, e) in enumerate(names, start)]
+ original_names, names = names, []
+ last_values = []
+ for count, name in enumerate(original_names):
+ value = first_enum._generate_next_value_(name, start, count, last_values[:])
+ last_values.append(value)
+ names.append((name, value))
# Here, names is either an iterable of (name, value) or a mapping.
for item in names:
@@ -461,7 +517,7 @@ class Enum(metaclass=EnumMeta):
# without calling this method; this method is called by the metaclass'
# __call__ (i.e. Color(3) ), and by pickle
if type(value) is cls:
- # For lookups like Color(Color.red)
+ # For lookups like Color(Color.RED)
return value
# by-value search for a matching enum member
# see if it's in the reverse mapping (for hashable values)
@@ -473,6 +529,20 @@ class Enum(metaclass=EnumMeta):
for member in cls._member_map_.values():
if member._value_ == value:
return member
+ # still not found -- try _missing_ hook
+ return cls._missing_(value)
+
+ def _generate_next_value_(name, start, count, last_values):
+ for last_value in reversed(last_values):
+ try:
+ return last_value + 1
+ except TypeError:
+ pass
+ else:
+ return start
+
+ @classmethod
+ def _missing_(cls, value):
raise ValueError("%r is not a valid %s" % (value, cls.__name__))
def __repr__(self):
@@ -544,8 +614,21 @@ class Enum(metaclass=EnumMeta):
source = vars(source)
else:
source = module_globals
- members = {name: value for name, value in source.items()
- if filter(name)}
+ # We use an OrderedDict of sorted source keys so that the
+ # _value2member_map is populated in the same order every time
+ # for a consistent reverse mapping of number to name when there
+ # are multiple names for the same number rather than varying
+ # between runs due to hash randomization of the module dictionary.
+ members = [
+ (name, source[name])
+ for name in source.keys()
+ if filter(name)]
+ try:
+ # sort by value
+ members.sort(key=lambda t: (t[1], t[0]))
+ except TypeError:
+ # unless some values aren't comparable, in which case sort by name
+ members.sort(key=lambda t: t[0])
cls = cls(name, members, module=module)
cls.__reduce_ex__ = _reduce_ex_by_name
module_globals.update(cls.__members__)
@@ -560,6 +643,184 @@ class IntEnum(int, Enum):
def _reduce_ex_by_name(self, proto):
return self.name
+class Flag(Enum):
+ """Support for flags"""
+
+ def _generate_next_value_(name, start, count, last_values):
+ """
+ Generate the next value when not given.
+
+ name: the name of the member
+ start: the initital start value or None
+ count: the number of existing members
+ last_value: the last value assigned or None
+ """
+ if not count:
+ return start if start is not None else 1
+ for last_value in reversed(last_values):
+ try:
+ high_bit = _high_bit(last_value)
+ break
+ except Exception:
+ raise TypeError('Invalid Flag value: %r' % last_value) from None
+ return 2 ** (high_bit+1)
+
+ @classmethod
+ def _missing_(cls, value):
+ original_value = value
+ if value < 0:
+ value = ~value
+ possible_member = cls._create_pseudo_member_(value)
+ if original_value < 0:
+ possible_member = ~possible_member
+ return possible_member
+
+ @classmethod
+ def _create_pseudo_member_(cls, value):
+ """
+ Create a composite member iff value contains only members.
+ """
+ pseudo_member = cls._value2member_map_.get(value, None)
+ if pseudo_member is None:
+ # verify all bits are accounted for
+ _, extra_flags = _decompose(cls, value)
+ if extra_flags:
+ raise ValueError("%r is not a valid %s" % (value, cls.__name__))
+ # construct a singleton enum pseudo-member
+ pseudo_member = object.__new__(cls)
+ pseudo_member._name_ = None
+ pseudo_member._value_ = value
+ # use setdefault in case another thread already created a composite
+ # with this value
+ pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
+ return pseudo_member
+
+ def __contains__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other._value_ & self._value_ == other._value_
+
+ def __repr__(self):
+ cls = self.__class__
+ if self._name_ is not None:
+ return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_)
+ members, uncovered = _decompose(cls, self._value_)
+ return '<%s.%s: %r>' % (
+ cls.__name__,
+ '|'.join([str(m._name_ or m._value_) for m in members]),
+ self._value_,
+ )
+
+ def __str__(self):
+ cls = self.__class__
+ if self._name_ is not None:
+ return '%s.%s' % (cls.__name__, self._name_)
+ members, uncovered = _decompose(cls, self._value_)
+ if len(members) == 1 and members[0]._name_ is None:
+ return '%s.%r' % (cls.__name__, members[0]._value_)
+ else:
+ return '%s.%s' % (
+ cls.__name__,
+ '|'.join([str(m._name_ or m._value_) for m in members]),
+ )
+
+ def __bool__(self):
+ return bool(self._value_)
+
+ def __or__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return self.__class__(self._value_ | other._value_)
+
+ def __and__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return self.__class__(self._value_ & other._value_)
+
+ def __xor__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return self.__class__(self._value_ ^ other._value_)
+
+ def __invert__(self):
+ members, uncovered = _decompose(self.__class__, self._value_)
+ inverted_members = [
+ m for m in self.__class__
+ if m not in members and not m._value_ & self._value_
+ ]
+ inverted = reduce(_or_, inverted_members, self.__class__(0))
+ return self.__class__(inverted)
+
+
+class IntFlag(int, Flag):
+ """Support for integer-based Flags"""
+
+ @classmethod
+ def _missing_(cls, value):
+ if not isinstance(value, int):
+ raise ValueError("%r is not a valid %s" % (value, cls.__name__))
+ new_member = cls._create_pseudo_member_(value)
+ return new_member
+
+ @classmethod
+ def _create_pseudo_member_(cls, value):
+ pseudo_member = cls._value2member_map_.get(value, None)
+ if pseudo_member is None:
+ need_to_create = [value]
+ # get unaccounted for bits
+ _, extra_flags = _decompose(cls, value)
+ # timer = 10
+ while extra_flags:
+ # timer -= 1
+ bit = _high_bit(extra_flags)
+ flag_value = 2 ** bit
+ if (flag_value not in cls._value2member_map_ and
+ flag_value not in need_to_create
+ ):
+ need_to_create.append(flag_value)
+ if extra_flags == -flag_value:
+ extra_flags = 0
+ else:
+ extra_flags ^= flag_value
+ for value in reversed(need_to_create):
+ # construct singleton pseudo-members
+ pseudo_member = int.__new__(cls, value)
+ pseudo_member._name_ = None
+ pseudo_member._value_ = value
+ # use setdefault in case another thread already created a composite
+ # with this value
+ pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
+ return pseudo_member
+
+ def __or__(self, other):
+ if not isinstance(other, (self.__class__, int)):
+ return NotImplemented
+ result = self.__class__(self._value_ | self.__class__(other)._value_)
+ return result
+
+ def __and__(self, other):
+ if not isinstance(other, (self.__class__, int)):
+ return NotImplemented
+ return self.__class__(self._value_ & self.__class__(other)._value_)
+
+ def __xor__(self, other):
+ if not isinstance(other, (self.__class__, int)):
+ return NotImplemented
+ return self.__class__(self._value_ ^ self.__class__(other)._value_)
+
+ __ror__ = __or__
+ __rand__ = __and__
+ __rxor__ = __xor__
+
+ def __invert__(self):
+ result = self.__class__(~self._value_)
+ return result
+
+
+def _high_bit(value):
+ """returns index of highest bit, or -1 if value is zero or negative"""
+ return value.bit_length() - 1
+
def unique(enumeration):
"""Class decorator for enumerations ensuring unique member values."""
duplicates = []
@@ -572,3 +833,43 @@ def unique(enumeration):
raise ValueError('duplicate values found in %r: %s' %
(enumeration, alias_details))
return enumeration
+
+def _decompose(flag, value):
+ """Extract all members from the value."""
+ # _decompose is only called if the value is not named
+ not_covered = value
+ negative = value < 0
+ # issue29167: wrap accesses to _value2member_map_ in a list to avoid race
+ # conditions between iterating over it and having more psuedo-
+ # members added to it
+ if negative:
+ # only check for named flags
+ flags_to_check = [
+ (m, v)
+ for v, m in list(flag._value2member_map_.items())
+ if m.name is not None
+ ]
+ else:
+ # check for named flags and powers-of-two flags
+ flags_to_check = [
+ (m, v)
+ for v, m in list(flag._value2member_map_.items())
+ if m.name is not None or _power_of_two(v)
+ ]
+ members = []
+ for member, member_value in flags_to_check:
+ if member_value and member_value & value == member_value:
+ members.append(member)
+ not_covered &= ~member_value
+ if not members and value in flag._value2member_map_:
+ members.append(flag._value2member_map_[value])
+ members.sort(key=lambda m: m._value_, reverse=True)
+ if len(members) > 1 and members[0].value == value:
+ # we have the breakdown, don't need the value member itself
+ members.pop(0)
+ return members, not_covered
+
+def _power_of_two(value):
+ if value < 1:
+ return False
+ return value == 2 ** _high_bit(value)