diff options
author | Cory Benfield <lukasaoz@gmail.com> | 2015-04-12 09:11:49 -0400 |
---|---|---|
committer | Cory Benfield <lukasaoz@gmail.com> | 2015-04-13 16:11:48 -0400 |
commit | f1177e782fb6a0e9b4b62dae61fbe8c54e01170e (patch) | |
tree | c2bd442b060935287429b58f47450ed1e37f9039 | |
parent | 63759dc59052fc53b159bc139ab9fa9ecfc5ae85 (diff) | |
download | pyopenssl-f1177e782fb6a0e9b4b62dae61fbe8c54e01170e.tar.gz |
Handle exceptions in ALPN select callback.
-rw-r--r-- | OpenSSL/SSL.py | 90 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 33 |
2 files changed, 88 insertions, 35 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index cc2472f..5d41aa8 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -319,6 +319,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper): ) +class _AlpnSelectHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an ALPN selection callback. + """ + def __init__(self, callback): + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + l = indexbytes(instr, 0) + proto = instr[1:l+1] + protolist.append(proto) + instr = instr[l+1:] + + # Call the callback + outstr = callback(conn, protolist) + + if not isinstance(outstr, _binary_type): + raise TypeError("ALPN callback must return a bytestring.") + + # Save our callback arguments on the connection object to make + # sure that they don't get freed before OpenSSL can use them. + # Then, return them in the appropriate output parameters. + conn._alpn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._alpn_select_callback_args[0][0] + out[0] = conn._alpn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + + def _asFileDescriptor(obj): fd = None if not isinstance(obj, integer_types): @@ -410,6 +460,7 @@ class Context(object): self._npn_advertise_callback = None self._npn_select_helper = None self._npn_select_callback = None + self._alpn_select_helper = None self._alpn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); @@ -988,41 +1039,8 @@ class Context(object): bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return one of those bytestrings, the chosen protocol. """ - @wraps(callback) - def wrapper(ssl, out, outlen, in_, inlen, arg): - conn = Connection._reverse_mapping[ssl] - - # The string passed to us is made up of multiple length-prefixed - # bytestrings. We need to split that into a list. - instr = _ffi.buffer(in_, inlen)[:] - protolist = [] - while instr: - l = indexbytes(instr, 0) - proto = instr[1:l+1] - protolist.append(proto) - instr = instr[l+1:] - - # Call the callback - outstr = callback(conn, protolist) - - if not isinstance(outstr, _binary_type): - raise TypeError("ALPN callback must return a bytestring.") - - # Save our callback arguments on the connection object to make sure - # that they don't get freed before OpenSSL can use them. Then, - # return them in the appropriate output parameters. - conn._alpn_select_callback_args = [ - _ffi.new("unsigned char *", len(outstr)), - _ffi.new("unsigned char[]", outstr), - ] - outlen[0] = conn._alpn_select_callback_args[0][0] - out[0] = conn._alpn_select_callback_args[1] - return 0 - - self._alpn_select_callback = _ffi.callback( - "int (*)(SSL *, unsigned char **, unsigned char *, " - "const unsigned char *, unsigned int, void *)", - wrapper) + self._alpn_select_helper = _AlpnSelectHelper(callback) + self._alpn_select_callback = self._alpn_select_helper.callback _lib.SSL_CTX_set_alpn_select_cb( self._context, self._alpn_select_callback, _ffi.NULL) @@ -1101,6 +1119,8 @@ class Connection(object): self._context._npn_advertise_helper.raise_if_problem() if self._context._npn_select_helper is not None: self._context._npn_select_helper.raise_if_problem() + if self._context._alpn_select_helper is not None: + self._context._alpn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index abc9dfd..68289a2 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1927,6 +1927,39 @@ class ApplicationLayerProtoNegotiationTests(TestCase, _LoopbackMixin): self.assertEqual(client.get_alpn_proto_negotiated(), b'') + def test_alpn_callback_exception(self): + """ + Test that we can handle exceptions in the ALPN select callback. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + raise TypeError + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises(TypeError, self._interactInMemory, server, client) + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + class SessionTests(TestCase): """ |