summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aiogreen.py60
-rw-r--r--tests/test_thread.py9
2 files changed, 24 insertions, 45 deletions
diff --git a/aiogreen.py b/aiogreen.py
index ba3c92f..829446b 100644
--- a/aiogreen.py
+++ b/aiogreen.py
@@ -6,6 +6,7 @@ import eventlet.hubs.hub
import functools
import heapq
socket = eventlet.patcher.original('socket')
+threading = eventlet.patcher.original('threading')
try:
# Python 2
import Queue as queue
@@ -20,16 +21,6 @@ try:
_FUTURE_CLASSES = (asyncio.Future,)
- if eventlet.patcher.is_monkey_patched('socket'):
- # asyncio must use call original functions socket.socket()
- # and socket.socketpair()
- asyncio.base_events.socket = socket
- if sys.platform == 'win32':
- asyncio.windows_events.socket = socket
- asyncio.windows_utils.socket = socket
- else:
- asyncio.unix_events.socket = socket
-
if sys.platform == 'win32':
from asyncio.windows_utils import socketpair
else:
@@ -48,23 +39,24 @@ except ImportError:
# Trollius >= 1.0.1
_FUTURE_CLASSES = asyncio.futures._FUTURE_CLASSES
- if eventlet.patcher.is_monkey_patched('socket'):
- # trollius must use call original functions socket.socket()
- # and socket.socketpair()
- asyncio.base_events.socket = socket
- if sys.platform == 'win32':
- asyncio.windows_events.socket = socket
- asyncio.windows_utils.socket = socket
- else:
- asyncio.unix_events.socket = socket
- # FIXME: patch also trollius.py3_ssl
-
if sys.platform == 'win32':
from trollius.windows_utils import socketpair
else:
socketpair = socket.socketpair
-threading = eventlet.patcher.original('threading')
+if eventlet.patcher.is_monkey_patched('socket'):
+ # trollius must use call original socket and threading functions.
+ # Examples: socket.socket(), socket.socketpair(),
+ # threading.current_thread().
+ asyncio.base_events.socket = socket
+ asyncio.events.threading = threading
+ if sys.platform == 'win32':
+ asyncio.windows_events.socket = socket
+ asyncio.windows_utils.socket = socket
+ else:
+ asyncio.unix_events.socket = socket
+ asyncio.unix_events.threading = threading
+ # FIXME: patch also trollius.py3_ssl
_READ = eventlet.hubs.hub.READ
_WRITE = eventlet.hubs.hub.WRITE
@@ -85,26 +77,6 @@ def _is_main_thread():
return isinstance(threading.current_thread(), threading._MainThread)
-class EventLoopPolicy(asyncio.AbstractEventLoopPolicy):
- def __init__(self):
- self._loop = None
-
- def get_event_loop(self):
- if not _is_main_thread():
- return None
- if self._loop is None:
- self._loop = EventLoop()
- return self._loop
-
- def new_event_loop(self):
- return EventLoop()
-
- def set_event_loop(self, loop):
- if not _is_main_thread():
- raise NotImplementedError("aiogreen can only run in the main thread")
- self._loop = loop
-
-
class SocketTransport(selector_events._SelectorSocketTransport):
def __repr__(self):
# override repr because _SelectorSocketTransport depends on
@@ -411,3 +383,7 @@ class EventLoop(base_events.BaseEventLoop):
def _make_socket_transport(self, sock, protocol, waiter=None,
extra=None, server=None):
return SocketTransport(self, sock, protocol, waiter, extra, server)
+
+
+class EventLoopPolicy(asyncio.DefaultEventLoopPolicy):
+ _loop_factory = EventLoop
diff --git a/tests/test_thread.py b/tests/test_thread.py
index f6e679d..87714ba 100644
--- a/tests/test_thread.py
+++ b/tests/test_thread.py
@@ -44,15 +44,18 @@ class ThreadTests(tests.TestCase):
self.assertEqual(result, ["run", "run"])
def test_policy(self):
- result = {'loop': 42} # sentinel, different than None
+ result = {'loop': 'not set'} # sentinel, different than None
def work():
- result['loop'] = asyncio.get_event_loop()
+ try:
+ result['loop'] = asyncio.get_event_loop()
+ except AssertionError as exc:
+ result['loop'] = exc
# get_event_loop() must return None in a different thread
fut = self.loop.run_in_executor(None, work)
self.loop.run_until_complete(fut)
- self.assertIsNone(result['loop'])
+ self.assertIsInstance(result['loop'], AssertionError)
if __name__ == '__main__':