summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEthan Furman <ethan@stoneleaf.us>2015-03-18 18:19:30 -0700
committerEthan Furman <ethan@stoneleaf.us>2015-03-18 18:19:30 -0700
commitb1aadfb025b77d3ec50627703b5a56cc1e548a9e (patch)
tree83b09736637f97784324e6ba8a1670d9fe715fc8
parent65db4820ae5baa32d52b85cd9310863c73cb3852 (diff)
downloadcpython-b1aadfb025b77d3ec50627703b5a56cc1e548a9e.tar.gz
issue23673
add private method to enum to support replacing global constants with Enum members: - search for candidate constants via supplied filter - create new enum class and members - insert enum class and replace constants with members via supplied module name - replace __reduce_ex__ with function that returns member name, so previous Python versions can unpickle modify IntEnum classes to use new method
-rw-r--r--Lib/enum.py26
-rw-r--r--Lib/socket.py18
-rw-r--r--Lib/test/test_enum.py8
-rw-r--r--Lib/test/test_socket.py5
4 files changed, 48 insertions, 9 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index 9b19c1d34b..3cd3df8428 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -511,11 +511,37 @@ class Enum(metaclass=EnumMeta):
"""The value of the Enum member."""
return self._value_
+ @classmethod
+ def _convert(cls, name, module, filter, source=None):
+ """
+ Create a new Enum subclass that replaces a collection of global constants
+ """
+ # convert all constants from source (or module) that pass filter() to
+ # a new Enum called name, and export the enum and its members back to
+ # module;
+ # also, replace the __reduce_ex__ method so unpickling works in
+ # previous Python versions
+ module_globals = vars(sys.modules[module])
+ if source:
+ source = vars(source)
+ else:
+ source = module_globals
+ members = {name: value for name, value in source.items()
+ if filter(name)}
+ cls = cls(name, members, module=module)
+ cls.__reduce_ex__ = _reduce_ex_by_name
+ module_globals.update(cls.__members__)
+ module_globals[name] = cls
+ return cls
+
class IntEnum(int, Enum):
"""Enum where members are also (and must be) ints"""
+def _reduce_ex_by_name(self, proto):
+ return self.name
+
def unique(enumeration):
"""Class decorator for enumerations ensuring unique member values."""
duplicates = []
diff --git a/Lib/socket.py b/Lib/socket.py
index 8efd760696..004588671c 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -69,15 +69,15 @@ __all__.extend(os._get_exports_list(_socket))
# Note that _socket only knows about the integer values. The public interface
# in this module understands the enums and translates them back from integers
# where needed (e.g. .family property of a socket object).
-AddressFamily = IntEnum('AddressFamily',
- {name: value for name, value in globals().items()
- if name.isupper() and name.startswith('AF_')})
-globals().update(AddressFamily.__members__)
-
-SocketKind = IntEnum('SocketKind',
- {name: value for name, value in globals().items()
- if name.isupper() and name.startswith('SOCK_')})
-globals().update(SocketKind.__members__)
+IntEnum._convert(
+ 'AddressFamily',
+ __name__,
+ lambda C: C.isupper() and C.startswith('AF_'))
+
+IntEnum._convert(
+ 'SocketKind',
+ __name__,
+ lambda C: C.isupper() and C.startswith('SOCK_'))
def _intenum_converter(value, enum_klass):
"""Convert a numeric family value to an IntEnum member.
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index dccaa4ffaa..5db40403f3 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -581,6 +581,14 @@ class TestEnum(unittest.TestCase):
test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs,
protocol=(4, HIGHEST_PROTOCOL))
+ def test_pickle_by_name(self):
+ class ReplaceGlobalInt(IntEnum):
+ ONE = 1
+ TWO = 2
+ ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_name
+ for proto in range(HIGHEST_PROTOCOL):
+ self.assertEqual(ReplaceGlobalInt.TWO.__reduce_ex__(proto), 'TWO')
+
def test_exploding_pickle(self):
BadPickle = Enum(
'BadPickle', 'dill sweet bread-n-butter', module=__name__)
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index b412386c43..0db760f64e 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -1375,6 +1375,11 @@ class GeneralModuleTests(unittest.TestCase):
with sock:
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
self.assertRaises(TypeError, pickle.dumps, sock, protocol)
+ for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+ family = pickle.loads(pickle.dumps(socket.AF_INET, protocol))
+ self.assertEqual(family, socket.AF_INET)
+ type = pickle.loads(pickle.dumps(socket.SOCK_STREAM, protocol))
+ self.assertEqual(type, socket.SOCK_STREAM)
def test_listen_backlog(self):
for backlog in 0, -1: