diff options
author | Jean-Paul Calderone <exarkun@twistedmatrix.com> | 2015-04-13 20:52:29 -0400 |
---|---|---|
committer | Jean-Paul Calderone <exarkun@twistedmatrix.com> | 2015-04-13 20:52:29 -0400 |
commit | da6399a3c9e5b6ed19d4ec68ba8980e83a3b7ad6 (patch) | |
tree | ef81bd7773f156da88489c68ff4799467ff7f18d | |
parent | 39a8d590ffb7c8adaa641f182ac7571ba30936a9 (diff) | |
parent | 21873744bc048e5a08130b27a3e347ba7d0fc0d3 (diff) | |
download | pyopenssl-da6399a3c9e5b6ed19d4ec68ba8980e83a3b7ad6.tar.gz |
Merge master again!
-rw-r--r-- | .travis.yml | 2 | ||||
-rw-r--r-- | OpenSSL/SSL.py | 158 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 208 | ||||
-rw-r--r-- | doc/api/ssl.rst | 36 |
4 files changed, 400 insertions, 4 deletions
diff --git a/.travis.yml b/.travis.yml index eda1242..eccc720 100644 --- a/.travis.yml +++ b/.travis.yml @@ -98,7 +98,7 @@ script: coverage report -m python -c "import OpenSSL.SSL; print(OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION))" -after_success: +after_script: - source ~/.venv/bin/activate && coveralls notifications: diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 54092c5..39a2bca 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -5,6 +5,7 @@ from weakref import WeakValueDictionary from errno import errorcode from six import text_type as _text_type +from six import binary_type as _binary_type from six import integer_types as integer_types from six import int2byte, indexbytes @@ -319,6 +320,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper): ) +class _ALPNSelectHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an ALPN selection callback. + """ + def __init__(self, callback): + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + encoded_len = indexbytes(instr, 0) + proto = instr[1:encoded_len + 1] + protolist.append(proto) + instr = instr[encoded_len + 1:] + + # Call the callback + outstr = callback(conn, protolist) + + if not isinstance(outstr, _binary_type): + raise TypeError("ALPN callback must return a bytestring.") + + # Save our callback arguments on the connection object to make + # sure that they don't get freed before OpenSSL can use them. + # Then, return them in the appropriate output parameters. + conn._alpn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._alpn_select_callback_args[0][0] + out[0] = conn._alpn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + + def _asFileDescriptor(obj): fd = None if not isinstance(obj, integer_types): @@ -349,6 +400,22 @@ def SSLeay_version(type): +def _requires_alpn(func): + """ + Wraps any function that requires ALPN support in OpenSSL, ensuring that + NotImplementedError is raised if ALPN support is not present. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if not _lib.Cryptography_HAS_ALPN: + raise NotImplementedError("ALPN not available.") + + return func(*args, **kwargs) + + return wrapper + + + class Session(object): pass @@ -410,6 +477,8 @@ class Context(object): self._npn_advertise_callback = None self._npn_select_helper = None self._npn_select_callback = None + self._alpn_select_helper = None + self._alpn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); # SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | @@ -924,7 +993,6 @@ class Context(object): _lib.SSL_CTX_set_tlsext_servername_callback( self._context, self._tlsext_servername_callback) - def set_npn_advertise_callback(self, callback): """ Specify a callback function that will be called when offering `Next @@ -957,6 +1025,44 @@ class Context(object): _lib.SSL_CTX_set_next_proto_select_cb( self._context, self._npn_select_callback, _ffi.NULL) + @_requires_alpn + def set_alpn_protos(self, protos): + """ + Specify the clients ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + input_str_len = _ffi.cast("unsigned", len(protostr)) + _lib.SSL_CTX_set_alpn_protos(self._context, input_str, input_str_len) + + @_requires_alpn + def set_alpn_select_callback(self, callback): + """ + Set the callback to handle ALPN protocol choice. + + :param callback: The callback function. It will be invoked with two + arguments: the Connection, and a list of offered protocols as + bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return + one of those bytestrings, the chosen protocol. + """ + self._alpn_select_helper = _ALPNSelectHelper(callback) + self._alpn_select_callback = self._alpn_select_helper.callback + _lib.SSL_CTX_set_alpn_select_cb( + self._context, self._alpn_select_callback, _ffi.NULL) + ContextType = Context @@ -988,6 +1094,12 @@ class Connection(object): self._npn_advertise_callback_args = None self._npn_select_callback_args = None + # References to strings used for Application Layer Protocol + # Negotiation. These strings get copied at some point but it's well + # after the callback returns, so we have to hang them somewhere to + # avoid them getting freed. + self._alpn_select_callback_args = None + self._reverse_mapping[self._ssl] = self if socket is None: @@ -1026,6 +1138,8 @@ class Connection(object): self._context._npn_advertise_helper.raise_if_problem() if self._context._npn_select_helper is not None: self._context._npn_select_helper.raise_if_problem() + if self._context._alpn_select_helper is not None: + self._context._alpn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: @@ -1765,6 +1879,48 @@ class Connection(object): return _ffi.buffer(data[0], data_len[0])[:] + @_requires_alpn + def set_alpn_protos(self, protos): + """ + Specify the client's ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + input_str_len = _ffi.cast("unsigned", len(protostr)) + _lib.SSL_set_alpn_protos(self._ssl, input_str, input_str_len) + + + def get_alpn_proto_negotiated(self): + """ + Get the protocol that was negotiated by ALPN. + """ + if not _lib.Cryptography_HAS_ALPN: + raise NotImplementedError("ALPN not available") + + data = _ffi.new("unsigned char **") + data_len = _ffi.new("unsigned int *") + + _lib.SSL_get0_alpn_selected(self._ssl, data, data_len) + + if not data_len: + return b'' + + return _ffi.buffer(data[0], data_len[0])[:] + + ConnectionType = Connection diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index bd00142..4dedb6b 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -43,10 +43,10 @@ from OpenSSL.SSL import ( from OpenSSL.SSL import ( Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) +from OpenSSL._util import lib as _lib from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b from OpenSSL.test.test_crypto import ( - cleartextCertificatePEM, cleartextPrivateKeyPEM) -from OpenSSL.test.test_crypto import ( + cleartextCertificatePEM, cleartextPrivateKeyPEM, client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, root_cert_pem) @@ -1783,6 +1783,210 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin): +class ApplicationLayerProtoNegotiationTests(TestCase, _LoopbackMixin): + """ + Tests for ALPN in PyOpenSSL. + """ + # Skip tests on versions that don't support ALPN. + if _lib.Cryptography_HAS_ALPN: + + def test_alpn_success(self): + """ + Clients and servers that agree on the negotiated ALPN protocol can + correct establish a connection, and the agreed protocol is reported + by the connections. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2') + + + def test_alpn_set_on_connection(self): + """ + The same as test_alpn_success, but setting the ALPN protocols on + the connection rather than the context. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' + + # Setup the client context but don't set any ALPN protocols. + client_context = Context(TLSv1_METHOD) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + # Set the ALPN protocols on the client connection. + client = Connection(client_context, None) + client.set_alpn_protos([b'http/1.1', b'spdy/2']) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2') + + + def test_alpn_server_fail(self): + """ + When clients and servers cannot agree on what protocol to use next + the TLS connection does not get established. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + return b'' + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises(Error, self._interactInMemory, server, client) + + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + + def test_alpn_no_server(self): + """ + When clients and servers cannot agree on what protocol to use next + because the server doesn't offer ALPN, no protocol is negotiated. + """ + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # Do the dance. + self._interactInMemory(server, client) + + self.assertEqual(client.get_alpn_proto_negotiated(), b'') + + + def test_alpn_callback_exception(self): + """ + We can handle exceptions in the ALPN select callback. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + raise TypeError() + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + else: + # No ALPN. + def test_alpn_not_implemented(self): + """ + If ALPN is not in OpenSSL, we should raise NotImplementedError. + """ + # Test the context methods first. + context = Context(TLSv1_METHOD) + self.assertRaises( + NotImplementedError, context.set_alpn_protos, None + ) + self.assertRaises( + NotImplementedError, context.set_alpn_select_callback, None + ) + + # Now test a connection. + conn = Connection(context) + self.assertRaises( + NotImplementedError, context.set_alpn_protos, None + ) + + + class SessionTests(TestCase): """ Unit tests for :py:obj:`OpenSSL.SSL.Session`. diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index e6a0775..2929305 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -498,6 +498,26 @@ Context objects have the following methods: .. versionadded:: 0.15 +.. py:method:: Context.set_alpn_protos(protos) + + Specify the protocols that the client is prepared to speak after the TLS + connection has been negotiated using Application Layer Protocol + Negotiation. + + *protos* should be a list of protocols that the client is offering, each + as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``. + + +.. py:method:: Context.set_alpn_select_callback(callback) + + Specify a callback function that will be called on the server when a client + offers protocols using Application Layer Protocol Negotiation. + + *callback* should be the callback function. It will be invoked with two + arguments: the :py:class:`Connection` and a list of offered protocols as + bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of + these bytestrings, the chosen protocol. + .. _openssl-session: @@ -849,6 +869,22 @@ Connection objects have the following methods: .. versionadded:: 0.15 +.. py:method:: Connection.set_alpn_protos(protos) + + Specify the protocols that the client is prepared to speak after the TLS + connection has been negotiated using Application Layer Protocol + Negotiation. + + *protos* should be a list of protocols that the client is offering, each + as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``. + + +.. py:method:: Connection.get_alpn_proto_negotiated() + + Get the protocol that was negotiated by Application Layer Protocol + Negotiation. Returns a bytestring of the protocol name. If no protocol has + been negotiated yet, returns an empty string. + .. Rubric:: Footnotes |