summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Kehrer <paul.l.kehrer@gmail.com>2015-04-13 23:32:30 -0400
committerPaul Kehrer <paul.l.kehrer@gmail.com>2015-04-13 23:32:30 -0400
commitdf220ad9d87115efa01571e29aee78e29de79191 (patch)
treed9975249b15c55d7e2e0038dcec739e9696ad083
parent10a9d3b0035bf6c1f6a0da4f8c1e31260f4b91ab (diff)
parent218a0146a82b1e392ef7bbde6fd63f58f2d35ad7 (diff)
downloadpyopenssl-df220ad9d87115efa01571e29aee78e29de79191.tar.gz
Merge pull request #218 from exarkun/npn_feature_detect
NPN feature detect
-rw-r--r--OpenSSL/SSL.py20
-rw-r--r--OpenSSL/test/test_ssl.py268
2 files changed, 168 insertions, 120 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index cfe705a..9296444 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -398,6 +398,21 @@ def SSLeay_version(type):
return _ffi.string(_lib.SSLeay_version(type))
+def _requires_npn(func):
+ """
+ Wraps any function that requires NPN support in OpenSSL, ensuring that
+ NotImplementedError is raised if NPN is not present.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not _lib.Cryptography_HAS_NEXTPROTONEG:
+ raise NotImplementedError("NPN not available.")
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
def _requires_alpn(func):
"""
@@ -992,6 +1007,8 @@ class Context(object):
_lib.SSL_CTX_set_tlsext_servername_callback(
self._context, self._tlsext_servername_callback)
+
+ @_requires_npn
def set_npn_advertise_callback(self, callback):
"""
Specify a callback function that will be called when offering `Next
@@ -1009,6 +1026,7 @@ class Context(object):
self._context, self._npn_advertise_callback, _ffi.NULL)
+ @_requires_npn
def set_npn_select_callback(self, callback):
"""
Specify a callback function that will be called when a server offers
@@ -1867,6 +1885,8 @@ class Connection(object):
version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher))
return version.decode("utf-8")
+
+ @_requires_npn
def get_next_proto_negotiated(self):
"""
Get the protocol that was negotiated by NPN.
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 4dedb6b..2055b6d 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -44,6 +44,7 @@ from OpenSSL.SSL import (
Context, ContextType, Session, Connection, ConnectionType, SSLeay_version)
from OpenSSL._util import lib as _lib
+
from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b
from OpenSSL.test.test_crypto import (
cleartextCertificatePEM, cleartextPrivateKeyPEM,
@@ -1627,159 +1628,186 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
"""
Test for Next Protocol Negotiation in PyOpenSSL.
"""
- def test_npn_success(self):
- """
- Tests that clients and servers that agree on the negotiated next
- protocol can correct establish a connection, and that the agreed
- protocol is reported by the connections.
- """
- advertise_args = []
- select_args = []
- def advertise(conn):
- advertise_args.append((conn,))
- return [b'http/1.1', b'spdy/2']
- def select(conn, options):
- select_args.append((conn, options))
- return b'spdy/2'
+ if _lib.Cryptography_HAS_NEXTPROTONEG:
+ def test_npn_success(self):
+ """
+ Tests that clients and servers that agree on the negotiated next
+ protocol can correct establish a connection, and that the agreed
+ protocol is reported by the connections.
+ """
+ advertise_args = []
+ select_args = []
+ def advertise(conn):
+ advertise_args.append((conn,))
+ return [b'http/1.1', b'spdy/2']
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b'spdy/2'
- server_context = Context(TLSv1_METHOD)
- server_context.set_npn_advertise_callback(advertise)
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_npn_advertise_callback(advertise)
- client_context = Context(TLSv1_METHOD)
- client_context.set_npn_select_callback(select)
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_npn_select_callback(select)
- # Necessary to actually accept the connection
- server_context.use_privatekey(
- load_privatekey(FILETYPE_PEM, server_key_pem))
- server_context.use_certificate(
- load_certificate(FILETYPE_PEM, server_cert_pem))
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
- # Do a little connection to trigger the logic
- server = Connection(server_context, None)
- server.set_accept_state()
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
- client = Connection(client_context, None)
- client.set_connect_state()
+ client = Connection(client_context, None)
+ client.set_connect_state()
- self._interactInMemory(server, client)
+ self._interactInMemory(server, client)
- self.assertEqual([(server,)], advertise_args)
- self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
+ self.assertEqual([(server,)], advertise_args)
+ self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
- self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
- self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')
+ self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
+ self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')
- def test_npn_client_fail(self):
- """
- Tests that when clients and servers cannot agree on what protocol to
- use next that the TLS connection does not get established.
- """
- advertise_args = []
- select_args = []
- def advertise(conn):
- advertise_args.append((conn,))
- return [b'http/1.1', b'spdy/2']
- def select(conn, options):
- select_args.append((conn, options))
- return b''
+ def test_npn_client_fail(self):
+ """
+ Tests that when clients and servers cannot agree on what protocol
+ to use next that the TLS connection does not get established.
+ """
+ advertise_args = []
+ select_args = []
+ def advertise(conn):
+ advertise_args.append((conn,))
+ return [b'http/1.1', b'spdy/2']
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b''
- server_context = Context(TLSv1_METHOD)
- server_context.set_npn_advertise_callback(advertise)
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_npn_advertise_callback(advertise)
- client_context = Context(TLSv1_METHOD)
- client_context.set_npn_select_callback(select)
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_npn_select_callback(select)
- # Necessary to actually accept the connection
- server_context.use_privatekey(
- load_privatekey(FILETYPE_PEM, server_key_pem))
- server_context.use_certificate(
- load_certificate(FILETYPE_PEM, server_cert_pem))
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
- # Do a little connection to trigger the logic
- server = Connection(server_context, None)
- server.set_accept_state()
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
- client = Connection(client_context, None)
- client.set_connect_state()
+ client = Connection(client_context, None)
+ client.set_connect_state()
- # If the client doesn't return anything, the connection will fail.
- self.assertRaises(Error, self._interactInMemory, server, client)
+ # If the client doesn't return anything, the connection will fail.
+ self.assertRaises(Error, self._interactInMemory, server, client)
- self.assertEqual([(server,)], advertise_args)
- self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
+ self.assertEqual([(server,)], advertise_args)
+ self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
- def test_npn_select_error(self):
- """
- Test that we can handle exceptions in the select callback. If select
- fails it should be fatal to the connection.
- """
- advertise_args = []
- def advertise(conn):
- advertise_args.append((conn,))
- return [b'http/1.1', b'spdy/2']
- def select(conn, options):
- raise TypeError
+ def test_npn_select_error(self):
+ """
+ Test that we can handle exceptions in the select callback. If
+ select fails it should be fatal to the connection.
+ """
+ advertise_args = []
+ def advertise(conn):
+ advertise_args.append((conn,))
+ return [b'http/1.1', b'spdy/2']
+ def select(conn, options):
+ raise TypeError
- server_context = Context(TLSv1_METHOD)
- server_context.set_npn_advertise_callback(advertise)
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_npn_advertise_callback(advertise)
- client_context = Context(TLSv1_METHOD)
- client_context.set_npn_select_callback(select)
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_npn_select_callback(select)
- # Necessary to actually accept the connection
- server_context.use_privatekey(
- load_privatekey(FILETYPE_PEM, server_key_pem))
- server_context.use_certificate(
- load_certificate(FILETYPE_PEM, server_cert_pem))
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
- # Do a little connection to trigger the logic
- server = Connection(server_context, None)
- server.set_accept_state()
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
- client = Connection(client_context, None)
- client.set_connect_state()
+ client = Connection(client_context, None)
+ client.set_connect_state()
- # If the callback throws an exception it should be raised here.
- self.assertRaises(TypeError, self._interactInMemory, server, client)
- self.assertEqual([(server,)], advertise_args)
+ # If the callback throws an exception it should be raised here.
+ self.assertRaises(
+ TypeError, self._interactInMemory, server, client
+ )
+ self.assertEqual([(server,)], advertise_args)
- def test_npn_advertise_error(self):
- """
- Test that we can handle exceptions in the advertise callback. If
- advertise fails no NPN is advertised to the client.
- """
- select_args = []
- def advertise(conn):
- raise TypeError
- def select(conn, options):
- select_args.append((conn, options))
- return b''
+ def test_npn_advertise_error(self):
+ """
+ Test that we can handle exceptions in the advertise callback. If
+ advertise fails no NPN is advertised to the client.
+ """
+ select_args = []
+ def advertise(conn):
+ raise TypeError
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b''
- server_context = Context(TLSv1_METHOD)
- server_context.set_npn_advertise_callback(advertise)
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_npn_advertise_callback(advertise)
- client_context = Context(TLSv1_METHOD)
- client_context.set_npn_select_callback(select)
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_npn_select_callback(select)
- # Necessary to actually accept the connection
- server_context.use_privatekey(
- load_privatekey(FILETYPE_PEM, server_key_pem))
- server_context.use_certificate(
- load_certificate(FILETYPE_PEM, server_cert_pem))
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
- # Do a little connection to trigger the logic
- server = Connection(server_context, None)
- server.set_accept_state()
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
- client = Connection(client_context, None)
- client.set_connect_state()
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ # If the client doesn't return anything, the connection will fail.
+ self.assertRaises(
+ TypeError, self._interactInMemory, server, client
+ )
+ self.assertEqual([], select_args)
- # If the client doesn't return anything, the connection will fail.
- self.assertRaises(TypeError, self._interactInMemory, server, client)
- self.assertEqual([], select_args)
+ else:
+ # No NPN.
+ def test_npn_not_implemented(self):
+ # Test the context methods first.
+ context = Context(TLSv1_METHOD)
+ fail_methods = [
+ context.set_npn_advertise_callback,
+ context.set_npn_select_callback,
+ ]
+ for method in fail_methods:
+ self.assertRaises(
+ NotImplementedError, method, None
+ )
+
+ # Now test a connection.
+ conn = Connection(context)
+ fail_methods = [
+ conn.get_next_proto_negotiated,
+ ]
+ for method in fail_methods:
+ self.assertRaises(NotImplementedError, method)