diff options
Diffstat (limited to 'src/waitress/server.py')
-rw-r--r-- | src/waitress/server.py | 436 |
1 files changed, 436 insertions, 0 deletions
diff --git a/src/waitress/server.py b/src/waitress/server.py new file mode 100644 index 0000000..ae56699 --- /dev/null +++ b/src/waitress/server.py @@ -0,0 +1,436 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import os +import os.path +import socket +import time + +from waitress import trigger +from waitress.adjustments import Adjustments +from waitress.channel import HTTPChannel +from waitress.task import ThreadedTaskDispatcher +from waitress.utilities import cleanup_unix_socket + +from waitress.compat import ( + IPPROTO_IPV6, + IPV6_V6ONLY, +) +from . import wasyncore +from .proxy_headers import proxy_headers_middleware + + +def create_server( + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + _dispatcher=None, # test shim + **kw # adjustments +): + """ + if __name__ == '__main__': + server = create_server(app) + server.run() + """ + if application is None: + raise ValueError( + 'The "app" passed to ``create_server`` was ``None``. You forgot ' + "to return a WSGI app within your application." + ) + adj = Adjustments(**kw) + + if map is None: # pragma: nocover + map = {} + + dispatcher = _dispatcher + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(adj.threads) + + if adj.unix_socket and hasattr(socket, "AF_UNIX"): + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + return UnixWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + + effective_listen = [] + last_serv = None + if not adj.sockets: + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + for sock in adj.sockets: + sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) + if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + last_serv = TcpWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + last_serv = UnixWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + # We are running a single server, so we can just return the last server, + # saves us from having to create one more object + if len(effective_listen) == 1: + # In this case we have no need to use a MultiSocketServer + return last_serv + + # Return a class that has a utility function to print out the sockets it's + # listening on, and has a .run() function. All of the TcpWSGIServers + # registered themselves in the map above. + return MultiSocketServer(map, adj, effective_listen, dispatcher) + + +# This class is only ever used if we have multiple listen sockets. It allows +# the serve() API to call .run() which starts the wasyncore loop, and catches +# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. +class MultiSocketServer(object): + asyncore = wasyncore # test shim + + def __init__( + self, map=None, adj=None, effective_listen=None, dispatcher=None, + ): + self.adj = adj + self.map = map + self.effective_listen = effective_listen + self.task_dispatcher = dispatcher + + def print_listen(self, format_str): # pragma: nocover + for l in self.effective_listen: + l = list(l) + + if ":" in l[0]: + l[0] = "[{}]".format(l[0]) + + print(format_str.format(*l)) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self.map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.close() + + def close(self): + self.task_dispatcher.shutdown() + wasyncore.close_all(self.map) + + +class BaseWSGIServer(wasyncore.dispatcher, object): + + channel_class = HTTPChannel + next_channel_cleanup = 0 + socketmod = socket # test shim + asyncore = wasyncore # test shim + + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + bind_socket=True, + **kw + ): + if adj is None: + adj = Adjustments(**kw) + + if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: + # wrap the application to deal with proxy headers + # we wrap it here because webtest subclasses the TcpWSGIServer + # directly and thus doesn't run any code that's in create_server + application = proxy_headers_middleware( + application, + trusted_proxy=adj.trusted_proxy, + trusted_proxy_count=adj.trusted_proxy_count, + trusted_proxy_headers=adj.trusted_proxy_headers, + clear_untrusted=adj.clear_untrusted_proxy_headers, + log_untrusted=adj.log_untrusted_proxy_headers, + logger=self.logger, + ) + + if map is None: + # use a nonglobal socket map by default to hopefully prevent + # conflicts with apps and libs that use the wasyncore global socket + # map ala https://github.com/Pylons/waitress/issues/63 + map = {} + if sockinfo is None: + sockinfo = adj.listen[0] + + self.sockinfo = sockinfo + self.family = sockinfo[0] + self.socktype = sockinfo[1] + self.application = application + self.adj = adj + self.trigger = trigger.trigger(map) + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(self.adj.threads) + + self.task_dispatcher = dispatcher + self.asyncore.dispatcher.__init__(self, _sock, map=map) + if _sock is None: + self.create_socket(self.family, self.socktype) + if self.family == socket.AF_INET6: # pragma: nocover + self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + + self.set_reuse_addr() + + if bind_socket: + self.bind_server_socket() + + self.effective_host, self.effective_port = self.getsockname() + self.server_name = self.get_server_name(self.effective_host) + self.active_channels = {} + if _start: + self.accept_connections() + + def bind_server_socket(self): + raise NotImplementedError # pragma: no cover + + def get_server_name(self, ip): + """Given an IP or hostname, try to determine the server name.""" + + if not ip: + raise ValueError("Requires an IP to get the server name") + + server_name = str(ip) + + # If we are bound to all IP's, just return the current hostname, only + # fall-back to "localhost" if we fail to get the hostname + if server_name == "0.0.0.0" or server_name == "::": + try: + return str(self.socketmod.gethostname()) + except (socket.error, UnicodeDecodeError): # pragma: no cover + # We also deal with UnicodeDecodeError in case of Windows with + # non-ascii hostname + return "localhost" + + # Now let's try and convert the IP address to a proper hostname + try: + server_name = self.socketmod.gethostbyaddr(server_name)[0] + except (socket.error, UnicodeDecodeError): # pragma: no cover + # We also deal with UnicodeDecodeError in case of Windows with + # non-ascii hostname + pass + + # If it contains an IPv6 literal, make sure to surround it with + # brackets + if ":" in server_name and "[" not in server_name: + server_name = "[{}]".format(server_name) + + return server_name + + def getsockname(self): + raise NotImplementedError # pragma: no cover + + def accept_connections(self): + self.accepting = True + self.socket.listen(self.adj.backlog) # Get around asyncore NT limit + + def add_task(self, task): + self.task_dispatcher.add_task(task) + + def readable(self): + now = time.time() + if now >= self.next_channel_cleanup: + self.next_channel_cleanup = now + self.adj.cleanup_interval + self.maintenance(now) + return self.accepting and len(self._map) < self.adj.connection_limit + + def writable(self): + return False + + def handle_read(self): + pass + + def handle_connect(self): + pass + + def handle_accept(self): + try: + v = self.accept() + if v is None: + return + conn, addr = v + except socket.error: + # Linux: On rare occasions we get a bogus socket back from + # accept. socketmodule.c:makesockaddr complains that the + # address family is unknown. We don't want the whole server + # to shut down because of this. + if self.adj.log_socket_errors: + self.logger.warning("server accept() threw an exception", exc_info=True) + return + self.set_socket_options(conn) + addr = self.fix_addr(addr) + self.channel_class(self, conn, addr, self.adj, map=self._map) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self._map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.task_dispatcher.shutdown() + + def pull_trigger(self): + self.trigger.pull_trigger() + + def set_socket_options(self, conn): + pass + + def fix_addr(self, addr): + return addr + + def maintenance(self, now): + """ + Closes channels that have not had any activity in a while. + + The timeout is configured through adj.channel_timeout (seconds). + """ + cutoff = now - self.adj.channel_timeout + for channel in self.active_channels.values(): + if (not channel.requests) and channel.last_activity < cutoff: + channel.will_close = True + + def print_listen(self, format_str): # pragma: nocover + print(format_str.format(self.effective_host, self.effective_port)) + + def close(self): + self.trigger.close() + return wasyncore.dispatcher.close(self) + + +class TcpWSGIServer(BaseWSGIServer): + def bind_server_socket(self): + (_, _, _, sockaddr) = self.sockinfo + self.bind(sockaddr) + + def getsockname(self): + try: + return self.socketmod.getnameinfo( + self.socket.getsockname(), self.socketmod.NI_NUMERICSERV + ) + except: # pragma: no cover + # This only happens on Linux because a DNS issue is considered a + # temporary failure that will raise (even when NI_NAMEREQD is not + # set). Instead we try again, but this time we just ask for the + # numerichost and the numericserv (port) and return those. It is + # better than nothing. + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, + ) + + def set_socket_options(self, conn): + for (level, optname, value) in self.adj.socket_options: + conn.setsockopt(level, optname, value) + + +if hasattr(socket, "AF_UNIX"): + + class UnixWSGIServer(BaseWSGIServer): + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw + ): + if sockinfo is None: + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + + super(UnixWSGIServer, self).__init__( + application, + map=map, + _start=_start, + _sock=_sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + **kw + ) + + def bind_server_socket(self): + cleanup_unix_socket(self.adj.unix_socket) + self.bind(self.adj.unix_socket) + if os.path.exists(self.adj.unix_socket): + os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) + + def getsockname(self): + return ("unix", self.socket.getsockname()) + + def fix_addr(self, addr): + return ("localhost", None) + + def get_server_name(self, ip): + return "localhost" + + +# Compatibility alias. +WSGIServer = TcpWSGIServer |