summaryrefslogtreecommitdiff
path: root/Lib/test/test_socketserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_socketserver.py')
-rw-r--r--Lib/test/test_socketserver.py183
1 files changed, 179 insertions, 4 deletions
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 0d0f86fca4..140a6abf9e 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -3,12 +3,11 @@ Test suite for socketserver.
"""
import contextlib
+import io
import os
import select
import signal
import socket
-import select
-import errno
import tempfile
import unittest
import socketserver
@@ -46,7 +45,7 @@ def receive(sock, n, timeout=20):
else:
raise RuntimeError("timed out on %r" % (sock,))
-if HAVE_UNIX_SOCKETS:
+if HAVE_UNIX_SOCKETS and HAVE_FORKING:
class ForkingUnixStreamServer(socketserver.ForkingMixIn,
socketserver.UnixStreamServer):
pass
@@ -58,6 +57,7 @@ if HAVE_UNIX_SOCKETS:
@contextlib.contextmanager
def simple_subprocess(testcase):
+ """Tests that a custom child process is not waited on (Issue 1540386)"""
pid = os.fork()
if pid == 0:
# Don't raise an exception; it would be caught by the test harness.
@@ -103,7 +103,6 @@ class SocketServerTest(unittest.TestCase):
class MyServer(svrcls):
def handle_error(self, request, client_address):
self.close_request(request)
- self.server_close()
raise
class MyHandler(hdlrbase):
@@ -279,6 +278,182 @@ class SocketServerTest(unittest.TestCase):
socketserver.TCPServer((HOST, -1),
socketserver.StreamRequestHandler)
+ def test_context_manager(self):
+ with socketserver.TCPServer((HOST, 0),
+ socketserver.StreamRequestHandler) as server:
+ pass
+ self.assertEqual(-1, server.socket.fileno())
+
+
+class ErrorHandlerTest(unittest.TestCase):
+ """Test that the servers pass normal exceptions from the handler to
+ handle_error(), and that exiting exceptions like SystemExit and
+ KeyboardInterrupt are not passed."""
+
+ def tearDown(self):
+ test.support.unlink(test.support.TESTFN)
+
+ def test_sync_handled(self):
+ BaseErrorTestServer(ValueError)
+ self.check_result(handled=True)
+
+ def test_sync_not_handled(self):
+ with self.assertRaises(SystemExit):
+ BaseErrorTestServer(SystemExit)
+ self.check_result(handled=False)
+
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_threading_handled(self):
+ ThreadingErrorTestServer(ValueError)
+ self.check_result(handled=True)
+
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_threading_not_handled(self):
+ ThreadingErrorTestServer(SystemExit)
+ self.check_result(handled=False)
+
+ @requires_forking
+ def test_forking_handled(self):
+ ForkingErrorTestServer(ValueError)
+ self.check_result(handled=True)
+
+ @requires_forking
+ def test_forking_not_handled(self):
+ ForkingErrorTestServer(SystemExit)
+ self.check_result(handled=False)
+
+ def check_result(self, handled):
+ with open(test.support.TESTFN) as log:
+ expected = 'Handler called\n' + 'Error handled\n' * handled
+ self.assertEqual(log.read(), expected)
+
+
+class BaseErrorTestServer(socketserver.TCPServer):
+ def __init__(self, exception):
+ self.exception = exception
+ super().__init__((HOST, 0), BadHandler)
+ with socket.create_connection(self.server_address):
+ pass
+ try:
+ self.handle_request()
+ finally:
+ self.server_close()
+ self.wait_done()
+
+ def handle_error(self, request, client_address):
+ with open(test.support.TESTFN, 'a') as log:
+ log.write('Error handled\n')
+
+ def wait_done(self):
+ pass
+
+
+class BadHandler(socketserver.BaseRequestHandler):
+ def handle(self):
+ with open(test.support.TESTFN, 'a') as log:
+ log.write('Handler called\n')
+ raise self.server.exception('Test error')
+
+
+class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
+ BaseErrorTestServer):
+ def __init__(self, *pos, **kw):
+ self.done = threading.Event()
+ super().__init__(*pos, **kw)
+
+ def shutdown_request(self, *pos, **kw):
+ super().shutdown_request(*pos, **kw)
+ self.done.set()
+
+ def wait_done(self):
+ self.done.wait()
+
+
+if HAVE_FORKING:
+ class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
+ def wait_done(self):
+ [child] = self.active_children
+ os.waitpid(child, 0)
+ self.active_children.clear()
+
+
+class SocketWriterTest(unittest.TestCase):
+ def test_basics(self):
+ class Handler(socketserver.StreamRequestHandler):
+ def handle(self):
+ self.server.wfile = self.wfile
+ self.server.wfile_fileno = self.wfile.fileno()
+ self.server.request_fileno = self.request.fileno()
+
+ server = socketserver.TCPServer((HOST, 0), Handler)
+ self.addCleanup(server.server_close)
+ s = socket.socket(
+ server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
+ with s:
+ s.connect(server.server_address)
+ server.handle_request()
+ self.assertIsInstance(server.wfile, io.BufferedIOBase)
+ self.assertEqual(server.wfile_fileno, server.request_fileno)
+
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_write(self):
+ # Test that wfile.write() sends data immediately, and that it does
+ # not truncate sends when interrupted by a Unix signal
+ pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
+
+ class Handler(socketserver.StreamRequestHandler):
+ def handle(self):
+ self.server.sent1 = self.wfile.write(b'write data\n')
+ # Should be sent immediately, without requiring flush()
+ self.server.received = self.rfile.readline()
+ big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
+ self.server.sent2 = self.wfile.write(big_chunk)
+
+ server = socketserver.TCPServer((HOST, 0), Handler)
+ self.addCleanup(server.server_close)
+ interrupted = threading.Event()
+
+ def signal_handler(signum, frame):
+ interrupted.set()
+
+ original = signal.signal(signal.SIGUSR1, signal_handler)
+ self.addCleanup(signal.signal, signal.SIGUSR1, original)
+ response1 = None
+ received2 = None
+ main_thread = threading.get_ident()
+
+ def run_client():
+ s = socket.socket(server.address_family, socket.SOCK_STREAM,
+ socket.IPPROTO_TCP)
+ with s, s.makefile('rb') as reader:
+ s.connect(server.server_address)
+ nonlocal response1
+ response1 = reader.readline()
+ s.sendall(b'client response\n')
+
+ reader.read(100)
+ # The main thread should now be blocking in a send() syscall.
+ # But in theory, it could get interrupted by other signals,
+ # and then retried. So keep sending the signal in a loop, in
+ # case an earlier signal happens to be delivered at an
+ # inconvenient moment.
+ while True:
+ pthread_kill(main_thread, signal.SIGUSR1)
+ if interrupted.wait(timeout=float(1)):
+ break
+ nonlocal received2
+ received2 = len(reader.read())
+
+ background = threading.Thread(target=run_client)
+ background.start()
+ server.handle_request()
+ background.join()
+ self.assertEqual(server.sent1, len(response1))
+ self.assertEqual(response1, b'write data\n')
+ self.assertEqual(server.received, b'client response\n')
+ self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
+ self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
+
class MiscTestCase(unittest.TestCase):