diff options
Diffstat (limited to 'bzrlib/tests/test_test_server.py')
-rw-r--r-- | bzrlib/tests/test_test_server.py | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/bzrlib/tests/test_test_server.py b/bzrlib/tests/test_test_server.py new file mode 100644 index 0000000..bd05a62 --- /dev/null +++ b/bzrlib/tests/test_test_server.py @@ -0,0 +1,336 @@ +# Copyright (C) 2010, 2011 Canonical Ltd +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +import errno +import socket +import SocketServer +import threading + + +from bzrlib import ( + osutils, + tests, + ) +from bzrlib.tests import test_server +from bzrlib.tests.scenarios import load_tests_apply_scenarios + + +load_tests = load_tests_apply_scenarios + + +def portable_socket_pair(): + """Return a pair of TCP sockets connected to each other. + + Unlike socket.socketpair, this should work on Windows. + """ + listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(1) + client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_sock.connect(listen_sock.getsockname()) + server_sock, addr = listen_sock.accept() + listen_sock.close() + return server_sock, client_sock + + +class TCPClient(object): + + def __init__(self): + self.sock = None + + def connect(self, addr): + if self.sock is not None: + raise AssertionError('Already connected to %r' + % (self.sock.getsockname(),)) + self.sock = osutils.connect_socket(addr) + + def disconnect(self): + if self.sock is not None: + try: + self.sock.shutdown(socket.SHUT_RDWR) + self.sock.close() + except socket.error, e: + if e[0] in (errno.EBADF, errno.ENOTCONN): + # Right, the socket is already down + pass + else: + raise + self.sock = None + + def write(self, s): + return self.sock.sendall(s) + + def read(self, bufsize=4096): + return self.sock.recv(bufsize) + + +class TCPConnectionHandler(SocketServer.BaseRequestHandler): + + def handle(self): + self.done = False + self.handle_connection() + while not self.done: + self.handle_connection() + + def readline(self): + # TODO: We should be buffering any extra data sent, etc. However, in + # practice, we don't send extra content, so we haven't bothered + # to implement it yet. + req = self.request.recv(4096) + # An empty string is allowed, to indicate the end of the connection + if not req or (req.endswith('\n') and req.count('\n') == 1): + return req + raise ValueError('[%r] not a simple line' % (req,)) + + def handle_connection(self): + req = self.readline() + if not req: + self.done = True + elif req == 'ping\n': + self.request.sendall('pong\n') + else: + raise ValueError('[%s] not understood' % req) + + +class TestTCPServerInAThread(tests.TestCase): + + scenarios = [ + (name, {'server_class': getattr(test_server, name)}) + for name in + ('TestingTCPServer', 'TestingThreadingTCPServer')] + + def get_server(self, server_class=None, connection_handler_class=None): + if server_class is not None: + self.server_class = server_class + if connection_handler_class is None: + connection_handler_class = TCPConnectionHandler + server = test_server.TestingTCPServerInAThread( + ('localhost', 0), self.server_class, connection_handler_class) + server.start_server() + self.addCleanup(server.stop_server) + return server + + def get_client(self): + client = TCPClient() + self.addCleanup(client.disconnect) + return client + + def get_server_connection(self, server, conn_rank): + return server.server.clients[conn_rank] + + def assertClientAddr(self, client, server, conn_rank): + conn = self.get_server_connection(server, conn_rank) + self.assertEquals(client.sock.getsockname(), conn[1]) + + def test_start_stop(self): + server = self.get_server() + client = self.get_client() + server.stop_server() + # since the server doesn't accept connections anymore attempting to + # connect should fail + client = self.get_client() + self.assertRaises(socket.error, + client.connect, (server.host, server.port)) + + def test_client_talks_server_respond(self): + server = self.get_server() + client = self.get_client() + client.connect((server.host, server.port)) + self.assertIs(None, client.write('ping\n')) + resp = client.read() + self.assertClientAddr(client, server, 0) + self.assertEquals('pong\n', resp) + + def test_server_fails_to_start(self): + class CantStart(Exception): + pass + + class CantStartServer(test_server.TestingTCPServer): + + def server_bind(self): + raise CantStart() + + # The exception is raised in the main thread + self.assertRaises(CantStart, + self.get_server, server_class=CantStartServer) + + def test_server_fails_while_serving_or_stopping(self): + class CantConnect(Exception): + pass + + class FailingConnectionHandler(TCPConnectionHandler): + + def handle(self): + raise CantConnect() + + server = self.get_server( + connection_handler_class=FailingConnectionHandler) + # The server won't fail until a client connect + client = self.get_client() + client.connect((server.host, server.port)) + # We make sure the server wants to handle a request, but the request is + # guaranteed to fail. However, the server should make sure that the + # connection gets closed, and stop_server should then raise the + # original exception. + client.write('ping\n') + try: + self.assertEqual('', client.read()) + except socket.error, e: + # On Windows, failing during 'handle' means we get + # 'forced-close-of-connection'. Possibly because we haven't + # processed the write request before we close the socket. + WSAECONNRESET = 10054 + if e.errno in (WSAECONNRESET,): + pass + # Now the server has raised the exception in its own thread + self.assertRaises(CantConnect, server.stop_server) + + def test_server_crash_while_responding(self): + # We want to ensure the exception has been caught + caught = threading.Event() + caught.clear() + # The thread that will serve the client, this needs to be an attribute + # so the handler below can modify it when it's executed (it's + # instantiated when the request is processed) + self.connection_thread = None + + class FailToRespond(Exception): + pass + + class FailingDuringResponseHandler(TCPConnectionHandler): + + # We use 'request' instead of 'self' below because the test matters + # more and we need a container to properly set connection_thread. + def handle_connection(request): + req = request.readline() + # Capture the thread and make it use 'caught' so we can wait on + # the event that will be set when the exception is caught. We + # also capture the thread to know where to look. + self.connection_thread = threading.currentThread() + self.connection_thread.set_sync_event(caught) + raise FailToRespond() + + server = self.get_server( + connection_handler_class=FailingDuringResponseHandler) + client = self.get_client() + client.connect((server.host, server.port)) + client.write('ping\n') + # Wait for the exception to be caught + caught.wait() + self.assertEqual('', client.read()) # connection closed + # Check that the connection thread did catch the exception, + # http://pad.lv/869366 was wrongly checking the server thread which + # works for TestingTCPServer where the connection is handled in the + # same thread than the server one but was racy for + # TestingThreadingTCPServer. Since the connection thread detaches + # itself before handling the request, we are guaranteed that the + # exception won't leak into the server thread anymore. + self.assertRaises(FailToRespond, + self.connection_thread.pending_exception) + + def test_exception_swallowed_while_serving(self): + # We need to ensure the exception has been caught + caught = threading.Event() + caught.clear() + # The thread that will serve the client, this needs to be an attribute + # so the handler below can access it when it's executed (it's + # instantiated when the request is processed) + self.connection_thread = None + class CantServe(Exception): + pass + + class FailingWhileServingConnectionHandler(TCPConnectionHandler): + + # We use 'request' instead of 'self' below because the test matters + # more and we need a container to properly set connection_thread. + def handle(request): + # Capture the thread and make it use 'caught' so we can wait on + # the event that will be set when the exception is caught. We + # also capture the thread to know where to look. + self.connection_thread = threading.currentThread() + self.connection_thread.set_sync_event(caught) + raise CantServe() + + server = self.get_server( + connection_handler_class=FailingWhileServingConnectionHandler) + self.assertEquals(True, server.server.serving) + # Install the exception swallower + server.set_ignored_exceptions(CantServe) + client = self.get_client() + # Connect to the server so the exception is raised there + client.connect((server.host, server.port)) + # Wait for the exception to be caught + caught.wait() + self.assertEqual('', client.read()) # connection closed + # The connection wasn't served properly but the exception should have + # been swallowed (see test_server_crash_while_responding remark about + # http://pad.lv/869366 explaining why we can't check the server thread + # here). More precisely, the exception *has* been caught and captured + # but it is cleared when joining the thread (or trying to acquire the + # exception) and as such won't propagate to the server thread. + self.assertIs(None, self.connection_thread.pending_exception()) + self.assertIs(None, server.pending_exception()) + + def test_handle_request_closes_if_it_doesnt_process(self): + server = self.get_server() + client = self.get_client() + server.server.serving = False + client.connect((server.host, server.port)) + self.assertEqual('', client.read()) + + +class TestTestingSmartServer(tests.TestCase): + + def test_sets_client_timeout(self): + server = test_server.TestingSmartServer(('localhost', 0), None, None, + root_client_path='/no-such-client/path') + self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT, + server._client_timeout) + sock = socket.socket() + h = server._make_handler(sock) + self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT, + h._client_timeout) + + +class FakeServer(object): + """Minimal implementation to pass to TestingSmartConnectionHandler""" + backing_transport = None + root_client_path = '/' + + +class TestTestingSmartConnectionHandler(tests.TestCase): + + def test_connection_timeout_suppressed(self): + self.overrideAttr(test_server, '_DEFAULT_TESTING_CLIENT_TIMEOUT', 0.01) + s = FakeServer() + server_sock, client_sock = portable_socket_pair() + # This should timeout quickly, but not generate an exception. + handler = test_server.TestingSmartConnectionHandler(server_sock, + server_sock.getpeername(), s) + + def test_connection_shutdown_while_serving_no_error(self): + s = FakeServer() + server_sock, client_sock = portable_socket_pair() + class ShutdownConnectionHandler( + test_server.TestingSmartConnectionHandler): + + def _build_protocol(self): + self.finished = True + return super(ShutdownConnectionHandler, self)._build_protocol() + # This should trigger shutdown after the entering _build_protocol, and + # we should exit cleanly, without raising an exception. + handler = ShutdownConnectionHandler(server_sock, + server_sock.getpeername(), s) |