diff options
Diffstat (limited to 'tests/test_ssl.py')
-rw-r--r-- | tests/test_ssl.py | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py index c32ebab..8c090aa 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -3842,3 +3842,247 @@ class TestRequires(object): assert "Error text" in str(e.value) assert results == [] + + +class TestOCSP(_LoopbackMixin): + """ + Tests for PyOpenSSL's OCSP stapling support. + """ + sample_ocsp_data = b"this is totally ocsp data" + + def _client_connection(self, callback, data, request_ocsp=True): + """ + Builds a client connection suitable for using OCSP. + + :param callback: The callback to register for OCSP. + :param data: The opaque data object that will be handed to the + OCSP callback. + :param request_ocsp: Whether the client will actually ask for OCSP + stapling. Useful for testing only. + """ + ctx = Context(SSLv23_METHOD) + ctx.set_ocsp_client_callback(callback, data) + client = Connection(ctx) + + if request_ocsp: + client.request_ocsp() + + client.set_connect_state() + return client + + def _server_connection(self, callback, data): + """ + Builds a server connection suitable for using OCSP. + + :param callback: The callback to register for OCSP. + :param data: The opaque data object that will be handed to the + OCSP callback. + """ + ctx = Context(SSLv23_METHOD) + ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + ctx.set_ocsp_server_callback(callback, data) + server = Connection(ctx) + server.set_accept_state() + return server + + def test_callbacks_arent_called_by_default(self): + """ + If both the client and the server have registered OCSP callbacks, but + the client does not send the OCSP request, neither callback gets + called. + """ + called = [] + + def ocsp_callback(*args, **kwargs): + called.append((args, kwargs)) + + client = self._client_connection( + callback=ocsp_callback, data=None, request_ocsp=False + ) + server = self._server_connection(callback=ocsp_callback, data=None) + self._handshakeInMemory(client, server) + + assert not called + + def test_client_negotiates_without_server(self): + """ + If the client wants to do OCSP but the server does not, the handshake + succeeds, and the client callback fires with an empty byte string. + """ + called = [] + + def ocsp_callback(conn, ocsp_data, ignored): + called.append(ocsp_data) + return True + + client = self._client_connection(callback=ocsp_callback, data=None) + server = self._loopbackServerFactory(socket=None) + self._handshakeInMemory(client, server) + + assert len(called) == 1 + assert called[0] == b'' + + def test_client_receives_servers_data(self): + """ + The data the server sends in its callback is received by the client. + """ + calls = [] + + def server_callback(*args, **kwargs): + return self.sample_ocsp_data + + def client_callback(conn, ocsp_data, ignored): + calls.append(ocsp_data) + return True + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + self._handshakeInMemory(client, server) + + assert len(calls) == 1 + assert calls[0] == self.sample_ocsp_data + + def test_callbacks_are_invoked_with_connections(self): + """ + The first arguments to both callbacks are their respective connections. + """ + client_calls = [] + server_calls = [] + + def client_callback(conn, *args, **kwargs): + client_calls.append(conn) + return True + + def server_callback(conn, *args, **kwargs): + server_calls.append(conn) + return self.sample_ocsp_data + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + self._handshakeInMemory(client, server) + + assert len(client_calls) == 1 + assert len(server_calls) == 1 + assert client_calls[0] is client + assert server_calls[0] is server + + def test_opaque_data_is_passed_through(self): + """ + Both callbacks receive an opaque, user-provided piece of data in their + callbacks as the final argument. + """ + calls = [] + + def server_callback(*args): + calls.append(args) + return self.sample_ocsp_data + + def client_callback(*args): + calls.append(args) + return True + + sentinel = object() + + client = self._client_connection( + callback=client_callback, data=sentinel + ) + server = self._server_connection( + callback=server_callback, data=sentinel + ) + self._handshakeInMemory(client, server) + + assert len(calls) == 2 + assert calls[0][-1] is sentinel + assert calls[1][-1] is sentinel + + def test_server_returns_empty_string(self): + """ + If the server returns an empty bytestring from its callback, the + client callback is called with the empty bytestring. + """ + client_calls = [] + + def server_callback(*args): + return b'' + + def client_callback(conn, ocsp_data, ignored): + client_calls.append(ocsp_data) + return True + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + self._handshakeInMemory(client, server) + + assert len(client_calls) == 1 + assert client_calls[0] == b'' + + def test_client_returns_false_terminates_handshake(self): + """ + If the client returns False from its callback, the handshake fails. + """ + def server_callback(*args): + return self.sample_ocsp_data + + def client_callback(*args): + return False + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + + with pytest.raises(Error): + self._handshakeInMemory(client, server) + + def test_exceptions_in_client_bubble_up(self): + """ + The callbacks thrown in the client callback bubble up to the caller. + """ + class SentinelException(Exception): + pass + + def server_callback(*args): + return self.sample_ocsp_data + + def client_callback(*args): + raise SentinelException() + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + + with pytest.raises(SentinelException): + self._handshakeInMemory(client, server) + + def test_exceptions_in_server_bubble_up(self): + """ + The callbacks thrown in the server callback bubble up to the caller. + """ + class SentinelException(Exception): + pass + + def server_callback(*args): + raise SentinelException() + + def client_callback(*args): + pytest.fail("Should not be called") + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + + with pytest.raises(SentinelException): + self._handshakeInMemory(client, server) + + def test_server_must_return_bytes(self): + """ + The server callback must return a bytestring, or a TypeError is thrown. + """ + def server_callback(*args): + return self.sample_ocsp_data.decode('ascii') + + def client_callback(*args): + pytest.fail("Should not be called") + + client = self._client_connection(callback=client_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) + + with pytest.raises(TypeError): + self._handshakeInMemory(client, server) |