diff options
author | Jean-Paul Calderone <exarkun@divmod.com> | 2011-05-26 18:47:00 -0400 |
---|---|---|
committer | Jean-Paul Calderone <exarkun@divmod.com> | 2011-05-26 18:47:00 -0400 |
commit | c4cb658c516e20cad3e8707ba66dab78ce0bd1e8 (patch) | |
tree | 4495cb872f64d0d667d5f571a6f3efc23e3fe319 /OpenSSL/test | |
parent | 95613b7e9461a34db446501ce565c41feb7f5c6d (diff) | |
download | pyopenssl-c4cb658c516e20cad3e8707ba66dab78ce0bd1e8.tar.gz |
And SSL_get_servername, SSL_set_tlsext_host_name, and SSL_CTX_set_tlsext_servername_callback
Diffstat (limited to 'OpenSSL/test')
-rw-r--r-- | OpenSSL/test/test_ssl.py | 127 |
1 files changed, 124 insertions, 3 deletions
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 24a08b0..52ff818 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -5,12 +5,14 @@ Unit tests for L{OpenSSL.SSL}. """ +from gc import collect from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK -from sys import platform +from sys import platform, version_info from socket import error, socket from os import makedirs from os.path import join from unittest import main +from weakref import ref from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM from OpenSSL.crypto import PKey, X509, X509Extension @@ -873,6 +875,106 @@ class ContextTests(TestCase, _LoopbackMixin): +class ServerNameCallbackTests(TestCase, _LoopbackMixin): + """ + Tests for L{Context.set_tlsext_servername_callback} and its interaction with + L{Connection}. + """ + def test_wrong_args(self): + """ + L{Context.set_tlsext_servername_callback} raises L{TypeError} if called + with other than one argument. + """ + context = Context(TLSv1_METHOD) + self.assertRaises(TypeError, context.set_tlsext_servername_callback) + self.assertRaises( + TypeError, context.set_tlsext_servername_callback, 1, 2) + + def test_old_callback_forgotten(self): + """ + If L{Context.set_tlsext_servername_callback} is used to specify a new + callback, the one it replaces is dereferenced. + """ + def callback(connection): + pass + + def replacement(connection): + pass + + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(callback) + + tracker = ref(callback) + del callback + + context.set_tlsext_servername_callback(replacement) + collect() + self.assertIdentical(None, tracker()) + + + def test_no_servername(self): + """ + When a client specifies no server name, the callback passed to + L{Context.set_tlsext_servername_callback} is invoked and the result of + L{Connection.get_servername} is C{None}. + """ + args = [] + def servername(conn): + args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(servername) + + # Lose our reference to it. The Context is responsible for keeping it + # alive now. + del servername + collect() + + # Necessary to actually accept the connection + context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(context, None) + server.set_accept_state() + + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, None)], args) + + + def test_servername(self): + """ + When a client specifies a server name in its hello message, the callback + passed to L{Contexts.set_tlsext_servername_callback} is invoked and the + result of L{Connection.get_servername} is that server name. + """ + args = [] + def servername(conn): + args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(servername) + + # Necessary to actually accept the connection + context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(context, None) + server.set_accept_state() + + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + client.set_tlsext_host_name(b("foo1.example.com")) + + self._interactInMemory(server, client) + + self.assertEqual([(server, b("foo1.example.com"))], args) + + + class ConnectionTests(TestCase, _LoopbackMixin): """ Unit tests for L{OpenSSL.SSL.Connection}. @@ -955,8 +1057,27 @@ class ConnectionTests(TestCase, _LoopbackMixin): # Lose our references to the contexts, just in case the Connection isn't # properly managing its own contributions to their reference counts. del original, replacement - import gc - gc.collect() + collect() + + + def test_set_tlsext_host_name_wrong_args(self): + """ + If L{Connection.set_tlsext_host_name} is called with a non-byte string + argument or a byte string with an embedded NUL or other than one + argument, L{TypeError} is raised. + """ + conn = Connection(Context(TLSv1_METHOD), None) + self.assertRaises(TypeError, conn.set_tlsext_host_name) + self.assertRaises(TypeError, conn.set_tlsext_host_name, object()) + self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456) + self.assertRaises( + TypeError, conn.set_tlsext_host_name, b("with\0null")) + + if version_info >= (3,): + # On Python 3.x, don't accidentally implicitly convert from text. + self.assertRaises( + TypeError, + conn.set_tlsext_host_name, b("example.com").decode("ascii")) def test_pending(self): |