summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAllan Crooks <allan.crooks@yougov.com>2014-06-14 21:24:51 +0100
committerAllan Crooks <allan.crooks@yougov.com>2014-06-14 21:24:51 +0100
commit9102184e4fd3cbedf050db43c984e09bcc488203 (patch)
tree75c13c62c669d2408e459ebc002b38f5995a46f4
parent908962dc4dc00a63b940c329bca862bf827dea8f (diff)
downloadcherrypy-9102184e4fd3cbedf050db43c984e09bcc488203.tar.gz
Reimplementation of code to allow bounded request queues - rejecting requests when the queue is full. Fixes #1301.
-rw-r--r--cherrypy/_cpserver.py8
-rw-r--r--cherrypy/_cpwsgi_server.py2
-rw-r--r--cherrypy/test/test_conn.py182
-rw-r--r--cherrypy/wsgiserver/wsgiserver2.py22
-rw-r--r--cherrypy/wsgiserver/wsgiserver3.py22
5 files changed, 182 insertions, 54 deletions
diff --git a/cherrypy/_cpserver.py b/cherrypy/_cpserver.py
index 1ee3ee2b..a31e7428 100644
--- a/cherrypy/_cpserver.py
+++ b/cherrypy/_cpserver.py
@@ -61,6 +61,14 @@ class Server(ServerAdapter):
socket_timeout = 10
"""The timeout in seconds for accepted connections (default 10)."""
+
+ accepted_queue_size = -1
+ """The maximum number of requests which will be queued up before
+ the server refuses to accept it (default -1, meaning no limit)."""
+
+ accepted_queue_timeout = 10
+ """The timeout in seconds for attempting to add a request to the
+ queue when the queue is full (default 10)."""
shutdown_timeout = 5
"""The time to wait for HTTP worker threads to clean up."""
diff --git a/cherrypy/_cpwsgi_server.py b/cherrypy/_cpwsgi_server.py
index b648bab1..874e2e9f 100644
--- a/cherrypy/_cpwsgi_server.py
+++ b/cherrypy/_cpwsgi_server.py
@@ -39,6 +39,8 @@ class CPWSGIServer(wsgiserver.CherryPyWSGIServer):
request_queue_size=self.server_adapter.socket_queue_size,
timeout=self.server_adapter.socket_timeout,
shutdown_timeout=self.server_adapter.shutdown_timeout,
+ accepted_queue_size=self.server_adapter.accepted_queue_size,
+ accepted_queue_timeout=self.server_adapter.accepted_queue_timeout,
)
self.protocol = self.server_adapter.protocol_version
self.nodelay = self.server_adapter.nodelay
diff --git a/cherrypy/test/test_conn.py b/cherrypy/test/test_conn.py
index 9aa071e4..00adcca5 100644
--- a/cherrypy/test/test_conn.py
+++ b/cherrypy/test/test_conn.py
@@ -8,7 +8,7 @@ timeout = 1
import cherrypy
from cherrypy._cpcompat import HTTPConnection, HTTPSConnection, NotConnected
-from cherrypy._cpcompat import BadStatusLine, ntob, urlopen, unicodestr
+from cherrypy._cpcompat import BadStatusLine, ntob, tonative, urlopen, unicodestr
from cherrypy.test import webtest
from cherrypy import _cperror
@@ -133,10 +133,22 @@ class ConnectionCloseTests(helper.CPWebCase):
self.assertRaises(NotConnected, self.getPage, "/")
def test_Streaming_no_len(self):
- self._streaming(set_cl=False)
+ try:
+ self._streaming(set_cl=False)
+ finally:
+ try:
+ self.HTTP_CONN.close()
+ except (TypeError, AttributeError):
+ pass
def test_Streaming_with_len(self):
- self._streaming(set_cl=True)
+ try:
+ self._streaming(set_cl=True)
+ finally:
+ try:
+ self.HTTP_CONN.close()
+ except (TypeError, AttributeError):
+ pass
def _streaming(self, set_cl):
if cherrypy.server.protocol_version == "HTTP/1.1":
@@ -448,49 +460,53 @@ class PipelineTests(helper.CPWebCase):
# Try a page without an Expect request header first.
# Note that httplib's response.begin automatically ignores
# 100 Continue responses, so we must manually check for it.
- conn.putrequest("POST", "/upload", skip_host=True)
- conn.putheader("Host", self.HOST)
- conn.putheader("Content-Type", "text/plain")
- conn.putheader("Content-Length", "4")
- conn.endheaders()
- conn.send(ntob("d'oh"))
- response = conn.response_class(conn.sock, method="POST")
- version, status, reason = response._read_status()
- self.assertNotEqual(status, 100)
- conn.close()
+ try:
+ conn.putrequest("POST", "/upload", skip_host=True)
+ conn.putheader("Host", self.HOST)
+ conn.putheader("Content-Type", "text/plain")
+ conn.putheader("Content-Length", "4")
+ conn.endheaders()
+ conn.send(ntob("d'oh"))
+ response = conn.response_class(conn.sock, method="POST")
+ version, status, reason = response._read_status()
+ self.assertNotEqual(status, 100)
+ finally:
+ conn.close()
# Now try a page with an Expect header...
- conn.connect()
- conn.putrequest("POST", "/upload", skip_host=True)
- conn.putheader("Host", self.HOST)
- conn.putheader("Content-Type", "text/plain")
- conn.putheader("Content-Length", "17")
- conn.putheader("Expect", "100-continue")
- conn.endheaders()
- response = conn.response_class(conn.sock, method="POST")
-
- # ...assert and then skip the 100 response
- version, status, reason = response._read_status()
- self.assertEqual(status, 100)
- while True:
- line = response.fp.readline().strip()
- if line:
- self.fail(
- "100 Continue should not output any headers. Got %r" %
- line)
- else:
- break
+ try:
+ conn.connect()
+ conn.putrequest("POST", "/upload", skip_host=True)
+ conn.putheader("Host", self.HOST)
+ conn.putheader("Content-Type", "text/plain")
+ conn.putheader("Content-Length", "17")
+ conn.putheader("Expect", "100-continue")
+ conn.endheaders()
+ response = conn.response_class(conn.sock, method="POST")
- # ...send the body
- body = ntob("I am a small file")
- conn.send(body)
+ # ...assert and then skip the 100 response
+ version, status, reason = response._read_status()
+ self.assertEqual(status, 100)
+ while True:
+ line = response.fp.readline().strip()
+ if line:
+ self.fail(
+ "100 Continue should not output any headers. Got %r" %
+ line)
+ else:
+ break
- # ...get the final response
- response.begin()
- self.status, self.headers, self.body = webtest.shb(response)
- self.assertStatus(200)
- self.assertBody("thanks for '%s'" % body)
- conn.close()
+ # ...send the body
+ body = ntob("I am a small file")
+ conn.send(body)
+
+ # ...get the final response
+ response.begin()
+ self.status, self.headers, self.body = webtest.shb(response)
+ self.assertStatus(200)
+ self.assertBody("thanks for '%s'" % body)
+ finally:
+ conn.close()
class ConnectionTests(helper.CPWebCase):
@@ -717,6 +733,88 @@ class ConnectionTests(helper.CPWebCase):
remote_data_conn.close()
+def setup_upload_server():
+
+ class Root:
+ def upload(self):
+ if not cherrypy.request.method == 'POST':
+ raise AssertionError("'POST' != request.method %r" %
+ cherrypy.request.method)
+ return "thanks for '%s'" % tonative(cherrypy.request.body.read())
+ upload.exposed = True
+
+ cherrypy.tree.mount(Root())
+ cherrypy.config.update({
+ 'server.max_request_body_size': 1001,
+ 'server.socket_timeout': 10,
+ 'server.accepted_queue_size': 5,
+ 'server.accepted_queue_timeout': 0.1,
+ })
+
+import errno
+socket_reset_errors = []
+# Not all of these names will be defined for every platform.
+for _ in ("ECONNRESET", "WSAECONNRESET"):
+ if _ in dir(errno):
+ socket_reset_errors.append(getattr(errno, _))
+
+class LimitedRequestQueueTests(helper.CPWebCase):
+ setup_server = staticmethod(setup_upload_server)
+
+ def test_queue_full(self):
+ conns = []
+ overflow_conn = None
+
+ try:
+ # Make 15 initial requests and leave them open, which should use
+ # all of wsgiserver's WorkerThreads and fill its Queue.
+ import time
+ for i in range(15):
+ conn = self.HTTP_CONN(self.HOST, self.PORT)
+ conn.putrequest("POST", "/upload", skip_host=True)
+ conn.putheader("Host", self.HOST)
+ conn.putheader("Content-Type", "text/plain")
+ conn.putheader("Content-Length", "4")
+ conn.endheaders()
+ conns.append(conn)
+
+ # Now try a 16th conn, which should be closed by the server immediately.
+ overflow_conn = self.HTTP_CONN(self.HOST, self.PORT)
+ # Manually connect since httplib won't let us set a timeout
+ for res in socket.getaddrinfo(self.HOST, self.PORT, 0,
+ socket.SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ overflow_conn.sock = socket.socket(af, socktype, proto)
+ overflow_conn.sock.settimeout(5)
+ overflow_conn.sock.connect(sa)
+ break
+
+ overflow_conn.putrequest("GET", "/", skip_host=True)
+ overflow_conn.putheader("Host", self.HOST)
+ overflow_conn.endheaders()
+ response = overflow_conn.response_class(overflow_conn.sock, method="GET")
+ try:
+ response.begin()
+ except socket.error as exc:
+ if exc.args[0] in socket_reset_errors:
+ pass # Expected.
+ else:
+ raise AssertionError("Overflow conn did not get RST. "
+ "Got %s instead" % repr(exc.args))
+ else:
+ raise AssertionError("Overflow conn did not get RST ")
+ finally:
+ for conn in conns:
+ conn.send(ntob("done"))
+ response = conn.response_class(conn.sock, method="POST")
+ response.begin()
+ self.body = response.read()
+ self.assertBody("thanks for 'done'")
+ self.assertEqual(response.status, 200)
+ conn.close()
+ if overflow_conn:
+ overflow_conn.close()
+
class BadRequestTests(helper.CPWebCase):
setup_server = staticmethod(setup_server)
diff --git a/cherrypy/wsgiserver/wsgiserver2.py b/cherrypy/wsgiserver/wsgiserver2.py
index a304922a..ddd07387 100644
--- a/cherrypy/wsgiserver/wsgiserver2.py
+++ b/cherrypy/wsgiserver/wsgiserver2.py
@@ -1545,12 +1545,14 @@ class ThreadPool(object):
and stop(timeout) attributes.
"""
- def __init__(self, server, min=10, max=-1):
+ def __init__(self, server, min=10, max=-1,
+ accepted_queue_size=-1, accepted_queue_timeout=10):
self.server = server
self.min = min
self.max = max
self._threads = []
- self._queue = queue.Queue()
+ self._queue = queue.Queue(maxsize=accepted_queue_size)
+ self._queue_put_timeout = accepted_queue_timeout
self.get = self._queue.get
def start(self):
@@ -1570,7 +1572,7 @@ class ThreadPool(object):
idle = property(_get_idle, doc=_get_idle.__doc__)
def put(self, obj):
- self._queue.put(obj)
+ self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
if obj is _SHUTDOWNREQUEST:
return
@@ -2072,7 +2074,12 @@ class HTTPServer(object):
conn.ssl_env = ssl_env
- self.requests.put(conn)
+ try:
+ self.requests.put(conn)
+ except queue.Full:
+ # Just drop the conn. TODO: write 503 back?
+ conn.close()
+ return
except socket.timeout:
# The only reason for the timeout in start() is so we can
# notice keyboard interrupts on Win32, which don't interrupt
@@ -2215,8 +2222,11 @@ class CherryPyWSGIServer(HTTPServer):
"""The version of WSGI to produce."""
def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None,
- max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5):
- self.requests = ThreadPool(self, min=numthreads or 1, max=max)
+ max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5,
+ accepted_queue_size=-1, accepted_queue_timeout=10):
+ self.requests = ThreadPool(self, min=numthreads or 1, max=max,
+ accepted_queue_size=accepted_queue_size,
+ accepted_queue_timeout=accepted_queue_timeout)
self.wsgi_app = wsgi_app
self.gateway = wsgi_gateways[self.wsgi_version]
diff --git a/cherrypy/wsgiserver/wsgiserver3.py b/cherrypy/wsgiserver/wsgiserver3.py
index 88acff9a..ef5e4542 100644
--- a/cherrypy/wsgiserver/wsgiserver3.py
+++ b/cherrypy/wsgiserver/wsgiserver3.py
@@ -1261,12 +1261,14 @@ class ThreadPool(object):
and stop(timeout) attributes.
"""
- def __init__(self, server, min=10, max=-1):
+ def __init__(self, server, min=10, max=-1,
+ accepted_queue_size=-1, accepted_queue_timeout=10):
self.server = server
self.min = min
self.max = max
self._threads = []
- self._queue = queue.Queue()
+ self._queue = queue.Queue(maxsize=accepted_queue_size)
+ self._queue_put_timeout = accepted_queue_timeout
self.get = self._queue.get
def start(self):
@@ -1286,7 +1288,7 @@ class ThreadPool(object):
idle = property(_get_idle, doc=_get_idle.__doc__)
def put(self, obj):
- self._queue.put(obj)
+ self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
if obj is _SHUTDOWNREQUEST:
return
@@ -1764,7 +1766,12 @@ class HTTPServer(object):
conn.ssl_env = ssl_env
- self.requests.put(conn)
+ try:
+ self.requests.put(conn)
+ except queue.Full:
+ # Just drop the conn. TODO: write 503 back?
+ conn.close()
+ return
except socket.timeout:
# The only reason for the timeout in start() is so we can
# notice keyboard interrupts on Win32, which don't interrupt
@@ -1906,8 +1913,11 @@ class CherryPyWSGIServer(HTTPServer):
"""The version of WSGI to produce."""
def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None,
- max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5):
- self.requests = ThreadPool(self, min=numthreads or 1, max=max)
+ max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5,
+ accepted_queue_size=-1, accepted_queue_timeout=10):
+ self.requests = ThreadPool(self, min=numthreads or 1, max=max,
+ accepted_queue_size=accepted_queue_size,
+ accepted_queue_timeout=accepted_queue_timeout)
self.wsgi_app = wsgi_app
self.gateway = wsgi_gateways[self.wsgi_version]