diff options
author | Paul Kehrer <paul.l.kehrer@gmail.com> | 2015-04-13 23:32:30 -0400 |
---|---|---|
committer | Paul Kehrer <paul.l.kehrer@gmail.com> | 2015-04-13 23:32:30 -0400 |
commit | df220ad9d87115efa01571e29aee78e29de79191 (patch) | |
tree | d9975249b15c55d7e2e0038dcec739e9696ad083 | |
parent | 10a9d3b0035bf6c1f6a0da4f8c1e31260f4b91ab (diff) | |
parent | 218a0146a82b1e392ef7bbde6fd63f58f2d35ad7 (diff) | |
download | pyopenssl-df220ad9d87115efa01571e29aee78e29de79191.tar.gz |
Merge pull request #218 from exarkun/npn_feature_detect
NPN feature detect
-rw-r--r-- | OpenSSL/SSL.py | 20 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 268 |
2 files changed, 168 insertions, 120 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index cfe705a..9296444 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -398,6 +398,21 @@ def SSLeay_version(type): return _ffi.string(_lib.SSLeay_version(type)) +def _requires_npn(func): + """ + Wraps any function that requires NPN support in OpenSSL, ensuring that + NotImplementedError is raised if NPN is not present. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if not _lib.Cryptography_HAS_NEXTPROTONEG: + raise NotImplementedError("NPN not available.") + + return func(*args, **kwargs) + + return wrapper + + def _requires_alpn(func): """ @@ -992,6 +1007,8 @@ class Context(object): _lib.SSL_CTX_set_tlsext_servername_callback( self._context, self._tlsext_servername_callback) + + @_requires_npn def set_npn_advertise_callback(self, callback): """ Specify a callback function that will be called when offering `Next @@ -1009,6 +1026,7 @@ class Context(object): self._context, self._npn_advertise_callback, _ffi.NULL) + @_requires_npn def set_npn_select_callback(self, callback): """ Specify a callback function that will be called when a server offers @@ -1867,6 +1885,8 @@ class Connection(object): version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher)) return version.decode("utf-8") + + @_requires_npn def get_next_proto_negotiated(self): """ Get the protocol that was negotiated by NPN. diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 4dedb6b..2055b6d 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -44,6 +44,7 @@ from OpenSSL.SSL import ( Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) from OpenSSL._util import lib as _lib + from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b from OpenSSL.test.test_crypto import ( cleartextCertificatePEM, cleartextPrivateKeyPEM, @@ -1627,159 +1628,186 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin): """ Test for Next Protocol Negotiation in PyOpenSSL. """ - def test_npn_success(self): - """ - Tests that clients and servers that agree on the negotiated next - protocol can correct establish a connection, and that the agreed - protocol is reported by the connections. - """ - advertise_args = [] - select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - select_args.append((conn, options)) - return b'spdy/2' + if _lib.Cryptography_HAS_NEXTPROTONEG: + def test_npn_success(self): + """ + Tests that clients and servers that agree on the negotiated next + protocol can correct establish a connection, and that the agreed + protocol is reported by the connections. + """ + advertise_args = [] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - self._interactInMemory(server, client) + self._interactInMemory(server, client) - self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') - self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') + self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') - def test_npn_client_fail(self): - """ - Tests that when clients and servers cannot agree on what protocol to - use next that the TLS connection does not get established. - """ - advertise_args = [] - select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - select_args.append((conn, options)) - return b'' + def test_npn_client_fail(self): + """ + Tests that when clients and servers cannot agree on what protocol + to use next that the TLS connection does not get established. + """ + advertise_args = [] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + select_args.append((conn, options)) + return b'' - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - # If the client doesn't return anything, the connection will fail. - self.assertRaises(Error, self._interactInMemory, server, client) + # If the client doesn't return anything, the connection will fail. + self.assertRaises(Error, self._interactInMemory, server, client) - self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - def test_npn_select_error(self): - """ - Test that we can handle exceptions in the select callback. If select - fails it should be fatal to the connection. - """ - advertise_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - raise TypeError + def test_npn_select_error(self): + """ + Test that we can handle exceptions in the select callback. If + select fails it should be fatal to the connection. + """ + advertise_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + raise TypeError - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - # If the callback throws an exception it should be raised here. - self.assertRaises(TypeError, self._interactInMemory, server, client) - self.assertEqual([(server,)], advertise_args) + # If the callback throws an exception it should be raised here. + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([(server,)], advertise_args) - def test_npn_advertise_error(self): - """ - Test that we can handle exceptions in the advertise callback. If - advertise fails no NPN is advertised to the client. - """ - select_args = [] - def advertise(conn): - raise TypeError - def select(conn, options): - select_args.append((conn, options)) - return b'' + def test_npn_advertise_error(self): + """ + Test that we can handle exceptions in the advertise callback. If + advertise fails no NPN is advertised to the client. + """ + select_args = [] + def advertise(conn): + raise TypeError + def select(conn, options): + select_args.append((conn, options)) + return b'' - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([], select_args) - # If the client doesn't return anything, the connection will fail. - self.assertRaises(TypeError, self._interactInMemory, server, client) - self.assertEqual([], select_args) + else: + # No NPN. + def test_npn_not_implemented(self): + # Test the context methods first. + context = Context(TLSv1_METHOD) + fail_methods = [ + context.set_npn_advertise_callback, + context.set_npn_select_callback, + ] + for method in fail_methods: + self.assertRaises( + NotImplementedError, method, None + ) + + # Now test a connection. + conn = Connection(context) + fail_methods = [ + conn.get_next_proto_negotiated, + ] + for method in fail_methods: + self.assertRaises(NotImplementedError, method) |