diff options
-rw-r--r-- | aiogreen.py | 60 | ||||
-rw-r--r-- | tests/test_thread.py | 9 |
2 files changed, 24 insertions, 45 deletions
diff --git a/aiogreen.py b/aiogreen.py index ba3c92f..829446b 100644 --- a/aiogreen.py +++ b/aiogreen.py @@ -6,6 +6,7 @@ import eventlet.hubs.hub import functools import heapq socket = eventlet.patcher.original('socket') +threading = eventlet.patcher.original('threading') try: # Python 2 import Queue as queue @@ -20,16 +21,6 @@ try: _FUTURE_CLASSES = (asyncio.Future,) - if eventlet.patcher.is_monkey_patched('socket'): - # asyncio must use call original functions socket.socket() - # and socket.socketpair() - asyncio.base_events.socket = socket - if sys.platform == 'win32': - asyncio.windows_events.socket = socket - asyncio.windows_utils.socket = socket - else: - asyncio.unix_events.socket = socket - if sys.platform == 'win32': from asyncio.windows_utils import socketpair else: @@ -48,23 +39,24 @@ except ImportError: # Trollius >= 1.0.1 _FUTURE_CLASSES = asyncio.futures._FUTURE_CLASSES - if eventlet.patcher.is_monkey_patched('socket'): - # trollius must use call original functions socket.socket() - # and socket.socketpair() - asyncio.base_events.socket = socket - if sys.platform == 'win32': - asyncio.windows_events.socket = socket - asyncio.windows_utils.socket = socket - else: - asyncio.unix_events.socket = socket - # FIXME: patch also trollius.py3_ssl - if sys.platform == 'win32': from trollius.windows_utils import socketpair else: socketpair = socket.socketpair -threading = eventlet.patcher.original('threading') +if eventlet.patcher.is_monkey_patched('socket'): + # trollius must use call original socket and threading functions. + # Examples: socket.socket(), socket.socketpair(), + # threading.current_thread(). + asyncio.base_events.socket = socket + asyncio.events.threading = threading + if sys.platform == 'win32': + asyncio.windows_events.socket = socket + asyncio.windows_utils.socket = socket + else: + asyncio.unix_events.socket = socket + asyncio.unix_events.threading = threading + # FIXME: patch also trollius.py3_ssl _READ = eventlet.hubs.hub.READ _WRITE = eventlet.hubs.hub.WRITE @@ -85,26 +77,6 @@ def _is_main_thread(): return isinstance(threading.current_thread(), threading._MainThread) -class EventLoopPolicy(asyncio.AbstractEventLoopPolicy): - def __init__(self): - self._loop = None - - def get_event_loop(self): - if not _is_main_thread(): - return None - if self._loop is None: - self._loop = EventLoop() - return self._loop - - def new_event_loop(self): - return EventLoop() - - def set_event_loop(self, loop): - if not _is_main_thread(): - raise NotImplementedError("aiogreen can only run in the main thread") - self._loop = loop - - class SocketTransport(selector_events._SelectorSocketTransport): def __repr__(self): # override repr because _SelectorSocketTransport depends on @@ -411,3 +383,7 @@ class EventLoop(base_events.BaseEventLoop): def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): return SocketTransport(self, sock, protocol, waiter, extra, server) + + +class EventLoopPolicy(asyncio.DefaultEventLoopPolicy): + _loop_factory = EventLoop diff --git a/tests/test_thread.py b/tests/test_thread.py index f6e679d..87714ba 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -44,15 +44,18 @@ class ThreadTests(tests.TestCase): self.assertEqual(result, ["run", "run"]) def test_policy(self): - result = {'loop': 42} # sentinel, different than None + result = {'loop': 'not set'} # sentinel, different than None def work(): - result['loop'] = asyncio.get_event_loop() + try: + result['loop'] = asyncio.get_event_loop() + except AssertionError as exc: + result['loop'] = exc # get_event_loop() must return None in a different thread fut = self.loop.run_in_executor(None, work) self.loop.run_until_complete(fut) - self.assertIsNone(result['loop']) + self.assertIsInstance(result['loop'], AssertionError) if __name__ == '__main__': |