summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCory Benfield <lukasaoz@gmail.com>2015-04-12 09:11:49 -0400
committerCory Benfield <lukasaoz@gmail.com>2015-04-13 16:11:48 -0400
commitf1177e782fb6a0e9b4b62dae61fbe8c54e01170e (patch)
treec2bd442b060935287429b58f47450ed1e37f9039
parent63759dc59052fc53b159bc139ab9fa9ecfc5ae85 (diff)
downloadpyopenssl-f1177e782fb6a0e9b4b62dae61fbe8c54e01170e.tar.gz
Handle exceptions in ALPN select callback.
-rw-r--r--OpenSSL/SSL.py90
-rw-r--r--OpenSSL/test/test_ssl.py33
2 files changed, 88 insertions, 35 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index cc2472f..5d41aa8 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -319,6 +319,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper):
)
+class _AlpnSelectHelper(_CallbackExceptionHelper):
+ """
+ Wrap a callback such that it can be used as an ALPN selection callback.
+ """
+ def __init__(self, callback):
+ _CallbackExceptionHelper.__init__(self)
+
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, in_, inlen, arg):
+ try:
+ conn = Connection._reverse_mapping[ssl]
+
+ # The string passed to us is made up of multiple
+ # length-prefixed bytestrings. We need to split that into a
+ # list.
+ instr = _ffi.buffer(in_, inlen)[:]
+ protolist = []
+ while instr:
+ l = indexbytes(instr, 0)
+ proto = instr[1:l+1]
+ protolist.append(proto)
+ instr = instr[l+1:]
+
+ # Call the callback
+ outstr = callback(conn, protolist)
+
+ if not isinstance(outstr, _binary_type):
+ raise TypeError("ALPN callback must return a bytestring.")
+
+ # Save our callback arguments on the connection object to make
+ # sure that they don't get freed before OpenSSL can use them.
+ # Then, return them in the appropriate output parameters.
+ conn._alpn_select_callback_args = [
+ _ffi.new("unsigned char *", len(outstr)),
+ _ffi.new("unsigned char[]", outstr),
+ ]
+ outlen[0] = conn._alpn_select_callback_args[0][0]
+ out[0] = conn._alpn_select_callback_args[1]
+ return 0
+ except Exception as e:
+ self._problems.append(e)
+ return 2 # SSL_TLSEXT_ERR_ALERT_FATAL
+
+ self.callback = _ffi.callback(
+ "int (*)(SSL *, unsigned char **, unsigned char *, "
+ "const unsigned char *, unsigned int, void *)",
+ wrapper
+ )
+
+
def _asFileDescriptor(obj):
fd = None
if not isinstance(obj, integer_types):
@@ -410,6 +460,7 @@ class Context(object):
self._npn_advertise_callback = None
self._npn_select_helper = None
self._npn_select_callback = None
+ self._alpn_select_helper = None
self._alpn_select_callback = None
# SSL_CTX_set_app_data(self->ctx, self);
@@ -988,41 +1039,8 @@ class Context(object):
bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return
one of those bytestrings, the chosen protocol.
"""
- @wraps(callback)
- def wrapper(ssl, out, outlen, in_, inlen, arg):
- conn = Connection._reverse_mapping[ssl]
-
- # The string passed to us is made up of multiple length-prefixed
- # bytestrings. We need to split that into a list.
- instr = _ffi.buffer(in_, inlen)[:]
- protolist = []
- while instr:
- l = indexbytes(instr, 0)
- proto = instr[1:l+1]
- protolist.append(proto)
- instr = instr[l+1:]
-
- # Call the callback
- outstr = callback(conn, protolist)
-
- if not isinstance(outstr, _binary_type):
- raise TypeError("ALPN callback must return a bytestring.")
-
- # Save our callback arguments on the connection object to make sure
- # that they don't get freed before OpenSSL can use them. Then,
- # return them in the appropriate output parameters.
- conn._alpn_select_callback_args = [
- _ffi.new("unsigned char *", len(outstr)),
- _ffi.new("unsigned char[]", outstr),
- ]
- outlen[0] = conn._alpn_select_callback_args[0][0]
- out[0] = conn._alpn_select_callback_args[1]
- return 0
-
- self._alpn_select_callback = _ffi.callback(
- "int (*)(SSL *, unsigned char **, unsigned char *, "
- "const unsigned char *, unsigned int, void *)",
- wrapper)
+ self._alpn_select_helper = _AlpnSelectHelper(callback)
+ self._alpn_select_callback = self._alpn_select_helper.callback
_lib.SSL_CTX_set_alpn_select_cb(
self._context, self._alpn_select_callback, _ffi.NULL)
@@ -1101,6 +1119,8 @@ class Connection(object):
self._context._npn_advertise_helper.raise_if_problem()
if self._context._npn_select_helper is not None:
self._context._npn_select_helper.raise_if_problem()
+ if self._context._alpn_select_helper is not None:
+ self._context._alpn_select_helper.raise_if_problem()
error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ:
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index abc9dfd..68289a2 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -1927,6 +1927,39 @@ class ApplicationLayerProtoNegotiationTests(TestCase, _LoopbackMixin):
self.assertEqual(client.get_alpn_proto_negotiated(), b'')
+ def test_alpn_callback_exception(self):
+ """
+ Test that we can handle exceptions in the ALPN select callback.
+ """
+ select_args = []
+ def select(conn, options):
+ select_args.append((conn, options))
+ raise TypeError
+
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_alpn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ # If the client doesn't return anything, the connection will fail.
+ self.assertRaises(TypeError, self._interactInMemory, server, client)
+ self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args)
+
+
class SessionTests(TestCase):
"""