diff options
Diffstat (limited to 'src/waitress/task.py')
-rw-r--r-- | src/waitress/task.py | 570 |
1 files changed, 570 insertions, 0 deletions
diff --git a/src/waitress/task.py b/src/waitress/task.py new file mode 100644 index 0000000..8e7ab18 --- /dev/null +++ b/src/waitress/task.py @@ -0,0 +1,570 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import socket +import sys +import threading +import time +from collections import deque + +from .buffers import ReadOnlyFileBasedBuffer +from .compat import reraise, tobytes +from .utilities import build_http_date, logger, queue_logger + +rename_headers = { # or keep them without the HTTP_ prefix added + "CONTENT_LENGTH": "CONTENT_LENGTH", + "CONTENT_TYPE": "CONTENT_TYPE", +} + +hop_by_hop = frozenset( + ( + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ) +) + + +class ThreadedTaskDispatcher(object): + """A Task Dispatcher that creates a thread for each task. + """ + + stop_count = 0 # Number of threads that will stop soon. + active_count = 0 # Number of currently active threads + logger = logger + queue_logger = queue_logger + + def __init__(self): + self.threads = set() + self.queue = deque() + self.lock = threading.Lock() + self.queue_cv = threading.Condition(self.lock) + self.thread_exit_cv = threading.Condition(self.lock) + + def start_new_thread(self, target, args): + t = threading.Thread(target=target, name="waitress", args=args) + t.daemon = True + t.start() + + def handler_thread(self, thread_no): + while True: + with self.lock: + while not self.queue and self.stop_count == 0: + # Mark ourselves as idle before waiting to be + # woken up, then we will once again be active + self.active_count -= 1 + self.queue_cv.wait() + self.active_count += 1 + + if self.stop_count > 0: + self.active_count -= 1 + self.stop_count -= 1 + self.threads.discard(thread_no) + self.thread_exit_cv.notify() + break + + task = self.queue.popleft() + try: + task.service() + except BaseException: + self.logger.exception("Exception when servicing %r", task) + + def set_thread_count(self, count): + with self.lock: + threads = self.threads + thread_no = 0 + running = len(threads) - self.stop_count + while running < count: + # Start threads. + while thread_no in threads: + thread_no = thread_no + 1 + threads.add(thread_no) + running += 1 + self.start_new_thread(self.handler_thread, (thread_no,)) + self.active_count += 1 + thread_no = thread_no + 1 + if running > count: + # Stop threads. + self.stop_count += running - count + self.queue_cv.notify_all() + + def add_task(self, task): + with self.lock: + self.queue.append(task) + self.queue_cv.notify() + queue_size = len(self.queue) + idle_threads = len(self.threads) - self.stop_count - self.active_count + if queue_size > idle_threads: + self.queue_logger.warning( + "Task queue depth is %d", queue_size - idle_threads + ) + + def shutdown(self, cancel_pending=True, timeout=5): + self.set_thread_count(0) + # Ensure the threads shut down. + threads = self.threads + expiration = time.time() + timeout + with self.lock: + while threads: + if time.time() >= expiration: + self.logger.warning("%d thread(s) still running", len(threads)) + break + self.thread_exit_cv.wait(0.1) + if cancel_pending: + # Cancel remaining tasks. + queue = self.queue + if len(queue) > 0: + self.logger.warning("Canceling %d pending task(s)", len(queue)) + while queue: + task = queue.popleft() + task.cancel() + self.queue_cv.notify_all() + return True + return False + + +class Task(object): + close_on_finish = False + status = "200 OK" + wrote_header = False + start_time = 0 + content_length = None + content_bytes_written = 0 + logged_write_excess = False + logged_write_no_body = False + complete = False + chunked_response = False + logger = logger + + def __init__(self, channel, request): + self.channel = channel + self.request = request + self.response_headers = [] + version = request.version + if version not in ("1.0", "1.1"): + # fall back to a version we support. + version = "1.0" + self.version = version + + def service(self): + try: + try: + self.start() + self.execute() + self.finish() + except socket.error: + self.close_on_finish = True + if self.channel.adj.log_socket_errors: + raise + finally: + pass + + @property + def has_body(self): + return not ( + self.status.startswith("1") + or self.status.startswith("204") + or self.status.startswith("304") + ) + + def build_response_header(self): + version = self.version + # Figure out whether the connection should be closed. + connection = self.request.headers.get("CONNECTION", "").lower() + response_headers = [] + content_length_header = None + date_header = None + server_header = None + connection_close_header = None + + for (headername, headerval) in self.response_headers: + headername = "-".join([x.capitalize() for x in headername.split("-")]) + + if headername == "Content-Length": + if self.has_body: + content_length_header = headerval + else: + continue # pragma: no cover + + if headername == "Date": + date_header = headerval + + if headername == "Server": + server_header = headerval + + if headername == "Connection": + connection_close_header = headerval.lower() + # replace with properly capitalized version + response_headers.append((headername, headerval)) + + if ( + content_length_header is None + and self.content_length is not None + and self.has_body + ): + content_length_header = str(self.content_length) + response_headers.append(("Content-Length", content_length_header)) + + def close_on_finish(): + if connection_close_header is None: + response_headers.append(("Connection", "close")) + self.close_on_finish = True + + if version == "1.0": + if connection == "keep-alive": + if not content_length_header: + close_on_finish() + else: + response_headers.append(("Connection", "Keep-Alive")) + else: + close_on_finish() + + elif version == "1.1": + if connection == "close": + close_on_finish() + + if not content_length_header: + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + + if self.has_body: + response_headers.append(("Transfer-Encoding", "chunked")) + self.chunked_response = True + + if not self.close_on_finish: + close_on_finish() + + # under HTTP 1.1 keep-alive is default, no need to set the header + else: + raise AssertionError("neither HTTP/1.0 or HTTP/1.1") + + # Set the Server and Date field, if not yet specified. This is needed + # if the server is used as a proxy. + ident = self.channel.server.adj.ident + + if not server_header: + if ident: + response_headers.append(("Server", ident)) + else: + response_headers.append(("Via", ident or "waitress")) + + if not date_header: + response_headers.append(("Date", build_http_date(self.start_time))) + + self.response_headers = response_headers + + first_line = "HTTP/%s %s" % (self.version, self.status) + # NB: sorting headers needs to preserve same-named-header order + # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; + # rely on stable sort to keep relative position of same-named headers + next_lines = [ + "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) + ] + lines = [first_line] + next_lines + res = "%s\r\n\r\n" % "\r\n".join(lines) + + return tobytes(res) + + def remove_content_length_header(self): + response_headers = [] + + for header_name, header_value in self.response_headers: + if header_name.lower() == "content-length": + continue # pragma: nocover + response_headers.append((header_name, header_value)) + + self.response_headers = response_headers + + def start(self): + self.start_time = time.time() + + def finish(self): + if not self.wrote_header: + self.write(b"") + if self.chunked_response: + # not self.write, it will chunk it! + self.channel.write_soon(b"0\r\n\r\n") + + def write(self, data): + if not self.complete: + raise RuntimeError("start_response was not called before body written") + channel = self.channel + if not self.wrote_header: + rh = self.build_response_header() + channel.write_soon(rh) + self.wrote_header = True + + if data and self.has_body: + towrite = data + cl = self.content_length + if self.chunked_response: + # use chunked encoding response + towrite = tobytes(hex(len(data))[2:].upper()) + b"\r\n" + towrite += data + b"\r\n" + elif cl is not None: + towrite = data[: cl - self.content_bytes_written] + self.content_bytes_written += len(towrite) + if towrite != data and not self.logged_write_excess: + self.logger.warning( + "application-written content exceeded the number of " + "bytes specified by Content-Length header (%s)" % cl + ) + self.logged_write_excess = True + if towrite: + channel.write_soon(towrite) + elif data: + # Cheat, and tell the application we have written all of the bytes, + # even though the response shouldn't have a body and we are + # ignoring it entirely. + self.content_bytes_written += len(data) + + if not self.logged_write_no_body: + self.logger.warning( + "application-written content was ignored due to HTTP " + "response that may not contain a message-body: (%s)" % self.status + ) + self.logged_write_no_body = True + + +class ErrorTask(Task): + """ An error task produces an error response + """ + + complete = True + + def execute(self): + e = self.request.error + status, headers, body = e.to_response() + self.status = status + self.response_headers.extend(headers) + # We need to explicitly tell the remote client we are closing the + # connection, because self.close_on_finish is set, and we are going to + # slam the door in the clients face. + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + self.content_length = len(body) + self.write(tobytes(body)) + + +class WSGITask(Task): + """A WSGI task produces a response from a WSGI application. + """ + + environ = None + + def execute(self): + environ = self.get_environment() + + def start_response(status, headers, exc_info=None): + if self.complete and not exc_info: + raise AssertionError( + "start_response called a second time without providing exc_info." + ) + if exc_info: + try: + if self.wrote_header: + # higher levels will catch and handle raised exception: + # 1. "service" method in task.py + # 2. "service" method in channel.py + # 3. "handler_thread" method in task.py + reraise(exc_info[0], exc_info[1], exc_info[2]) + else: + # As per WSGI spec existing headers must be cleared + self.response_headers = [] + finally: + exc_info = None + + self.complete = True + + if not status.__class__ is str: + raise AssertionError("status %s is not a string" % status) + if "\n" in status or "\r" in status: + raise ValueError( + "carriage return/line feed character present in status" + ) + + self.status = status + + # Prepare the headers for output + for k, v in headers: + if not k.__class__ is str: + raise AssertionError( + "Header name %r is not a string in %r" % (k, (k, v)) + ) + if not v.__class__ is str: + raise AssertionError( + "Header value %r is not a string in %r" % (v, (k, v)) + ) + + if "\n" in v or "\r" in v: + raise ValueError( + "carriage return/line feed character present in header value" + ) + if "\n" in k or "\r" in k: + raise ValueError( + "carriage return/line feed character present in header name" + ) + + kl = k.lower() + if kl == "content-length": + self.content_length = int(v) + elif kl in hop_by_hop: + raise AssertionError( + '%s is a "hop-by-hop" header; it cannot be used by ' + "a WSGI application (see PEP 3333)" % k + ) + + self.response_headers.extend(headers) + + # Return a method used to write the response data. + return self.write + + # Call the application to handle the request and write a response + app_iter = self.channel.server.application(environ, start_response) + + can_close_app_iter = True + try: + if app_iter.__class__ is ReadOnlyFileBasedBuffer: + cl = self.content_length + size = app_iter.prepare(cl) + if size: + if cl != size: + if cl is not None: + self.remove_content_length_header() + self.content_length = size + self.write(b"") # generate headers + # if the write_soon below succeeds then the channel will + # take over closing the underlying file via the channel's + # _flush_some or handle_close so we intentionally avoid + # calling close in the finally block + self.channel.write_soon(app_iter) + can_close_app_iter = False + return + + first_chunk_len = None + for chunk in app_iter: + if first_chunk_len is None: + first_chunk_len = len(chunk) + # Set a Content-Length header if one is not supplied. + # start_response may not have been called until first + # iteration as per PEP, so we must reinterrogate + # self.content_length here + if self.content_length is None: + app_iter_len = None + if hasattr(app_iter, "__len__"): + app_iter_len = len(app_iter) + if app_iter_len == 1: + self.content_length = first_chunk_len + # transmit headers only after first iteration of the iterable + # that returns a non-empty bytestring (PEP 3333) + if chunk: + self.write(chunk) + + cl = self.content_length + if cl is not None: + if self.content_bytes_written != cl: + # close the connection so the client isn't sitting around + # waiting for more data when there are too few bytes + # to service content-length + self.close_on_finish = True + if self.request.command != "HEAD": + self.logger.warning( + "application returned too few bytes (%s) " + "for specified Content-Length (%s) via app_iter" + % (self.content_bytes_written, cl), + ) + finally: + if can_close_app_iter and hasattr(app_iter, "close"): + app_iter.close() + + def get_environment(self): + """Returns a WSGI environment.""" + environ = self.environ + if environ is not None: + # Return the cached copy. + return environ + + request = self.request + path = request.path + channel = self.channel + server = channel.server + url_prefix = server.adj.url_prefix + + if path.startswith("/"): + # strip extra slashes at the beginning of a path that starts + # with any number of slashes + path = "/" + path.lstrip("/") + + if url_prefix: + # NB: url_prefix is guaranteed by the configuration machinery to + # be either the empty string or a string that starts with a single + # slash and ends without any slashes + if path == url_prefix: + # if the path is the same as the url prefix, the SCRIPT_NAME + # should be the url_prefix and PATH_INFO should be empty + path = "" + else: + # if the path starts with the url prefix plus a slash, + # the SCRIPT_NAME should be the url_prefix and PATH_INFO should + # the value of path from the slash until its end + url_prefix_with_trailing_slash = url_prefix + "/" + if path.startswith(url_prefix_with_trailing_slash): + path = path[len(url_prefix) :] + + environ = { + "REMOTE_ADDR": channel.addr[0], + # Nah, we aren't actually going to look up the reverse DNS for + # REMOTE_ADDR, but we will happily set this environment variable + # for the WSGI application. Spec says we can just set this to + # REMOTE_ADDR, so we do. + "REMOTE_HOST": channel.addr[0], + # try and set the REMOTE_PORT to something useful, but maybe None + "REMOTE_PORT": str(channel.addr[1]), + "REQUEST_METHOD": request.command.upper(), + "SERVER_PORT": str(server.effective_port), + "SERVER_NAME": server.server_name, + "SERVER_SOFTWARE": server.adj.ident, + "SERVER_PROTOCOL": "HTTP/%s" % self.version, + "SCRIPT_NAME": url_prefix, + "PATH_INFO": path, + "QUERY_STRING": request.query, + "wsgi.url_scheme": request.url_scheme, + # the following environment variables are required by the WSGI spec + "wsgi.version": (1, 0), + # apps should use the logging module + "wsgi.errors": sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "wsgi.input": request.get_body_stream(), + "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, + "wsgi.input_terminated": True, # wsgi.input is EOF terminated + } + + for key, value in dict(request.headers).items(): + value = value.strip() + mykey = rename_headers.get(key, None) + if mykey is None: + mykey = "HTTP_" + key + if mykey not in environ: + environ[mykey] = value + + # cache the environ for this request + self.environ = environ + return environ |