summaryrefslogtreecommitdiff
path: root/OpenSSL/test
diff options
context:
space:
mode:
authorJean-Paul Calderone <exarkun@divmod.com>2011-05-26 18:47:00 -0400
committerJean-Paul Calderone <exarkun@divmod.com>2011-05-26 18:47:00 -0400
commitc4cb658c516e20cad3e8707ba66dab78ce0bd1e8 (patch)
tree4495cb872f64d0d667d5f571a6f3efc23e3fe319 /OpenSSL/test
parent95613b7e9461a34db446501ce565c41feb7f5c6d (diff)
downloadpyopenssl-c4cb658c516e20cad3e8707ba66dab78ce0bd1e8.tar.gz
And SSL_get_servername, SSL_set_tlsext_host_name, and SSL_CTX_set_tlsext_servername_callback
Diffstat (limited to 'OpenSSL/test')
-rw-r--r--OpenSSL/test/test_ssl.py127
1 files changed, 124 insertions, 3 deletions
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 24a08b0..52ff818 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -5,12 +5,14 @@
Unit tests for L{OpenSSL.SSL}.
"""
+from gc import collect
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK
-from sys import platform
+from sys import platform, version_info
from socket import error, socket
from os import makedirs
from os.path import join
from unittest import main
+from weakref import ref
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
from OpenSSL.crypto import PKey, X509, X509Extension
@@ -873,6 +875,106 @@ class ContextTests(TestCase, _LoopbackMixin):
+class ServerNameCallbackTests(TestCase, _LoopbackMixin):
+ """
+ Tests for L{Context.set_tlsext_servername_callback} and its interaction with
+ L{Connection}.
+ """
+ def test_wrong_args(self):
+ """
+ L{Context.set_tlsext_servername_callback} raises L{TypeError} if called
+ with other than one argument.
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.set_tlsext_servername_callback)
+ self.assertRaises(
+ TypeError, context.set_tlsext_servername_callback, 1, 2)
+
+ def test_old_callback_forgotten(self):
+ """
+ If L{Context.set_tlsext_servername_callback} is used to specify a new
+ callback, the one it replaces is dereferenced.
+ """
+ def callback(connection):
+ pass
+
+ def replacement(connection):
+ pass
+
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(callback)
+
+ tracker = ref(callback)
+ del callback
+
+ context.set_tlsext_servername_callback(replacement)
+ collect()
+ self.assertIdentical(None, tracker())
+
+
+ def test_no_servername(self):
+ """
+ When a client specifies no server name, the callback passed to
+ L{Context.set_tlsext_servername_callback} is invoked and the result of
+ L{Connection.get_servername} is C{None}.
+ """
+ args = []
+ def servername(conn):
+ args.append((conn, conn.get_servername()))
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(servername)
+
+ # Lose our reference to it. The Context is responsible for keeping it
+ # alive now.
+ del servername
+ collect()
+
+ # Necessary to actually accept the connection
+ context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(context, None)
+ server.set_accept_state()
+
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, None)], args)
+
+
+ def test_servername(self):
+ """
+ When a client specifies a server name in its hello message, the callback
+ passed to L{Contexts.set_tlsext_servername_callback} is invoked and the
+ result of L{Connection.get_servername} is that server name.
+ """
+ args = []
+ def servername(conn):
+ args.append((conn, conn.get_servername()))
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(servername)
+
+ # Necessary to actually accept the connection
+ context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(context, None)
+ server.set_accept_state()
+
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+ client.set_tlsext_host_name(b("foo1.example.com"))
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, b("foo1.example.com"))], args)
+
+
+
class ConnectionTests(TestCase, _LoopbackMixin):
"""
Unit tests for L{OpenSSL.SSL.Connection}.
@@ -955,8 +1057,27 @@ class ConnectionTests(TestCase, _LoopbackMixin):
# Lose our references to the contexts, just in case the Connection isn't
# properly managing its own contributions to their reference counts.
del original, replacement
- import gc
- gc.collect()
+ collect()
+
+
+ def test_set_tlsext_host_name_wrong_args(self):
+ """
+ If L{Connection.set_tlsext_host_name} is called with a non-byte string
+ argument or a byte string with an embedded NUL or other than one
+ argument, L{TypeError} is raised.
+ """
+ conn = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, conn.set_tlsext_host_name)
+ self.assertRaises(TypeError, conn.set_tlsext_host_name, object())
+ self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456)
+ self.assertRaises(
+ TypeError, conn.set_tlsext_host_name, b("with\0null"))
+
+ if version_info >= (3,):
+ # On Python 3.x, don't accidentally implicitly convert from text.
+ self.assertRaises(
+ TypeError,
+ conn.set_tlsext_host_name, b("example.com").decode("ascii"))
def test_pending(self):