summaryrefslogtreecommitdiff
path: root/tests/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_ssl.py')
-rw-r--r--tests/test_ssl.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index c32ebab..8c090aa 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -3842,3 +3842,247 @@ class TestRequires(object):
assert "Error text" in str(e.value)
assert results == []
+
+
+class TestOCSP(_LoopbackMixin):
+ """
+ Tests for PyOpenSSL's OCSP stapling support.
+ """
+ sample_ocsp_data = b"this is totally ocsp data"
+
+ def _client_connection(self, callback, data, request_ocsp=True):
+ """
+ Builds a client connection suitable for using OCSP.
+
+ :param callback: The callback to register for OCSP.
+ :param data: The opaque data object that will be handed to the
+ OCSP callback.
+ :param request_ocsp: Whether the client will actually ask for OCSP
+ stapling. Useful for testing only.
+ """
+ ctx = Context(SSLv23_METHOD)
+ ctx.set_ocsp_client_callback(callback, data)
+ client = Connection(ctx)
+
+ if request_ocsp:
+ client.request_ocsp()
+
+ client.set_connect_state()
+ return client
+
+ def _server_connection(self, callback, data):
+ """
+ Builds a server connection suitable for using OCSP.
+
+ :param callback: The callback to register for OCSP.
+ :param data: The opaque data object that will be handed to the
+ OCSP callback.
+ """
+ ctx = Context(SSLv23_METHOD)
+ ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+ ctx.set_ocsp_server_callback(callback, data)
+ server = Connection(ctx)
+ server.set_accept_state()
+ return server
+
+ def test_callbacks_arent_called_by_default(self):
+ """
+ If both the client and the server have registered OCSP callbacks, but
+ the client does not send the OCSP request, neither callback gets
+ called.
+ """
+ called = []
+
+ def ocsp_callback(*args, **kwargs):
+ called.append((args, kwargs))
+
+ client = self._client_connection(
+ callback=ocsp_callback, data=None, request_ocsp=False
+ )
+ server = self._server_connection(callback=ocsp_callback, data=None)
+ self._handshakeInMemory(client, server)
+
+ assert not called
+
+ def test_client_negotiates_without_server(self):
+ """
+ If the client wants to do OCSP but the server does not, the handshake
+ succeeds, and the client callback fires with an empty byte string.
+ """
+ called = []
+
+ def ocsp_callback(conn, ocsp_data, ignored):
+ called.append(ocsp_data)
+ return True
+
+ client = self._client_connection(callback=ocsp_callback, data=None)
+ server = self._loopbackServerFactory(socket=None)
+ self._handshakeInMemory(client, server)
+
+ assert len(called) == 1
+ assert called[0] == b''
+
+ def test_client_receives_servers_data(self):
+ """
+ The data the server sends in its callback is received by the client.
+ """
+ calls = []
+
+ def server_callback(*args, **kwargs):
+ return self.sample_ocsp_data
+
+ def client_callback(conn, ocsp_data, ignored):
+ calls.append(ocsp_data)
+ return True
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+ self._handshakeInMemory(client, server)
+
+ assert len(calls) == 1
+ assert calls[0] == self.sample_ocsp_data
+
+ def test_callbacks_are_invoked_with_connections(self):
+ """
+ The first arguments to both callbacks are their respective connections.
+ """
+ client_calls = []
+ server_calls = []
+
+ def client_callback(conn, *args, **kwargs):
+ client_calls.append(conn)
+ return True
+
+ def server_callback(conn, *args, **kwargs):
+ server_calls.append(conn)
+ return self.sample_ocsp_data
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+ self._handshakeInMemory(client, server)
+
+ assert len(client_calls) == 1
+ assert len(server_calls) == 1
+ assert client_calls[0] is client
+ assert server_calls[0] is server
+
+ def test_opaque_data_is_passed_through(self):
+ """
+ Both callbacks receive an opaque, user-provided piece of data in their
+ callbacks as the final argument.
+ """
+ calls = []
+
+ def server_callback(*args):
+ calls.append(args)
+ return self.sample_ocsp_data
+
+ def client_callback(*args):
+ calls.append(args)
+ return True
+
+ sentinel = object()
+
+ client = self._client_connection(
+ callback=client_callback, data=sentinel
+ )
+ server = self._server_connection(
+ callback=server_callback, data=sentinel
+ )
+ self._handshakeInMemory(client, server)
+
+ assert len(calls) == 2
+ assert calls[0][-1] is sentinel
+ assert calls[1][-1] is sentinel
+
+ def test_server_returns_empty_string(self):
+ """
+ If the server returns an empty bytestring from its callback, the
+ client callback is called with the empty bytestring.
+ """
+ client_calls = []
+
+ def server_callback(*args):
+ return b''
+
+ def client_callback(conn, ocsp_data, ignored):
+ client_calls.append(ocsp_data)
+ return True
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+ self._handshakeInMemory(client, server)
+
+ assert len(client_calls) == 1
+ assert client_calls[0] == b''
+
+ def test_client_returns_false_terminates_handshake(self):
+ """
+ If the client returns False from its callback, the handshake fails.
+ """
+ def server_callback(*args):
+ return self.sample_ocsp_data
+
+ def client_callback(*args):
+ return False
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+
+ with pytest.raises(Error):
+ self._handshakeInMemory(client, server)
+
+ def test_exceptions_in_client_bubble_up(self):
+ """
+ The callbacks thrown in the client callback bubble up to the caller.
+ """
+ class SentinelException(Exception):
+ pass
+
+ def server_callback(*args):
+ return self.sample_ocsp_data
+
+ def client_callback(*args):
+ raise SentinelException()
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+
+ with pytest.raises(SentinelException):
+ self._handshakeInMemory(client, server)
+
+ def test_exceptions_in_server_bubble_up(self):
+ """
+ The callbacks thrown in the server callback bubble up to the caller.
+ """
+ class SentinelException(Exception):
+ pass
+
+ def server_callback(*args):
+ raise SentinelException()
+
+ def client_callback(*args):
+ pytest.fail("Should not be called")
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+
+ with pytest.raises(SentinelException):
+ self._handshakeInMemory(client, server)
+
+ def test_server_must_return_bytes(self):
+ """
+ The server callback must return a bytestring, or a TypeError is thrown.
+ """
+ def server_callback(*args):
+ return self.sample_ocsp_data.decode('ascii')
+
+ def client_callback(*args):
+ pytest.fail("Should not be called")
+
+ client = self._client_connection(callback=client_callback, data=None)
+ server = self._server_connection(callback=server_callback, data=None)
+
+ with pytest.raises(TypeError):
+ self._handshakeInMemory(client, server)