From 14778962fafb44bdd1c2ce1a163d197225a6508a Mon Sep 17 00:00:00 2001 From: Chris McDonough Date: Fri, 31 Aug 2018 21:47:03 -0400 Subject: Vendor asyncore into waitress as waitress.wasyncore. (#199) Waitress has now "vendored" asyncore into itself as ``waitress.wasyncore``. This is to cope with the eventuality that asyncore will be removed from the Python standard library in 3.8 or so. --- CHANGES.txt | 6 + waitress/channel.py | 24 +- waitress/compat.py | 33 + waitress/server.py | 28 +- waitress/tests/test_server.py | 32 +- waitress/tests/test_trigger.py | 12 +- waitress/tests/test_utilities.py | 20 - waitress/tests/test_wasyncore.py | 1660 ++++++++++++++++++++++++++++++++++++++ waitress/trigger.py | 24 +- waitress/utilities.py | 11 - waitress/wasyncore.py | 664 +++++++++++++++ 11 files changed, 2437 insertions(+), 77 deletions(-) create mode 100644 waitress/tests/test_wasyncore.py create mode 100644 waitress/wasyncore.py diff --git a/CHANGES.txt b/CHANGES.txt index 184473a..e9e5fdb 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -6,7 +6,13 @@ Features - Server header can be omitted by specifying `ident=None` or `ident=''`. See https://github.com/Pylons/waitress/pull/187 + +Compatibility +~~~~~~~~~~~~~ +- Waitress has now "vendored" asyncore into itself as ``waitress.wasyncore``. + This is to cope with the eventuality that asyncore will be removed from + the Python standard library in 3.8 or so. 1.1.0 (2017-10-10) ------------------ diff --git a/waitress/channel.py b/waitress/channel.py index ca02511..7ed3461 100644 --- a/waitress/channel.py +++ b/waitress/channel.py @@ -11,7 +11,6 @@ # FOR A PARTICULAR PURPOSE. # ############################################################################## -import asyncore import socket import threading import time @@ -29,12 +28,11 @@ from waitress.task import ( WSGITask, ) -from waitress.utilities import ( - logging_dispatcher, - InternalServerError, -) +from waitress.utilities import InternalServerError + +from . import wasyncore -class HTTPChannel(logging_dispatcher, object): +class HTTPChannel(wasyncore.dispatcher, object): """ Setting self.requests = [somerequest] prevents more requests from being received until the out buffers have been flushed. @@ -76,9 +74,9 @@ class HTTPChannel(logging_dispatcher, object): # outbuf_lock used to access any outbuf self.outbuf_lock = threading.Lock() - asyncore.dispatcher.__init__(self, sock, map=map) + wasyncore.dispatcher.__init__(self, sock, map=map) - # Don't let asyncore.dispatcher throttle self.addr on us. + # Don't let wasyncore.dispatcher throttle self.addr on us. self.addr = addr def any_outbuf_has_data(self): @@ -281,23 +279,23 @@ class HTTPChannel(logging_dispatcher, object): self.logger.exception( 'Unknown exception while trying to close outbuf') self.connected = False - asyncore.dispatcher.close(self) + wasyncore.dispatcher.close(self) def add_channel(self, map=None): - """See asyncore.dispatcher + """See wasyncore.dispatcher This hook keeps track of opened channels. """ - asyncore.dispatcher.add_channel(self, map) + wasyncore.dispatcher.add_channel(self, map) self.server.active_channels[self._fileno] = self def del_channel(self, map=None): - """See asyncore.dispatcher + """See wasyncore.dispatcher This hook keeps track of closed channels. """ fd = self._fileno # next line sets this to None - asyncore.dispatcher.del_channel(self, map) + wasyncore.dispatcher.del_channel(self, map) ac = self.server.active_channels if fd in ac: del ac[fd] diff --git a/waitress/compat.py b/waitress/compat.py index 700f7a1..c758b89 100644 --- a/waitress/compat.py +++ b/waitress/compat.py @@ -1,3 +1,4 @@ +import os import sys import types import platform @@ -8,6 +9,11 @@ try: except ImportError: # pragma: no cover from urllib import parse as urlparse +try: + import fcntl +except ImportError: # pragma: no cover + fcntl = None # windows + # True if we are running on Python 3. PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 @@ -138,3 +144,30 @@ else: # pragma: no cover RuntimeWarning ) HAS_IPV6 = False + +def set_nonblocking(fd): # pragma: no cover + if PY3 and sys.version_info[1] >= 5: + os.set_blocking(fd, False) + elif fcntl is None: + raise RuntimeError('no fcntl module present') + else: + flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + +if PY3: + ResourceWarning = ResourceWarning +else: + ResourceWarning = UserWarning + +def qualname(cls): + if PY3: + return cls.__qualname__ + return cls.__name__ + +try: + import thread +except ImportError: + # py3 + import _thread as thread + diff --git a/waitress/server.py b/waitress/server.py index 79aa9b7..7175c64 100644 --- a/waitress/server.py +++ b/waitress/server.py @@ -12,7 +12,6 @@ # ############################################################################## -import asyncore import os import os.path import socket @@ -22,14 +21,13 @@ from waitress import trigger from waitress.adjustments import Adjustments from waitress.channel import HTTPChannel from waitress.task import ThreadedTaskDispatcher -from waitress.utilities import ( - cleanup_unix_socket, - logging_dispatcher, - ) +from waitress.utilities import cleanup_unix_socket + from waitress.compat import ( IPPROTO_IPV6, IPV6_V6ONLY, ) +from . import wasyncore def create_server(application, map=None, @@ -98,10 +96,10 @@ def create_server(application, # This class is only ever used if we have multiple listen sockets. It allows -# the serve() API to call .run() which starts the asyncore loop, and catches +# the serve() API to call .run() which starts the wasyncore loop, and catches # SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. class MultiSocketServer(object): - asyncore = asyncore # test shim + asyncore = wasyncore # test shim def __init__(self, map=None, @@ -131,15 +129,19 @@ class MultiSocketServer(object): use_poll=self.adj.asyncore_use_poll, ) except (SystemExit, KeyboardInterrupt): - self.task_dispatcher.shutdown() + self.close() + def close(self): + self.task_dispatcher.shutdown() + wasyncore.close_all(self.map) -class BaseWSGIServer(logging_dispatcher, object): + +class BaseWSGIServer(wasyncore.dispatcher, object): channel_class = HTTPChannel next_channel_cleanup = 0 socketmod = socket # test shim - asyncore = asyncore # test shim + asyncore = wasyncore # test shim def __init__(self, application, @@ -155,7 +157,7 @@ class BaseWSGIServer(logging_dispatcher, object): adj = Adjustments(**kw) if map is None: # use a nonglobal socket map by default to hopefully prevent - # conflicts with apps and libs that use the asyncore global socket + # conflicts with apps and libs that use the wasyncore global socket # map ala https://github.com/Pylons/waitress/issues/63 map = {} if sockinfo is None: @@ -286,6 +288,10 @@ class BaseWSGIServer(logging_dispatcher, object): def print_listen(self, format_str): # pragma: nocover print(format_str.format(self.effective_host, self.effective_port)) + def close(self): + self.trigger.close() + return wasyncore.dispatcher.close(self) + class TcpWSGIServer(BaseWSGIServer): diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py index 39b90b3..76eade5 100644 --- a/waitress/tests/test_server.py +++ b/waitress/tests/test_server.py @@ -10,7 +10,7 @@ class TestWSGIServer(unittest.TestCase): _dispatcher=None, adj=None, map=None, _start=True, _sock=None, _server=None): from waitress.server import create_server - return create_server( + self.inst = create_server( application, host=host, port=port, @@ -18,6 +18,7 @@ class TestWSGIServer(unittest.TestCase): _dispatcher=_dispatcher, _start=_start, _sock=_sock) + return self.inst def _makeOneWithMap(self, adj=None, _start=True, host='127.0.0.1', port=0, app=dummy_app): @@ -40,15 +41,21 @@ class TestWSGIServer(unittest.TestCase): task_dispatcher = DummyTaskDispatcher() map = {} from waitress.server import create_server - return create_server( + self.inst = create_server( app, listen=listen, map=map, _dispatcher=task_dispatcher, _start=_start, _sock=sock) + return self.inst + + def tearDown(self): + if self.inst is not None: + self.inst.close() def test_ctor_app_is_None(self): + self.inst = None self.assertRaises(ValueError, self._makeOneWithMap, app=None) def test_ctor_start_true(self): @@ -105,6 +112,7 @@ class TestWSGIServer(unittest.TestCase): def test_pull_trigger(self): inst = self._makeOneWithMap(_start=False) + inst.trigger.close() inst.trigger = DummyTrigger() inst.pull_trigger() self.assertEqual(inst.trigger.pulled, True) @@ -215,10 +223,10 @@ class TestWSGIServer(unittest.TestCase): from waitress.server import WSGIServer, TcpWSGIServer from waitress.adjustments import Adjustments self.assertTrue(WSGIServer is TcpWSGIServer) - inst = WSGIServer(None, _start=False, port=1234) + self.inst = WSGIServer(None, _start=False, port=1234) # Ensure the adjustment was actually applied. self.assertNotEqual(Adjustments.port, 1234) - self.assertEqual(inst.adj.port, 1234) + self.assertEqual(self.inst.adj.port, 1234) if hasattr(socket, 'AF_UNIX'): @@ -227,7 +235,7 @@ if hasattr(socket, 'AF_UNIX'): def _makeOne(self, _start=True, _sock=None): from waitress.server import create_server - return create_server( + self.inst = create_server( dummy_app, map={}, _start=_start, @@ -236,6 +244,10 @@ if hasattr(socket, 'AF_UNIX'): unix_socket=self.unix_socket, unix_socket_perms='600' ) + return self.inst + + def tearDown(self): + self.inst.close() def _makeDummy(self, *args, **kwargs): sock = DummySock(*args, **kwargs) @@ -268,13 +280,13 @@ if hasattr(socket, 'AF_UNIX'): def test_creates_new_sockinfo(self): from waitress.server import UnixWSGIServer - inst = UnixWSGIServer( + self.inst = UnixWSGIServer( dummy_app, unix_socket=self.unix_socket, unix_socket_perms='600' ) - self.assertEqual(inst.sockinfo[0], socket.AF_UNIX) + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) class DummySock(object): accepted = False @@ -317,6 +329,9 @@ class DummySock(object): def getsockname(self): return self.bound + def close(self): + pass + class DummyTaskDispatcher(object): def __init__(self): @@ -358,6 +373,9 @@ class DummyTrigger(object): def pull_trigger(self): self.pulled = True + def close(self): + pass + class DummyLogger(object): def __init__(self): diff --git a/waitress/tests/test_trigger.py b/waitress/tests/test_trigger.py index bfff16e..6bd4824 100644 --- a/waitress/tests/test_trigger.py +++ b/waitress/tests/test_trigger.py @@ -8,15 +8,19 @@ if not sys.platform.startswith("win"): def _makeOne(self, map): from waitress.trigger import trigger - return trigger(map) + self.inst = trigger(map) + return self.inst + + def tearDown(self): + self.inst.close() # prevent __del__ warning from file_dispatcher def test__close(self): map = {} inst = self._makeOne(map) - fd = os.open(os.path.abspath(__file__), os.O_RDONLY) - inst._fds = (fd,) + fd1, fd2 = inst._fds inst.close() - self.assertRaises(OSError, os.read, fd, 1) + self.assertRaises(OSError, os.read, fd1, 1) + self.assertRaises(OSError, os.read, fd2, 1) def test__physical_pull(self): map = {} diff --git a/waitress/tests/test_utilities.py b/waitress/tests/test_utilities.py index 73f6c7b..95b39f3 100644 --- a/waitress/tests/test_utilities.py +++ b/waitress/tests/test_utilities.py @@ -89,21 +89,6 @@ class Test_find_double_newline(unittest.TestCase): def test_mixed(self): self.assertEqual(self._callFUT(b'\n\n00\r\n\r\n'), 2) -class Test_logging_dispatcher(unittest.TestCase): - - def _makeOne(self): - from waitress.utilities import logging_dispatcher - return logging_dispatcher(map={}) - - def test_log_info(self): - import logging - inst = self._makeOne() - logger = DummyLogger() - inst.logger = logger - inst.log_info('message', 'warning') - self.assertEqual(logger.severity, logging.WARN) - self.assertEqual(logger.message, 'message') - class TestBadRequest(unittest.TestCase): def _makeOne(self): @@ -114,8 +99,3 @@ class TestBadRequest(unittest.TestCase): inst = self._makeOne() self.assertEqual(inst.body, 1) -class DummyLogger(object): - - def log(self, severity, message): - self.severity = severity - self.message = message diff --git a/waitress/tests/test_wasyncore.py b/waitress/tests/test_wasyncore.py new file mode 100644 index 0000000..0d896ac --- /dev/null +++ b/waitress/tests/test_wasyncore.py @@ -0,0 +1,1660 @@ +from waitress import wasyncore as asyncore +from waitress import compat +import contextlib +import functools +import gc +import unittest +import select +import os +import socket +import sys +import time +import errno +import re +import struct +import threading +import warnings + +from io import BytesIO + +TIMEOUT = 3 +HAS_UNIX_SOCKETS = hasattr(socket, 'AF_UNIX') +HOST = 'localhost' +HOSTv4 = "127.0.0.1" +HOSTv6 = "::1" + +# Filename used for testing +if os.name == 'java': # pragma: no cover + # Jython disallows @ in module names + TESTFN = '$test' +else: + TESTFN = '@test' + +TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid()) + +class DummyLogger(object): # pragma: no cover + def __init__(self): + self.messages = [] + + def log(self, severity, message): + self.messages.append((severity, message)) + +class WarningsRecorder(object): # pragma: no cover + """Convenience wrapper for the warnings list returned on + entry to the warnings.catch_warnings() context manager. + """ + def __init__(self, warnings_list): + self._warnings = warnings_list + self._last = 0 + + @property + def warnings(self): + return self._warnings[self._last:] + + def reset(self): + self._last = len(self._warnings) + + +def _filterwarnings(filters, quiet=False): # pragma: no cover + """Catch the warnings, then check if all the expected + warnings have been raised and re-raise unexpected warnings. + If 'quiet' is True, only re-raise the unexpected warnings. + """ + # Clear the warning registry of the calling module + # in order to re-raise the warnings. + frame = sys._getframe(2) + registry = frame.f_globals.get('__warningregistry__') + if registry: + registry.clear() + with warnings.catch_warnings(record=True) as w: + # Set filter "always" to record all warnings. Because + # test_warnings swap the module, we need to look up in + # the sys.modules dictionary. + sys.modules['warnings'].simplefilter("always") + yield WarningsRecorder(w) + # Filter the recorded warnings + reraise = list(w) + missing = [] + for msg, cat in filters: + seen = False + for w in reraise[:]: + warning = w.message + # Filter out the matching messages + if (re.match(msg, str(warning), re.I) and + issubclass(warning.__class__, cat)): + seen = True + reraise.remove(w) + if not seen and not quiet: + # This filter caught nothing + missing.append((msg, cat.__name__)) + if reraise: + raise AssertionError("unhandled warning %s" % reraise[0]) + if missing: + raise AssertionError("filter (%r, %s) did not catch any warning" % + missing[0]) + + +@contextlib.contextmanager +def check_warnings(*filters, **kwargs): # pragma: no cover + """Context manager to silence warnings. + + Accept 2-tuples as positional arguments: + ("message regexp", WarningCategory) + + Optional argument: + - if 'quiet' is True, it does not fail if a filter catches nothing + (default True without argument, + default False if some filters are defined) + + Without argument, it defaults to: + check_warnings(("", Warning), quiet=True) + """ + quiet = kwargs.get('quiet') + if not filters: + filters = (("", Warning),) + # Preserve backward compatibility + if quiet is None: + quiet = True + return _filterwarnings(filters, quiet) + +def gc_collect(): # pragma: no cover + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if sys.platform.startswith('java'): + time.sleep(0.1) + gc.collect() + gc.collect() + +def threading_setup(): # pragma: no cover + return (compat.thread._count(), None) + +def threading_cleanup(*original_values): # pragma: no cover + global environment_altered + + _MAX_COUNT = 100 + + for count in range(_MAX_COUNT): + values = (compat.thread._count(), None) + if values == original_values: + break + + if not count: + # Display a warning at the first iteration + environment_altered = True + sys.stderr.write( + "Warning -- threading_cleanup() failed to cleanup " + "%s threads" % (values[0] - original_values[0]) + ) + sys.stderr.flush() + + values = None + + time.sleep(0.01) + gc_collect() + + +def reap_threads(func): # pragma: no cover + """Use this function when threads are being used. This will + ensure that the threads are cleaned up even when the test fails. + """ + @functools.wraps(func) + def decorator(*args): + key = threading_setup() + try: + return func(*args) + finally: + threading_cleanup(*key) + return decorator + +def join_thread(thread, timeout=30.0): # pragma: no cover + """Join a thread. Raise an AssertionError if the thread is still alive + after timeout seconds. + """ + thread.join(timeout) + if thread.is_alive(): + msg = "failed to join the thread in %.1f seconds" % timeout + raise AssertionError(msg) + +def bind_port(sock, host=HOST): # pragma: no cover + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, 'SO_REUSEADDR'): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise RuntimeError("tests should never set the SO_REUSEADDR " \ + "socket option on TCP/IP sockets!") + if hasattr(socket, 'SO_REUSEPORT'): + try: + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEPORT " \ + "socket option on TCP/IP sockets!") + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + return port + +@contextlib.contextmanager +def closewrapper(sock): # pragma: no cover + try: + yield sock + finally: + sock.close() + +class dummysocket: # pragma: no cover + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def fileno(self): + return 42 + + def setblocking(self, yesno): + self.isblocking = yesno + + def getpeername(self): + return 'peername' + +class dummychannel: # pragma: no cover + def __init__(self): + self.socket = dummysocket() + + def close(self): + self.socket.close() + +class exitingdummy: # pragma: no cover + def __init__(self): + pass + + def handle_read_event(self): + raise asyncore.ExitNow() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + +class crashingdummy: + def __init__(self): + self.error_handled = False + + def handle_read_event(self): + raise Exception() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + def handle_error(self): + self.error_handled = True + +# used when testing senders; just collects what it gets until newline is sent +def capture_server(evt, buf, serv): # pragma no cover + try: + serv.listen(0) + conn, addr = serv.accept() + except socket.timeout: + pass + else: + n = 200 + start = time.time() + while n > 0 and time.time() - start < 3.0: + r, w, e = select.select([conn], [], [], 0.1) + if r: + n -= 1 + data = conn.recv(10) + # keep everything except for the newline terminator + buf.write(data.replace(b'\n', b'')) + if b'\n' in data: + break + time.sleep(0.01) + + conn.close() + finally: + serv.close() + evt.set() + +def bind_unix_socket(sock, addr): # pragma: no cover + """Bind a unix socket, raising SkipTest if PermissionError is raised.""" + assert sock.family == socket.AF_UNIX + try: + sock.bind(addr) + except PermissionError: + sock.close() + raise unittest.SkipTest('cannot bind AF_UNIX sockets') + +def bind_af_aware(sock, addr): + """Helper function to bind a socket according to its family.""" + if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX: + # Make sure the path doesn't exist. + unlink(addr) + bind_unix_socket(sock, addr) + else: + sock.bind(addr) + +if sys.platform.startswith("win"): # pragma: no cover + def _waitfor(func, pathname, waitall=False): + # Perform the operation + func(pathname) + # Now setup the wait loop + if waitall: + dirname = pathname + else: + dirname, name = os.path.split(pathname) + dirname = dirname or '.' + # Check for `pathname` to be removed from the filesystem. + # The exponential backoff of the timeout amounts to a total + # of ~1 second after which the deletion is probably an error + # anyway. + # Testing on an i7@4.3GHz shows that usually only 1 iteration is + # required when contention occurs. + timeout = 0.001 + while timeout < 1.0: + # Note we are only testing for the existence of the file(s) in + # the contents of the directory regardless of any security or + # access rights. If we have made it this far, we have sufficient + # permissions to do that much using Python's equivalent of the + # Windows API FindFirstFile. + # Other Windows APIs can fail or give incorrect results when + # dealing with files that are pending deletion. + L = os.listdir(dirname) + if not (L if waitall else name in L): + return + # Increase the timeout and try again + time.sleep(timeout) + timeout *= 2 + warnings.warn('tests may fail, delete still pending for ' + pathname, + RuntimeWarning, stacklevel=4) + + def _unlink(filename): + _waitfor(os.unlink, filename) +else: + _unlink = os.unlink + + +def unlink(filename): + try: + _unlink(filename) + except OSError: + pass + +def _is_ipv6_enabled(): # pragma: no cover + """Check whether IPv6 is enabled on this host.""" + if compat.HAS_IPV6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(('::1', 0)) + return True + except socket.error: + pass + finally: + if sock: + sock.close() + return False + +IPV6_ENABLED = _is_ipv6_enabled() + +class HelperFunctionTests(unittest.TestCase): + def test_readwriteexc(self): + # Check exception handling behavior of read, write and _exception + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore read/write/_exception calls + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.read, tr1) + self.assertRaises(asyncore.ExitNow, asyncore.write, tr1) + self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + asyncore.read(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore.write(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore._exception(tr2) + self.assertEqual(tr2.error_handled, True) + + # asyncore.readwrite uses constants in the select module that + # are not present in Windows systems (see this thread: + # http://mail.python.org/pipermail/python-list/2001-October/109973.html) + # These constants should be present as long as poll is available + + @unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required') + def test_readwrite(self): + # Check that correct methods are called by readwrite() + + attributes = ('read', 'expt', 'write', 'closed', 'error_handled') + + expected = ( + (select.POLLIN, 'read'), + (select.POLLPRI, 'expt'), + (select.POLLOUT, 'write'), + (select.POLLERR, 'closed'), + (select.POLLHUP, 'closed'), + (select.POLLNVAL, 'closed'), + ) + + class testobj: + def __init__(self): + self.read = False + self.write = False + self.closed = False + self.expt = False + self.error_handled = False + + def handle_read_event(self): + self.read = True + + def handle_write_event(self): + self.write = True + + def handle_close(self): + self.closed = True + + def handle_expt_event(self): + self.expt = True + + # def handle_error(self): + # self.error_handled = True + + for flag, expectedattr in expected: + tobj = testobj() + self.assertEqual(getattr(tobj, expectedattr), False) + asyncore.readwrite(tobj, flag) + + # Only the attribute modified by the routine we expect to be + # called should be True. + for attr in attributes: + self.assertEqual(getattr(tobj, attr), attr==expectedattr) + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore readwrite call + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + self.assertEqual(tr2.error_handled, False) + asyncore.readwrite(tr2, flag) + self.assertEqual(tr2.error_handled, True) + + def test_closeall(self): + self.closeall_check(False) + + def test_closeall_default(self): + self.closeall_check(True) + + def closeall_check(self, usedefault): + # Check that close_all() closes everything in a given map + + l = [] + testmap = {} + for i in range(10): + c = dummychannel() + l.append(c) + self.assertEqual(c.socket.closed, False) + testmap[i] = c + + if usedefault: + socketmap = asyncore.socket_map + try: + asyncore.socket_map = testmap + asyncore.close_all() + finally: + testmap, asyncore.socket_map = asyncore.socket_map, socketmap + else: + asyncore.close_all(testmap) + + self.assertEqual(len(testmap), 0) + + for c in l: + self.assertEqual(c.socket.closed, True) + + def test_compact_traceback(self): + try: + raise Exception("I don't like spam!") + except: + real_t, real_v, real_tb = sys.exc_info() + r = asyncore.compact_traceback() + + (f, function, line), t, v, info = r + self.assertEqual(os.path.split(f)[-1], 'test_wasyncore.py') + self.assertEqual(function, 'test_compact_traceback') + self.assertEqual(t, real_t) + self.assertEqual(v, real_v) + self.assertEqual(info, '[%s|%s|%s]' % (f, function, line)) + + +class DispatcherTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + def test_basic(self): + d = asyncore.dispatcher() + self.assertEqual(d.readable(), True) + self.assertEqual(d.writable(), True) + + def test_repr(self): + d = asyncore.dispatcher() + self.assertEqual( + repr(d), + '' % id(d) + ) + + def test_log_info(self): + import logging + inst = asyncore.dispatcher(map={}) + logger = DummyLogger() + inst.logger = logger + inst.log_info('message', 'warning') + self.assertEqual(logger.messages, [(logging.WARN, 'message')]) + + def test_log(self): + import logging + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + inst.log('message') + self.assertEqual(logger.messages, [(logging.DEBUG, 'message')]) + + def test_unhandled(self): + import logging + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + + inst.handle_expt() + inst.handle_read() + inst.handle_write() + inst.handle_connect() + + expected = [(logging.WARN, 'unhandled incoming priority event'), + (logging.WARN, 'unhandled read event'), + (logging.WARN, 'unhandled write event'), + (logging.WARN, 'unhandled connect event')] + self.assertEqual(logger.messages, expected) + + def test_strerror(self): + # refers to bug #8573 + err = asyncore._strerror(errno.EPERM) + if hasattr(os, 'strerror'): + self.assertEqual(err, os.strerror(errno.EPERM)) + err = asyncore._strerror(-1) + self.assertTrue(err != "") + + +class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover + def readable(self): + return False + + def handle_connect(self): + pass + + +class DispatcherWithSendTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + @reap_threads + def test_send(self): + evt = threading.Event() + sock = socket.socket() + sock.settimeout(3) + port = bind_port(sock) + + cap = BytesIO() + args = (evt, cap, sock) + t = threading.Thread(target=capture_server, args=args) + t.start() + try: + # wait a little longer for the server to initialize (it sometimes + # refuses connections on slow machines without this wait) + time.sleep(0.2) + + data = b"Suppose there isn't a 16-ton weight?" + d = dispatcherwithsend_noread() + d.create_socket() + d.connect((HOST, port)) + + # give time for socket to connect + time.sleep(0.1) + + d.send(data) + d.send(data) + d.send(b'\n') + + n = 1000 + while d.out_buffer and n > 0: # pragma: no cover + asyncore.poll() + n -= 1 + + evt.wait() + + self.assertEqual(cap.getvalue(), data*2) + finally: + join_thread(t, timeout=TIMEOUT) + + +@unittest.skipUnless(hasattr(asyncore, 'file_wrapper'), + 'asyncore.file_wrapper required') +class FileWrapperTest(unittest.TestCase): + def setUp(self): + self.d = b"It's not dead, it's sleeping!" + with open(TESTFN, 'wb') as file: + file.write(self.d) + + def tearDown(self): + unlink(TESTFN) + + def test_recv(self): + fd = os.open(TESTFN, os.O_RDONLY) + w = asyncore.file_wrapper(fd) + os.close(fd) + + self.assertNotEqual(w.fd, fd) + self.assertNotEqual(w.fileno(), fd) + self.assertEqual(w.recv(13), b"It's not dead") + self.assertEqual(w.read(6), b", it's") + w.close() + self.assertRaises(OSError, w.read, 1) + + def test_send(self): + d1 = b"Come again?" + d2 = b"I want to buy some cheese." + fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND) + w = asyncore.file_wrapper(fd) + os.close(fd) + + w.write(d1) + w.send(d2) + w.close() + with open(TESTFN, 'rb') as file: + self.assertEqual(file.read(), self.d + d1 + d2) + + @unittest.skipUnless(hasattr(asyncore, 'file_dispatcher'), + 'asyncore.file_dispatcher required') + def test_dispatcher(self): + fd = os.open(TESTFN, os.O_RDONLY) + data = [] + class FileDispatcher(asyncore.file_dispatcher): + def handle_read(self): + data.append(self.recv(29)) + FileDispatcher(fd) + os.close(fd) + asyncore.loop(timeout=0.01, use_poll=True, count=2) + self.assertEqual(b"".join(data), self.d) + + def test_resource_warning(self): + # Issue #11453 + got_warning = False + while got_warning is False: + # we try until we get the outcome we want because this + # test is not deterministic (gc_collect() may not + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + + os.close(fd) + + try: + with check_warnings(('', compat.ResourceWarning)): + f = None + gc_collect() + except AssertionError: # pragma: no cover + pass + else: + got_warning = True + + def test_close_twice(self): + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + os.close(fd) + + os.close(f.fd) # file_wrapper dupped fd + with self.assertRaises(OSError): + f.close() + + self.assertEqual(f.fd, -1) + # calling close twice should not fail + f.close() + + +class BaseTestHandler(asyncore.dispatcher): # pragma: no cover + + def __init__(self, sock=None): + asyncore.dispatcher.__init__(self, sock) + self.flag = False + + def handle_accept(self): + raise Exception("handle_accept not supposed to be called") + + def handle_accepted(self): + raise Exception("handle_accepted not supposed to be called") + + def handle_connect(self): + raise Exception("handle_connect not supposed to be called") + + def handle_expt(self): + raise Exception("handle_expt not supposed to be called") + + def handle_close(self): + raise Exception("handle_close not supposed to be called") + + def handle_error(self): + raise + + +class BaseServer(asyncore.dispatcher): + """A server which listens on an address and dispatches the + connection to a handler. + """ + + def __init__(self, family, addr, handler=BaseTestHandler): + asyncore.dispatcher.__init__(self) + self.create_socket(family) + self.set_reuse_addr() + bind_af_aware(self.socket, addr) + self.listen(5) + self.handler = handler + + @property + def address(self): + return self.socket.getsockname() + + def handle_accepted(self, sock, addr): + self.handler(sock) + + def handle_error(self): # pragma: no cover + raise + + +class BaseClient(BaseTestHandler): + + def __init__(self, family, address): + BaseTestHandler.__init__(self) + self.create_socket(family) + self.connect(address) + + def handle_connect(self): + pass + + +class BaseTestAPI: + + def tearDown(self): + asyncore.close_all(ignore_all=True) + + def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover + timeout = float(timeout) / 100 + count = 100 + while asyncore.socket_map and count > 0: + asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) + if instance.flag: + return + count -= 1 + time.sleep(timeout) + self.fail("flag not set") + + def test_handle_connect(self): + # make sure handle_connect is called on connect() + + class TestClient(BaseClient): + def handle_connect(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_accept(self): + # make sure handle_accept() is called when a client connects + + class TestListener(BaseTestHandler): + + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_accepted(self): + # make sure handle_accepted() is called when a client connects + + class TestListener(BaseTestHandler): + + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + asyncore.dispatcher.handle_accept(self) + + def handle_accepted(self, sock, addr): + sock.close() + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + + def test_handle_read(self): + # make sure handle_read is called on data received + + class TestClient(BaseClient): + def handle_read(self): + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.send(b'x' * 1024) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_write(self): + # make sure handle_write is called + + class TestClient(BaseClient): + def handle_write(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close(self): + # make sure handle_close is called when the other end closes + # the connection + + class TestClient(BaseClient): + + def handle_read(self): + # in order to make handle_close be called we are supposed + # to make at least one recv() call + self.recv(1024) + + def handle_close(self): + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.close() + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close_after_conn_broken(self): + # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and + # #11265). + + data = b'\0' * 128 + + class TestClient(BaseClient): + + def handle_write(self): + self.send(data) + + def handle_close(self): + self.flag = True + self.close() + + # def handle_expt(self): + # self.flag = True + # self.close() + + class TestHandler(BaseTestHandler): + + def handle_read(self): + self.recv(len(data)) + self.close() + + def writable(self): + return False + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + @unittest.skipIf(sys.platform.startswith("sunos"), + "OOB support is broken on Solaris") + def test_handle_expt(self): + # Make sure handle_expt is called on OOB data received. + # Note: this might fail on some platforms as OOB data is + # tenuously supported and rarely used. + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + if sys.platform == "darwin" and self.use_poll: # pragma: no cover + self.skipTest("poll may fail on macOS; see issue #28087") + + class TestClient(BaseClient): + def handle_expt(self): + self.socket.recv(1024, socket.MSG_OOB) + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.socket.send( + compat.tobytes(chr(244)), socket.MSG_OOB + ) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_error(self): + + class TestClient(BaseClient): + def handle_write(self): + 1.0 / 0 + def handle_error(self): + self.flag = True + try: + raise + except ZeroDivisionError: + pass + else: # pragma: no cover + raise Exception("exception not raised") + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_connection_attributes(self): + server = BaseServer(self.family, self.addr) + client = BaseClient(self.family, server.address) + + # we start disconnected + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + # this can't be taken for granted across all platforms + #self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # execute some loops so that client connects to server + asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertTrue(client.connected) + self.assertFalse(client.accepting) + + # disconnect the client + client.close() + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # stop serving + server.close() + self.assertFalse(server.connected) + self.assertFalse(server.accepting) + + def test_create_socket(self): + s = asyncore.dispatcher() + s.create_socket(self.family) + #self.assertEqual(s.socket.type, socket.SOCK_STREAM) + self.assertEqual(s.socket.family, self.family) + self.assertEqual(s.socket.gettimeout(), 0) + #self.assertFalse(s.socket.get_inheritable()) + + def test_bind(self): + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + s1 = asyncore.dispatcher() + s1.create_socket(self.family) + s1.bind(self.addr) + s1.listen(5) + port = s1.socket.getsockname()[1] + + s2 = asyncore.dispatcher() + s2.create_socket(self.family) + # EADDRINUSE indicates the socket was correctly bound + self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) + + def test_set_reuse_addr(self): # pragma: no cover + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + with closewrapper(socket.socket(self.family)) as sock: + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except OSError: + unittest.skip("SO_REUSEADDR not supported on this platform") + else: + # if SO_REUSEADDR succeeded for sock we expect asyncore + # to do the same + s = asyncore.dispatcher(socket.socket(self.family)) + self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET, + socket.SO_REUSEADDR)) + s.socket.close() + s.create_socket(self.family) + s.set_reuse_addr() + self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET, + socket.SO_REUSEADDR)) + + @reap_threads + def test_quick_connect(self): # pragma: no cover + # see: http://bugs.python.org/issue10340 + if self.family not in (socket.AF_INET, + getattr(socket, "AF_INET6", object())): + self.skipTest("test specific to AF_INET and AF_INET6") + + server = BaseServer(self.family, self.addr) + # run the thread 500 ms: the socket should be connected in 200 ms + t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, + count=5)) + t.start() + try: + sock = socket.socket(self.family, socket.SOCK_STREAM) + with closewrapper(sock) as s: + s.settimeout(.2) + s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack('ii', 1, 0)) + + try: + s.connect(server.address) + except OSError: + pass + finally: + join_thread(t, timeout=TIMEOUT) + +class TestAPI_UseIPv4Sockets(BaseTestAPI): + family = socket.AF_INET + addr = (HOST, 0) + +@unittest.skipUnless(IPV6_ENABLED, 'IPv6 support required') +class TestAPI_UseIPv6Sockets(BaseTestAPI): + family = socket.AF_INET6 + addr = (HOSTv6, 0) + +@unittest.skipUnless(HAS_UNIX_SOCKETS, 'Unix sockets required') +class TestAPI_UseUnixSockets(BaseTestAPI): + if HAS_UNIX_SOCKETS: + family = socket.AF_UNIX + addr = TESTFN + + def tearDown(self): + unlink(self.addr) + BaseTestAPI.tearDown(self) + +class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = False + +@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required') +class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = True + +class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = False + +@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required') +class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = True + +class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase): + use_poll = False + +@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required') +class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase): + use_poll = True + +class Test__strerror(unittest.TestCase): + def _callFUT(self, err): + from waitress.wasyncore import _strerror + return _strerror(err) + + def test_gardenpath(self): + self.assertEqual(self._callFUT(1), 'Operation not permitted') + + def test_unknown(self): + self.assertEqual(self._callFUT('wut'), 'Unknown error wut') + +class Test_read(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import read + return read(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow,self._callFUT, inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + +class Test_write(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import write + return write(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow,self._callFUT, inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertTrue(inst.error_handled) + +class Test__exception(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import _exception + return _exception(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow,self._callFUT, inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertTrue(inst.error_handled) + +@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required') +class Test_readwrite(unittest.TestCase): + def _callFUT(self, obj, flags): + from waitress.wasyncore import readwrite + return readwrite(obj, flags) + + def test_handle_read_event(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_handle_write_event(self): + flags = 0 + flags |= select.POLLOUT + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.write_event_handled) + + def test_handle_expt_event(self): + flags = 0 + flags |= select.POLLPRI + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.expt_event_handled) + + def test_handle_close(self): + flags = 0 + flags |= select.POLLHUP + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.close_handled) + + def test_socketerror_not_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.EALREADY, 'EALREADY')) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + def test_socketerror_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.ECONNRESET, 'ECONNRESET')) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.close_handled) + + def test_exception_in_reraised(self): + from waitress import wasyncore + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(wasyncore.ExitNow) + self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_exception_not_in_reraised(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(ValueError) + self._callFUT(inst, flags) + self.assertTrue(inst.error_handled) + +class Test_poll(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll + return poll(timeout, map) + + def test_nothing_writable_nothing_readable_but_map_not_empty(self): + # i read the mock.patch docs. nerp. + dummy_time = DummyTime() + map = {0:DummyDispatcher()} + try: + from waitress import wasyncore + old_time = wasyncore.time + wasyncore.time = dummy_time + result = self._callFUT(map=map) + finally: + wasyncore.time = old_time + self.assertEqual(result, None) + self.assertEqual(dummy_time.sleepvals, [0.0]) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EINTR)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0:disp} + try: + from waitress import wasyncore + old_select = wasyncore.select + wasyncore.select = dummy_select + result = self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(result, None) + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EBADF)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0:disp} + try: + from waitress import wasyncore + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + +class Test_poll2(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll2 + return poll2(timeout, map) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EINTR)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0:disp} + try: + from waitress import wasyncore + old_select = wasyncore.select + wasyncore.select = dummy_select + self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EBADF)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0:disp} + try: + from waitress import wasyncore + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + +class Test_dispatcher(unittest.TestCase): + def _makeOne(self, sock=None, map=None): + from waitress.wasyncore import dispatcher + return dispatcher(sock=sock, map=map) + + def test_unexpected_getpeername_exc(self): + sock = dummysocket() + def getpeername(): + raise socket.error(errno.EBADF) + map = {} + sock.getpeername = getpeername + self.assertRaises(socket.error, self._makeOne, sock=sock, map=map) + self.assertEqual(map, {}) + + def test___repr__accepting(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.accepting = True + inst.addr = ('localhost', 8080) + result = repr(inst) + expected = ' + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +"""Basic infrastructure for asynchronous socket service clients and servers. + +There are only two ways to have a program on a single processor do "more +than one thing at a time". Multi-threaded programming is the simplest and +most popular way to do it, but there is another very different technique, +that lets you have nearly all the advantages of multi-threading, without +actually using multiple threads. it's really only practical if your program +is largely I/O bound. If your program is CPU bound, then pre-emptive +scheduled threads are probably what you really need. Network servers are +rarely CPU-bound, however. + +If your operating system supports the select() system call in its I/O +library (and nearly all do), then you can use it to juggle multiple +communication channels at once; doing other work while your I/O is taking +place in the "background." Although this strategy can seem strange and +complex, especially at first, it is in many ways easier to understand and +control than multi-threaded programming. The module documented here solves +many of the difficult problems for you, making the task of building +sophisticated high-performance network servers and clients a snap. + +NB: this is a fork of asyncore from the stdlib that we've (the waitress +developers) named 'wasyncore' to ensure forward compatibility, as asyncore +in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore +nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. +""" + +from . import compat +from . import utilities + +import logging +import select +import socket +import sys +import time +import warnings + +import os +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ + ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, EINTR, \ + errorcode + +_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, + EBADF}) + +try: + socket_map +except NameError: + socket_map = {} + +def _strerror(err): + try: + return os.strerror(err) + except (TypeError, ValueError, OverflowError, NameError): + return "Unknown error %s" % err + +class ExitNow(Exception): + pass + +_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) + +def read(obj): + try: + obj.handle_read_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def write(obj): + try: + obj.handle_write_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def _exception(obj): + try: + obj.handle_expt_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def readwrite(obj, flags): + try: + if flags & select.POLLIN: + obj.handle_read_event() + if flags & select.POLLOUT: + obj.handle_write_event() + if flags & select.POLLPRI: + obj.handle_expt_event() + if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + obj.handle_close() + except socket.error as e: + if e.args[0] not in _DISCONNECTED: + obj.handle_error() + else: + obj.handle_close() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def poll(timeout=0.0, map=None): + if map is None: # pragma: no cover + map = socket_map + if map: + r = []; w = []; e = [] + for fd, obj in list(map.items()): # list() call FBO py3 + is_r = obj.readable() + is_w = obj.writable() + if is_r: + r.append(fd) + # accepting sockets should not be writable + if is_w and not obj.accepting: + w.append(fd) + if is_r or is_w: + e.append(fd) + if [] == r == w == e: + time.sleep(timeout) + return + + try: + r, w, e = select.select(r, w, e, timeout) + except select.error as err: + if err.args[0] != EINTR: + raise + else: + return + + for fd in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + read(obj) + + for fd in w: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + write(obj) + + for fd in e: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + _exception(obj) + +def poll2(timeout=0.0, map=None): + # Use the poll() support added to the select module in Python 2.0 + if map is None: # pragma: no cover + map = socket_map + if timeout is not None: + # timeout is in milliseconds + timeout = int(timeout*1000) + pollster = select.poll() + if map: + for fd, obj in list(map.items()): + flags = 0 + if obj.readable(): + flags |= select.POLLIN | select.POLLPRI + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: + flags |= select.POLLOUT + if flags: + pollster.register(fd, flags) + + try: + r = pollster.poll(timeout) + except select.error as err: + if err.args[0] != EINTR: + raise + r = [] + + for fd, flags in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + readwrite(obj, flags) + +poll3 = poll2 # Alias for backward compatibility + +def loop(timeout=30.0, use_poll=False, map=None, count=None): + if map is None: # pragma: no cover + map = socket_map + + if use_poll and hasattr(select, 'poll'): + poll_fun = poll2 + else: + poll_fun = poll + + if count is None: # pragma: no cover + while map: + poll_fun(timeout, map) + + else: + while map and count > 0: + poll_fun(timeout, map) + count = count - 1 + +def compact_traceback(): + t, v, tb = sys.exc_info() + tbinfo = [] + if not tb: # pragma: no cover + raise AssertionError("traceback does not exist") + while tb: + tbinfo.append(( + tb.tb_frame.f_code.co_filename, + tb.tb_frame.f_code.co_name, + str(tb.tb_lineno) + )) + tb = tb.tb_next + + # just to be safe + del tb + + file, function, line = tbinfo[-1] + info = ' '.join(['[%s|%s|%s]' % x for x in tbinfo]) + return (file, function, line), t, v, info + +class dispatcher: + + debug = False + connected = False + accepting = False + connecting = False + closing = False + addr = None + ignore_log_types = frozenset({'warning'}) + logger = utilities.logger + compact_traceback = staticmethod(compact_traceback) # for testing + + def __init__(self, sock=None, map=None): + if map is None: # pragma: no cover + self._map = socket_map + else: + self._map = map + + self._fileno = None + + if sock: + # Set to nonblocking just to make sure for cases where we + # get a socket from a blocking source. + sock.setblocking(0) + self.set_socket(sock, map) + self.connected = True + # The constructor no longer requires that the socket + # passed be connected. + try: + self.addr = sock.getpeername() + except socket.error as err: + if err.args[0] in (ENOTCONN, EINVAL): + # To handle the case where we got an unconnected + # socket. + self.connected = False + else: + # The socket is broken in some unknown way, alert + # the user and remove it from the map (to prevent + # polling of broken sockets). + self.del_channel(map) + raise + else: + self.socket = None + + def __repr__(self): + status = [self.__class__.__module__+"."+compat.qualname(self.__class__)] + if self.accepting and self.addr: + status.append('listening') + elif self.connected: + status.append('connected') + if self.addr is not None: + try: + status.append('%s:%d' % self.addr) + except TypeError: # pragma: no cover + status.append(repr(self.addr)) + return '<%s at %#x>' % (' '.join(status), id(self)) + + __str__ = __repr__ + + def add_channel(self, map=None): + #self.log_info('adding channel %s' % self) + if map is None: + map = self._map + map[self._fileno] = self + + def del_channel(self, map=None): + fd = self._fileno + if map is None: + map = self._map + if fd in map: + #self.log_info('closing channel %d:%s' % (fd, self)) + del map[fd] + self._fileno = None + + def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): + self.family_and_type = family, type + sock = socket.socket(family, type) + sock.setblocking(0) + self.set_socket(sock) + + def set_socket(self, sock, map=None): + self.socket = sock + self._fileno = sock.fileno() + self.add_channel(map) + + def set_reuse_addr(self): + # try to re-use a server port if possible + try: + self.socket.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, + self.socket.getsockopt(socket.SOL_SOCKET, + socket.SO_REUSEADDR) | 1 + ) + except socket.error: + pass + + # ================================================== + # predicates for select() + # these are used as filters for the lists of sockets + # to pass to select(). + # ================================================== + + def readable(self): + return True + + def writable(self): + return True + + # ================================================== + # socket object methods. + # ================================================== + + def listen(self, num): + self.accepting = True + if os.name == 'nt' and num > 5: # pragma: no cover + num = 5 + return self.socket.listen(num) + + def bind(self, addr): + self.addr = addr + return self.socket.bind(addr) + + def connect(self, address): + self.connected = False + self.connecting = True + err = self.socket.connect_ex(address) + if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \ + or err == EINVAL and os.name == 'nt': # pragma: no cover + self.addr = address + return + if err in (0, EISCONN): + self.addr = address + self.handle_connect_event() + else: + raise socket.error(err, errorcode[err]) + + def accept(self): + # XXX can return either an address pair or None + try: + conn, addr = self.socket.accept() + except TypeError: + return None + except socket.error as why: + if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): + return None + else: + raise + else: + return conn, addr + + def send(self, data): + try: + result = self.socket.send(data) + return result + except socket.error as why: + if why.args[0] == EWOULDBLOCK: + return 0 + elif why.args[0] in _DISCONNECTED: + self.handle_close() + return 0 + else: + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.handle_close() + return b'' + else: + return data + except socket.error as why: + # winsock sometimes raises ENOTCONN + if why.args[0] in _DISCONNECTED: + self.handle_close() + return b'' + else: + raise + + def close(self): + self.connected = False + self.accepting = False + self.connecting = False + self.del_channel() + if self.socket is not None: + try: + self.socket.close() + except socket.error as why: + if why.args[0] not in (ENOTCONN, EBADF): + raise + + # log and log_info may be overridden to provide more sophisticated + # logging and warning methods. In general, log is for 'hit' logging + # and 'log_info' is for informational, warning and error logging. + + def log(self, message): + self.logger.log(logging.DEBUG, message) + + def log_info(self, message, type='info'): + severity = { + 'info': logging.INFO, + 'warning': logging.WARN, + 'error': logging.ERROR, + } + self.logger.log(severity.get(type, logging.INFO), message) + + def handle_read_event(self): + if self.accepting: + # accepting sockets are never connected, they "spawn" new + # sockets that are connected + self.handle_accept() + elif not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_read() + else: + self.handle_read() + + def handle_connect_event(self): + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise socket.error(err, _strerror(err)) + self.handle_connect() + self.connected = True + self.connecting = False + + def handle_write_event(self): + if self.accepting: + # Accepting sockets shouldn't get a write event. + # We will pretend it didn't happen. + return + + if not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_write() + + def handle_expt_event(self): + # handle_expt_event() is called if there might be an error on the + # socket, or if there is OOB data + # check for the error condition first + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # we can get here when select.select() says that there is an + # exceptional condition on the socket + # since there is an error, we'll go ahead and close the socket + # like we would in a subclassed handle_read() that received no + # data + self.handle_close() + else: + self.handle_expt() + + def handle_error(self): + nil, t, v, tbinfo = self.compact_traceback() + + # sometimes a user repr method will crash. + try: + self_repr = repr(self) + except: # pragma: no cover + self_repr = '<__repr__(self) failed for object at %0x>' % id(self) + + self.log_info( + 'uncaptured python exception, closing channel %s (%s:%s %s)' % ( + self_repr, + t, + v, + tbinfo + ), + 'error' + ) + self.handle_close() + + def handle_expt(self): + self.log_info('unhandled incoming priority event', 'warning') + + def handle_read(self): + self.log_info('unhandled read event', 'warning') + + def handle_write(self): + self.log_info('unhandled write event', 'warning') + + def handle_connect(self): + self.log_info('unhandled connect event', 'warning') + + def handle_accept(self): + pair = self.accept() + if pair is not None: + self.handle_accepted(*pair) + + def handle_accepted(self, sock, addr): + sock.close() + self.log_info('unhandled accepted event', 'warning') + + def handle_close(self): + self.log_info('unhandled close event', 'warning') + self.close() + +# --------------------------------------------------------------------------- +# adds simple buffered output capability, useful for simple clients. +# [for more sophisticated usage use asynchat.async_chat] +# --------------------------------------------------------------------------- + +class dispatcher_with_send(dispatcher): + + def __init__(self, sock=None, map=None): + dispatcher.__init__(self, sock, map) + self.out_buffer = b'' + + def initiate_send(self): + num_sent = 0 + num_sent = dispatcher.send(self, self.out_buffer[:65536]) + self.out_buffer = self.out_buffer[num_sent:] + + handle_write = initiate_send + + def writable(self): + return (not self.connected) or len(self.out_buffer) + + def send(self, data): + if self.debug: # pragma: no cover + self.log_info('sending %s' % repr(data)) + self.out_buffer = self.out_buffer + data + self.initiate_send() + +def close_all(map=None, ignore_all=False): + if map is None: # pragma: no cover + map = socket_map + for x in list(map.values()): # list() FBO py3 + try: + x.close() + except socket.error as x: + if x.args[0] == EBADF: + pass + elif not ignore_all: + raise + except _reraised_exceptions: + raise + except: + if not ignore_all: + raise + map.clear() + +# Asynchronous File I/O: +# +# After a little research (reading man pages on various unixen, and +# digging through the linux kernel), I've determined that select() +# isn't meant for doing asynchronous file i/o. +# Heartening, though - reading linux/mm/filemap.c shows that linux +# supports asynchronous read-ahead. So _MOST_ of the time, the data +# will be sitting in memory for us already when we go to read it. +# +# What other OS's (besides NT) support async file i/o? [VMS?] +# +# Regardless, this is useful for pipes, and stdin/stdout... + +if os.name == 'posix': + class file_wrapper: + # Here we override just enough to make a file + # look like a socket for the purposes of asyncore. + # The passed fd is automatically os.dup()'d + + def __init__(self, fd): + self.fd = os.dup(fd) + + def __del__(self): + if self.fd >= 0: + warnings.warn("unclosed file %r" % self, compat.ResourceWarning) + self.close() + + def recv(self, *args): + return os.read(self.fd, *args) + + def send(self, *args): + return os.write(self.fd, *args) + + def getsockopt(self, level, optname, buflen=None): # pragma: no cover + if (level == socket.SOL_SOCKET and + optname == socket.SO_ERROR and + not buflen): + return 0 + raise NotImplementedError("Only asyncore specific behaviour " + "implemented.") + + read = recv + write = send + + def close(self): + if self.fd < 0: + return + fd = self.fd + self.fd = -1 + os.close(fd) + + def fileno(self): + return self.fd + + class file_dispatcher(dispatcher): + + def __init__(self, fd, map=None): + dispatcher.__init__(self, None, map) + self.connected = True + try: + fd = fd.fileno() + except AttributeError: + pass + self.set_file(fd) + # set it to non-blocking mode + compat.set_nonblocking(fd) + + def set_file(self, fd): + self.socket = file_wrapper(fd) + self._fileno = self.socket.fileno() + self.add_channel() + -- cgit v1.2.1