summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Paul Calderone <exarkun@twistedmatrix.com>2015-04-13 22:14:07 -0400
committerJean-Paul Calderone <exarkun@twistedmatrix.com>2015-04-13 22:14:07 -0400
commit218a0146a82b1e392ef7bbde6fd63f58f2d35ad7 (patch)
treea412d198374e74f8ab4a4662642aa913a54d900c
parentba1820dfb1b03c849b272410d2c955478e633fbc (diff)
parent5c3b748846ad1f9597d51b24d04ac394980c2480 (diff)
downloadpyopenssl-218a0146a82b1e392ef7bbde6fd63f58f2d35ad7.tar.gz
merge master
-rw-r--r--.travis.yml2
-rw-r--r--OpenSSL/SSL.py165
-rw-r--r--OpenSSL/_util.py27
-rw-r--r--OpenSSL/crypto.py12
-rw-r--r--OpenSSL/test/test_crypto.py94
-rw-r--r--OpenSSL/test/test_ssl.py253
-rw-r--r--OpenSSL/test/util.py9
-rw-r--r--doc/api/ssl.rst36
8 files changed, 589 insertions, 9 deletions
diff --git a/.travis.yml b/.travis.yml
index eda1242..eccc720 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -98,7 +98,7 @@ script:
coverage report -m
python -c "import OpenSSL.SSL; print(OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION))"
-after_success:
+after_script:
- source ~/.venv/bin/activate && coveralls
notifications:
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 87492af..bba59ac 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -5,6 +5,7 @@ from weakref import WeakValueDictionary
from errno import errorcode
from six import text_type as _text_type
+from six import binary_type as _binary_type
from six import integer_types as integer_types
from six import int2byte, indexbytes
@@ -13,6 +14,7 @@ from OpenSSL._util import (
lib as _lib,
exception_from_error_queue as _exception_from_error_queue,
native as _native,
+ text_to_bytes_and_warn as _text_to_bytes_and_warn,
path_string as _path_string,
)
@@ -318,6 +320,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper):
)
+class _ALPNSelectHelper(_CallbackExceptionHelper):
+ """
+ Wrap a callback such that it can be used as an ALPN selection callback.
+ """
+ def __init__(self, callback):
+ _CallbackExceptionHelper.__init__(self)
+
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, in_, inlen, arg):
+ try:
+ conn = Connection._reverse_mapping[ssl]
+
+ # The string passed to us is made up of multiple
+ # length-prefixed bytestrings. We need to split that into a
+ # list.
+ instr = _ffi.buffer(in_, inlen)[:]
+ protolist = []
+ while instr:
+ encoded_len = indexbytes(instr, 0)
+ proto = instr[1:encoded_len + 1]
+ protolist.append(proto)
+ instr = instr[encoded_len + 1:]
+
+ # Call the callback
+ outstr = callback(conn, protolist)
+
+ if not isinstance(outstr, _binary_type):
+ raise TypeError("ALPN callback must return a bytestring.")
+
+ # Save our callback arguments on the connection object to make
+ # sure that they don't get freed before OpenSSL can use them.
+ # Then, return them in the appropriate output parameters.
+ conn._alpn_select_callback_args = [
+ _ffi.new("unsigned char *", len(outstr)),
+ _ffi.new("unsigned char[]", outstr),
+ ]
+ outlen[0] = conn._alpn_select_callback_args[0][0]
+ out[0] = conn._alpn_select_callback_args[1]
+ return 0
+ except Exception as e:
+ self._problems.append(e)
+ return 2 # SSL_TLSEXT_ERR_ALERT_FATAL
+
+ self.callback = _ffi.callback(
+ "int (*)(SSL *, unsigned char **, unsigned char *, "
+ "const unsigned char *, unsigned int, void *)",
+ wrapper
+ )
+
+
def _asFileDescriptor(obj):
fd = None
if not isinstance(obj, integer_types):
@@ -363,6 +415,22 @@ def _requires_npn(func):
+def _requires_alpn(func):
+ """
+ Wraps any function that requires ALPN support in OpenSSL, ensuring that
+ NotImplementedError is raised if ALPN support is not present.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not _lib.Cryptography_HAS_ALPN:
+ raise NotImplementedError("ALPN not available.")
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+
class Session(object):
pass
@@ -424,6 +492,8 @@ class Context(object):
self._npn_advertise_callback = None
self._npn_select_helper = None
self._npn_select_callback = None
+ self._alpn_select_helper = None
+ self._alpn_select_callback = None
# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
@@ -973,6 +1043,44 @@ class Context(object):
_lib.SSL_CTX_set_next_proto_select_cb(
self._context, self._npn_select_callback, _ffi.NULL)
+ @_requires_alpn
+ def set_alpn_protos(self, protos):
+ """
+ Specify the clients ALPN protocol list.
+
+ These protocols are offered to the server during protocol negotiation.
+
+ :param protos: A list of the protocols to be offered to the server.
+ This list should be a Python list of bytestrings representing the
+ protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
+ """
+ # Take the list of protocols and join them together, prefixing them
+ # with their lengths.
+ protostr = b''.join(
+ chain.from_iterable((int2byte(len(p)), p) for p in protos)
+ )
+
+ # Build a C string from the list. We don't need to save this off
+ # because OpenSSL immediately copies the data out.
+ input_str = _ffi.new("unsigned char[]", protostr)
+ input_str_len = _ffi.cast("unsigned", len(protostr))
+ _lib.SSL_CTX_set_alpn_protos(self._context, input_str, input_str_len)
+
+ @_requires_alpn
+ def set_alpn_select_callback(self, callback):
+ """
+ Set the callback to handle ALPN protocol choice.
+
+ :param callback: The callback function. It will be invoked with two
+ arguments: the Connection, and a list of offered protocols as
+ bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return
+ one of those bytestrings, the chosen protocol.
+ """
+ self._alpn_select_helper = _ALPNSelectHelper(callback)
+ self._alpn_select_callback = self._alpn_select_helper.callback
+ _lib.SSL_CTX_set_alpn_select_cb(
+ self._context, self._alpn_select_callback, _ffi.NULL)
+
ContextType = Context
@@ -1004,6 +1112,12 @@ class Connection(object):
self._npn_advertise_callback_args = None
self._npn_select_callback_args = None
+ # References to strings used for Application Layer Protocol
+ # Negotiation. These strings get copied at some point but it's well
+ # after the callback returns, so we have to hang them somewhere to
+ # avoid them getting freed.
+ self._alpn_select_callback_args = None
+
self._reverse_mapping[self._ssl] = self
if socket is None:
@@ -1042,6 +1156,8 @@ class Connection(object):
self._context._npn_advertise_helper.raise_if_problem()
if self._context._npn_select_helper is not None:
self._context._npn_select_helper.raise_if_problem()
+ if self._context._alpn_select_helper is not None:
+ self._context._alpn_select_helper.raise_if_problem()
error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ:
@@ -1142,6 +1258,9 @@ class Connection(object):
API, the value is ignored
:return: The number of bytes written
"""
+ # Backward compatibility
+ buf = _text_to_bytes_and_warn("buf", buf)
+
if isinstance(buf, _memoryview):
buf = buf.tobytes()
if isinstance(buf, _buffer):
@@ -1166,6 +1285,8 @@ class Connection(object):
API, the value is ignored
:return: The number of bytes written
"""
+ buf = _text_to_bytes_and_warn("buf", buf)
+
if isinstance(buf, _memoryview):
buf = buf.tobytes()
if isinstance(buf, _buffer):
@@ -1290,6 +1411,8 @@ class Connection(object):
:param buf: The string to put into the memory BIO.
:return: The number of bytes written
"""
+ buf = _text_to_bytes_and_warn("buf", buf)
+
if self._into_ssl is None:
raise TypeError("Connection sock was not None")
@@ -1776,6 +1899,48 @@ class Connection(object):
return _ffi.buffer(data[0], data_len[0])[:]
+ @_requires_alpn
+ def set_alpn_protos(self, protos):
+ """
+ Specify the client's ALPN protocol list.
+
+ These protocols are offered to the server during protocol negotiation.
+
+ :param protos: A list of the protocols to be offered to the server.
+ This list should be a Python list of bytestrings representing the
+ protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
+ """
+ # Take the list of protocols and join them together, prefixing them
+ # with their lengths.
+ protostr = b''.join(
+ chain.from_iterable((int2byte(len(p)), p) for p in protos)
+ )
+
+ # Build a C string from the list. We don't need to save this off
+ # because OpenSSL immediately copies the data out.
+ input_str = _ffi.new("unsigned char[]", protostr)
+ input_str_len = _ffi.cast("unsigned", len(protostr))
+ _lib.SSL_set_alpn_protos(self._ssl, input_str, input_str_len)
+
+
+ def get_alpn_proto_negotiated(self):
+ """
+ Get the protocol that was negotiated by ALPN.
+ """
+ if not _lib.Cryptography_HAS_ALPN:
+ raise NotImplementedError("ALPN not available")
+
+ data = _ffi.new("unsigned char **")
+ data_len = _ffi.new("unsigned int *")
+
+ _lib.SSL_get0_alpn_selected(self._ssl, data, data_len)
+
+ if not data_len:
+ return b''
+
+ return _ffi.buffer(data[0], data_len[0])[:]
+
+
ConnectionType = Connection
diff --git a/OpenSSL/_util.py b/OpenSSL/_util.py
index 1eec17e..8d59252 100644
--- a/OpenSSL/_util.py
+++ b/OpenSSL/_util.py
@@ -1,3 +1,4 @@
+from warnings import warn
import sys
from six import PY3, binary_type, text_type
@@ -93,3 +94,29 @@ if PY3:
else:
def byte_string(s):
return s
+
+_TEXT_WARNING = (
+ text_type.__name__ + " for {0} is no longer accepted, use bytes"
+)
+
+def text_to_bytes_and_warn(label, obj):
+ """
+ If ``obj`` is text, emit a warning that it should be bytes instead and try
+ to convert it to bytes automatically.
+
+ :param str label: The name of the parameter from which ``obj`` was taken
+ (so a developer can easily find the source of the problem and correct
+ it).
+
+ :return: If ``obj`` is the text string type, a ``bytes`` object giving the
+ UTF-8 encoding of that text is returned. Otherwise, ``obj`` itself is
+ returned.
+ """
+ if isinstance(obj, text_type):
+ warn(
+ _TEXT_WARNING.format(label),
+ category=DeprecationWarning,
+ stacklevel=3
+ )
+ return obj.encode('utf-8')
+ return obj
diff --git a/OpenSSL/crypto.py b/OpenSSL/crypto.py
index 0e99576..f9e189c 100644
--- a/OpenSSL/crypto.py
+++ b/OpenSSL/crypto.py
@@ -13,7 +13,9 @@ from OpenSSL._util import (
lib as _lib,
exception_from_error_queue as _exception_from_error_queue,
byte_string as _byte_string,
- native as _native)
+ native as _native,
+ text_to_bytes_and_warn as _text_to_bytes_and_warn,
+)
FILETYPE_PEM = _lib.SSL_FILETYPE_PEM
FILETYPE_ASN1 = _lib.SSL_FILETYPE_ASN1
@@ -2074,6 +2076,8 @@ class PKCS12(object):
:return: The string containing the PKCS12
"""
+ passphrase = _text_to_bytes_and_warn("passphrase", passphrase)
+
if self._cacerts is None:
cacerts = _ffi.NULL
else:
@@ -2371,6 +2375,8 @@ def sign(pkey, data, digest):
:param digest: message digest to use
:return: signature
"""
+ data = _text_to_bytes_and_warn("data", data)
+
digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest))
if digest_obj == _ffi.NULL:
raise ValueError("No such digest method")
@@ -2405,6 +2411,8 @@ def verify(cert, signature, data, digest):
:param digest: message digest to use
:return: None if the signature is correct, raise exception otherwise
"""
+ data = _text_to_bytes_and_warn("data", data)
+
digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest))
if digest_obj == _ffi.NULL:
raise ValueError("No such digest method")
@@ -2495,6 +2503,8 @@ def load_pkcs12(buffer, passphrase=None):
:param passphrase: (Optional) The password to decrypt the PKCS12 lump
:returns: The PKCS12 object
"""
+ passphrase = _text_to_bytes_and_warn("passphrase", passphrase)
+
if isinstance(buffer, _text_type):
buffer = buffer.encode("ascii")
diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py
index ca54176..dea5858 100644
--- a/OpenSSL/test/test_crypto.py
+++ b/OpenSSL/test/test_crypto.py
@@ -13,7 +13,9 @@ import re
from subprocess import PIPE, Popen
from datetime import datetime, timedelta
-from six import u, b, binary_type
+from six import u, b, binary_type, PY3
+from warnings import simplefilter
+from warnings import catch_warnings
from OpenSSL.crypto import TYPE_RSA, TYPE_DSA, Error, PKey, PKeyType
from OpenSSL.crypto import X509, X509Type, X509Name, X509NameType
@@ -30,7 +32,9 @@ from OpenSSL.crypto import CRL, Revoked, load_crl
from OpenSSL.crypto import NetscapeSPKI, NetscapeSPKIType
from OpenSSL.crypto import (
sign, verify, get_elliptic_curve, get_elliptic_curves)
-from OpenSSL.test.util import EqualityTestsMixin, TestCase
+from OpenSSL.test.util import (
+ EqualityTestsMixin, TestCase, WARNING_TYPE_EXPECTED
+)
from OpenSSL._util import native, lib
def normalize_certificate_pem(pem):
@@ -2061,6 +2065,31 @@ class PKCS12Tests(TestCase):
self.verify_pkcs12_container(p12)
+ def test_load_pkcs12_text_passphrase(self):
+ """
+ A PKCS12 string generated using the openssl command line can be loaded
+ with :py:obj:`load_pkcs12` and its components extracted and examined.
+ Using text as passphrase instead of bytes. DeprecationWarning expected.
+ """
+ pem = client_key_pem + client_cert_pem
+ passwd = b"whatever"
+ p12_str = _runopenssl(pem, b"pkcs12", b"-export", b"-clcerts",
+ b"-passout", b"pass:" + passwd)
+ with catch_warnings(record=True) as w:
+ simplefilter("always")
+ p12 = load_pkcs12(p12_str, passphrase=b"whatever".decode("ascii"))
+
+ self.assertEqual(
+ "{0} for passphrase is no longer accepted, use bytes".format(
+ WARNING_TYPE_EXPECTED
+ ),
+ str(w[-1].message)
+ )
+ self.assertIs(w[-1].category, DeprecationWarning)
+
+ self.verify_pkcs12_container(p12)
+
+
def test_load_pkcs12_no_passphrase(self):
"""
A PKCS12 string generated using openssl command line can be loaded with
@@ -2262,6 +2291,26 @@ class PKCS12Tests(TestCase):
dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=b"")
+ def test_export_without_bytes(self):
+ """
+ Test :py:obj:`PKCS12.export` with text not bytes as passphrase
+ """
+ p12 = self.gen_pkcs12(server_cert_pem, server_key_pem, root_cert_pem)
+
+ with catch_warnings(record=True) as w:
+ simplefilter("always")
+ dumped_p12 = p12.export(passphrase=b"randomtext".decode("ascii"))
+ self.assertEqual(
+ "{0} for passphrase is no longer accepted, use bytes".format(
+ WARNING_TYPE_EXPECTED
+ ),
+ str(w[-1].message)
+ )
+ self.assertIs(w[-1].category, DeprecationWarning)
+ self.check_recovery(
+ dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=b"randomtext")
+
+
def test_key_cert_mismatch(self):
"""
:py:obj:`PKCS12.export` raises an exception when a key and certificate
@@ -3321,6 +3370,47 @@ class SignVerifyTests(TestCase):
ValueError, verify, good_cert, sig, content, "strange-digest")
+ def test_sign_verify_with_text(self):
+ """
+ :py:obj:`sign` generates a cryptographic signature which :py:obj:`verify` can check.
+ Deprecation warnings raised because using text instead of bytes as content
+ """
+ content = (
+ b"It was a bright cold day in April, and the clocks were striking "
+ b"thirteen. Winston Smith, his chin nuzzled into his breast in an "
+ b"effort to escape the vile wind, slipped quickly through the "
+ b"glass doors of Victory Mansions, though not quickly enough to "
+ b"prevent a swirl of gritty dust from entering along with him."
+ ).decode("ascii")
+
+ priv_key = load_privatekey(FILETYPE_PEM, root_key_pem)
+ cert = load_certificate(FILETYPE_PEM, root_cert_pem)
+ for digest in ['md5', 'sha1']:
+ with catch_warnings(record=True) as w:
+ simplefilter("always")
+ sig = sign(priv_key, content, digest)
+
+ self.assertEqual(
+ "{0} for data is no longer accepted, use bytes".format(
+ WARNING_TYPE_EXPECTED
+ ),
+ str(w[-1].message)
+ )
+ self.assertIs(w[-1].category, DeprecationWarning)
+
+ with catch_warnings(record=True) as w:
+ simplefilter("always")
+ verify(cert, sig, content, digest)
+
+ self.assertEqual(
+ "{0} for data is no longer accepted, use bytes".format(
+ WARNING_TYPE_EXPECTED
+ ),
+ str(w[-1].message)
+ )
+ self.assertIs(w[-1].category, DeprecationWarning)
+
+
def test_sign_nulls(self):
"""
:py:obj:`sign` produces a signature for a string with embedded nulls.
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index caa9d02..2055b6d 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -7,12 +7,13 @@ Unit tests for :py:obj:`OpenSSL.SSL`.
from gc import collect, get_referrers
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
-from sys import platform, version_info, getfilesystemencoding
+from sys import platform, getfilesystemencoding
from socket import SHUT_RDWR, error, socket
from os import makedirs
from os.path import join
from unittest import main
from weakref import ref
+from warnings import catch_warnings, simplefilter
from six import PY3, text_type, u
@@ -44,10 +45,9 @@ from OpenSSL.SSL import (
from OpenSSL._util import lib as _lib
-from OpenSSL.test.util import NON_ASCII, TestCase, b
-from OpenSSL.test.test_crypto import (
- cleartextCertificatePEM, cleartextPrivateKeyPEM)
+from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b
from OpenSSL.test.test_crypto import (
+ cleartextCertificatePEM, cleartextPrivateKeyPEM,
client_cert_pem, client_key_pem, server_cert_pem, server_key_pem,
root_cert_pem)
@@ -1811,6 +1811,210 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
+class ApplicationLayerProtoNegotiationTests(TestCase, _LoopbackMixin):
+ """
+ Tests for ALPN in PyOpenSSL.
+ """
+ # Skip tests on versions that don't support ALPN.
+ if _lib.Cryptography_HAS_ALPN:
+
+ def test_alpn_success(self):
+ """
+ Clients and servers that agree on the negotiated ALPN protocol can
+ correct establish a connection, and the agreed protocol is reported
+ by the connections.
+ """
+ select_args = []
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b'spdy/2'
+
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_alpn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args)
+
+ self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2')
+ self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2')
+
+
+ def test_alpn_set_on_connection(self):
+ """
+ The same as test_alpn_success, but setting the ALPN protocols on
+ the connection rather than the context.
+ """
+ select_args = []
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b'spdy/2'
+
+ # Setup the client context but don't set any ALPN protocols.
+ client_context = Context(TLSv1_METHOD)
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_alpn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ # Set the ALPN protocols on the client connection.
+ client = Connection(client_context, None)
+ client.set_alpn_protos([b'http/1.1', b'spdy/2'])
+ client.set_connect_state()
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args)
+
+ self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2')
+ self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2')
+
+
+ def test_alpn_server_fail(self):
+ """
+ When clients and servers cannot agree on what protocol to use next
+ the TLS connection does not get established.
+ """
+ select_args = []
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b''
+
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_alpn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ # If the client doesn't return anything, the connection will fail.
+ self.assertRaises(Error, self._interactInMemory, server, client)
+
+ self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args)
+
+
+ def test_alpn_no_server(self):
+ """
+ When clients and servers cannot agree on what protocol to use next
+ because the server doesn't offer ALPN, no protocol is negotiated.
+ """
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
+
+ server_context = Context(TLSv1_METHOD)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ # Do the dance.
+ self._interactInMemory(server, client)
+
+ self.assertEqual(client.get_alpn_proto_negotiated(), b'')
+
+
+ def test_alpn_callback_exception(self):
+ """
+ We can handle exceptions in the ALPN select callback.
+ """
+ select_args = []
+ def select(conn, options):
+ select_args.append((conn, options))
+ raise TypeError()
+
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_alpn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ self.assertRaises(
+ TypeError, self._interactInMemory, server, client
+ )
+ self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args)
+
+ else:
+ # No ALPN.
+ def test_alpn_not_implemented(self):
+ """
+ If ALPN is not in OpenSSL, we should raise NotImplementedError.
+ """
+ # Test the context methods first.
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(
+ NotImplementedError, context.set_alpn_protos, None
+ )
+ self.assertRaises(
+ NotImplementedError, context.set_alpn_select_callback, None
+ )
+
+ # Now test a connection.
+ conn = Connection(context)
+ self.assertRaises(
+ NotImplementedError, context.set_alpn_protos, None
+ )
+
+
+
class SessionTests(TestCase):
"""
Unit tests for :py:obj:`OpenSSL.SSL.Session`.
@@ -1931,7 +2135,7 @@ class ConnectionTests(TestCase, _LoopbackMixin):
self.assertRaises(
TypeError, conn.set_tlsext_host_name, b("with\0null"))
- if version_info >= (3,):
+ if PY3:
# On Python 3.x, don't accidentally implicitly convert from text.
self.assertRaises(
TypeError,
@@ -2572,6 +2776,26 @@ class ConnectionSendTests(TestCase, _LoopbackMixin):
self.assertEquals(count, 2)
self.assertEquals(client.recv(2), b('xy'))
+
+ def test_text(self):
+ """
+ When passed a text, :py:obj:`Connection.send` transmits all of it and
+ returns the number of bytes sent. It also raises a DeprecationWarning.
+ """
+ server, client = self._loopback()
+ with catch_warnings(record=True) as w:
+ simplefilter("always")
+ count = server.send(b"xy".decode("ascii"))
+ self.assertEqual(
+ "{0} for buf is no longer accepted, use bytes".format(
+ WARNING_TYPE_EXPECTED
+ ),
+ str(w[-1].message)
+ )
+ self.assertIs(w[-1].category, DeprecationWarning)
+ self.assertEquals(count, 2)
+ self.assertEquals(client.recv(2), b"xy")
+
try:
memoryview
except NameError:
@@ -2792,6 +3016,25 @@ class ConnectionSendallTests(TestCase, _LoopbackMixin):
self.assertEquals(client.recv(1), b('x'))
+ def test_text(self):
+ """
+ :py:obj:`Connection.sendall` transmits all the content in the string
+ passed to it raising a DeprecationWarning in case of this being a text.
+ """
+ server, client = self._loopback()
+ with catch_warnings(record=True) as w:
+ simplefilter("always")
+ server.sendall(b"x".decode("ascii"))
+ self.assertEqual(
+ "{0} for buf is no longer accepted, use bytes".format(
+ WARNING_TYPE_EXPECTED
+ ),
+ str(w[-1].message)
+ )
+ self.assertIs(w[-1].category, DeprecationWarning)
+ self.assertEquals(client.recv(1), b"x")
+
+
try:
memoryview
except NameError:
diff --git a/OpenSSL/test/util.py b/OpenSSL/test/util.py
index b24a73b..b8be91d 100644
--- a/OpenSSL/test/util.py
+++ b/OpenSSL/test/util.py
@@ -14,6 +14,8 @@ from tempfile import mktemp
from unittest import TestCase
import sys
+from six import PY3
+
from OpenSSL._util import exception_from_error_queue
from OpenSSL.crypto import Error
@@ -452,3 +454,10 @@ class EqualityTestsMixin(object):
a = self.anInstance()
b = Delegate()
self.assertEqual(a != b, [b])
+
+
+# The type name expected in warnings about using the wrong string type.
+if PY3:
+ WARNING_TYPE_EXPECTED = "str"
+else:
+ WARNING_TYPE_EXPECTED = "unicode"
diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst
index e6a0775..2929305 100644
--- a/doc/api/ssl.rst
+++ b/doc/api/ssl.rst
@@ -498,6 +498,26 @@ Context objects have the following methods:
.. versionadded:: 0.15
+.. py:method:: Context.set_alpn_protos(protos)
+
+ Specify the protocols that the client is prepared to speak after the TLS
+ connection has been negotiated using Application Layer Protocol
+ Negotiation.
+
+ *protos* should be a list of protocols that the client is offering, each
+ as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``.
+
+
+.. py:method:: Context.set_alpn_select_callback(callback)
+
+ Specify a callback function that will be called on the server when a client
+ offers protocols using Application Layer Protocol Negotiation.
+
+ *callback* should be the callback function. It will be invoked with two
+ arguments: the :py:class:`Connection` and a list of offered protocols as
+ bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of
+ these bytestrings, the chosen protocol.
+
.. _openssl-session:
@@ -849,6 +869,22 @@ Connection objects have the following methods:
.. versionadded:: 0.15
+.. py:method:: Connection.set_alpn_protos(protos)
+
+ Specify the protocols that the client is prepared to speak after the TLS
+ connection has been negotiated using Application Layer Protocol
+ Negotiation.
+
+ *protos* should be a list of protocols that the client is offering, each
+ as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``.
+
+
+.. py:method:: Connection.get_alpn_proto_negotiated()
+
+ Get the protocol that was negotiated by Application Layer Protocol
+ Negotiation. Returns a bytestring of the protocol name. If no protocol has
+ been negotiated yet, returns an empty string.
+
.. Rubric:: Footnotes