diff options
author | Jean-Paul Calderone <exarkun@twistedmatrix.com> | 2015-04-13 21:43:33 -0400 |
---|---|---|
committer | Jean-Paul Calderone <exarkun@twistedmatrix.com> | 2015-04-13 21:43:33 -0400 |
commit | f0e74563a5c4309e6c27e46bf29ca333a04d5210 (patch) | |
tree | 02bdce954b5c2a00b65d831c47d4974ff36b5782 | |
parent | 9be8519b76c27ec26159dfc8e7747367fe609ff1 (diff) | |
parent | 5c3b748846ad1f9597d51b24d04ac394980c2480 (diff) | |
download | pyopenssl-f0e74563a5c4309e6c27e46bf29ca333a04d5210.tar.gz |
merge master
-rw-r--r-- | .travis.yml | 78 | ||||
-rw-r--r-- | OpenSSL/SSL.py | 166 | ||||
-rw-r--r-- | OpenSSL/_util.py | 27 | ||||
-rw-r--r-- | OpenSSL/crypto.py | 9 | ||||
-rw-r--r-- | OpenSSL/test/test_crypto.py | 94 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 254 | ||||
-rw-r--r-- | OpenSSL/test/util.py | 9 | ||||
-rw-r--r-- | doc/api/ssl.rst | 36 |
8 files changed, 637 insertions, 36 deletions
diff --git a/.travis.yml b/.travis.yml index f359de1..eccc720 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,18 +1,23 @@ +sudo: false language: python -os: - - linux - -python: - - "pypy" - - "2.6" - - "2.7" - - "3.2" - - "3.3" - - "3.4" - matrix: include: + - language: generic + os: osx + env: TOXENV=py27 + - python: "2.6" # these are just to make travis's UI a bit prettier + env: TOXENV=py26 + - python: "2.7" + env: TOXENV=py27 + - python: "3.2" + env: TOXENV=py32 + - python: "3.3" + env: TOXENV=py33 + - python: "3.4" + env: TOXENV=py34 + - python: "pypy" + env: TOXENV=pypy # Also run the tests against cryptography master. - python: "2.6" env: @@ -37,10 +42,21 @@ matrix: - python: "2.7" env: OPENSSL=0.9.8 + addons: + apt: + sources: + - lucid + packages: + - libssl-dev/lucid # Let the cryptography master builds fail because they might be triggered by # cryptography changes beyond our control. + # Also allow OS X and 0.9.8 to fail at the moment while we fix these new + # build configurations. allow_failures: + - language: generic + os: osx + env: TOXENV=py27 - env: CRYPTOGRAPHY_GIT_MASTER=true - env: @@ -50,34 +66,40 @@ before_install: - if [ -n "$CRYPTOGRAPHY_GIT_MASTER" ]; then pip install git+https://github.com/pyca/cryptography.git;fi install: - # Install the wheel library explicitly here. It is not really a setup - # dependency. It is not an install dependency. It is only a dependency for - # the script directive below - because we want to exercise wheel building on - # travis. - - pip install wheel - - # Also install some tools for measuring code coverage and sending the results - # to coveralls. - - pip install coveralls coverage + - | + if [[ "$(uname -s)" == 'Darwin' ]]; then + brew update + brew upgrade openssl + curl -O https://bootstrap.pypa.io/get-pip.py + python get-pip.py --user + pip install --user virtualenv + else + pip install virtualenv + fi + python -m virtualenv ~/.venv + source ~/.venv/bin/activate + pip install coveralls coverage wheel script: - | - if [[ "${OPENSSL}" == "0.9.8" ]]; then - sudo add-apt-repository "deb http://archive.ubuntu.com/ubuntu/ lucid main" - sudo apt-get -y update - sudo apt-get install -y --force-yes libssl-dev/lucid + if [[ "$(uname -s)" == "Darwin" ]]; then + # set our flags to use homebrew openssl + export ARCHFLAGS="-arch x86_64" + export LDFLAGS="-L/usr/local/opt/openssl/lib" + export CFLAGS="-I/usr/local/opt/openssl/include" fi - - | + source ~/.venv/bin/activate pip install -e . - | + source ~/.venv/bin/activate coverage run --branch --source=OpenSSL setup.py bdist_wheel test - | + source ~/.venv/bin/activate coverage report -m - - | python -c "import OpenSSL.SSL; print(OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION))" -after_success: - - coveralls +after_script: + - source ~/.venv/bin/activate && coveralls notifications: email: false diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 6dba10c..cfe705a 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -5,6 +5,7 @@ from weakref import WeakValueDictionary from errno import errorcode from six import text_type as _text_type +from six import binary_type as _binary_type from six import integer_types as integer_types from six import int2byte, indexbytes @@ -13,6 +14,7 @@ from OpenSSL._util import ( lib as _lib, exception_from_error_queue as _exception_from_error_queue, native as _native, + text_to_bytes_and_warn as _text_to_bytes_and_warn, path_string as _path_string, UNSPECIFIED as _UNSPECIFIED, ) @@ -317,6 +319,56 @@ class _NpnSelectHelper(_CallbackExceptionHelper): ) +class _ALPNSelectHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an ALPN selection callback. + """ + def __init__(self, callback): + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + encoded_len = indexbytes(instr, 0) + proto = instr[1:encoded_len + 1] + protolist.append(proto) + instr = instr[encoded_len + 1:] + + # Call the callback + outstr = callback(conn, protolist) + + if not isinstance(outstr, _binary_type): + raise TypeError("ALPN callback must return a bytestring.") + + # Save our callback arguments on the connection object to make + # sure that they don't get freed before OpenSSL can use them. + # Then, return them in the appropriate output parameters. + conn._alpn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._alpn_select_callback_args[0][0] + out[0] = conn._alpn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + + def _asFileDescriptor(obj): fd = None if not isinstance(obj, integer_types): @@ -347,6 +399,22 @@ def SSLeay_version(type): +def _requires_alpn(func): + """ + Wraps any function that requires ALPN support in OpenSSL, ensuring that + NotImplementedError is raised if ALPN support is not present. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if not _lib.Cryptography_HAS_ALPN: + raise NotImplementedError("ALPN not available.") + + return func(*args, **kwargs) + + return wrapper + + + class Session(object): pass @@ -408,6 +476,8 @@ class Context(object): self._npn_advertise_callback = None self._npn_select_helper = None self._npn_select_callback = None + self._alpn_select_helper = None + self._alpn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); # SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | @@ -922,7 +992,6 @@ class Context(object): _lib.SSL_CTX_set_tlsext_servername_callback( self._context, self._tlsext_servername_callback) - def set_npn_advertise_callback(self, callback): """ Specify a callback function that will be called when offering `Next @@ -955,6 +1024,44 @@ class Context(object): _lib.SSL_CTX_set_next_proto_select_cb( self._context, self._npn_select_callback, _ffi.NULL) + @_requires_alpn + def set_alpn_protos(self, protos): + """ + Specify the clients ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + input_str_len = _ffi.cast("unsigned", len(protostr)) + _lib.SSL_CTX_set_alpn_protos(self._context, input_str, input_str_len) + + @_requires_alpn + def set_alpn_select_callback(self, callback): + """ + Set the callback to handle ALPN protocol choice. + + :param callback: The callback function. It will be invoked with two + arguments: the Connection, and a list of offered protocols as + bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return + one of those bytestrings, the chosen protocol. + """ + self._alpn_select_helper = _ALPNSelectHelper(callback) + self._alpn_select_callback = self._alpn_select_helper.callback + _lib.SSL_CTX_set_alpn_select_cb( + self._context, self._alpn_select_callback, _ffi.NULL) + ContextType = Context @@ -986,6 +1093,12 @@ class Connection(object): self._npn_advertise_callback_args = None self._npn_select_callback_args = None + # References to strings used for Application Layer Protocol + # Negotiation. These strings get copied at some point but it's well + # after the callback returns, so we have to hang them somewhere to + # avoid them getting freed. + self._alpn_select_callback_args = None + self._reverse_mapping[self._ssl] = self if socket is None: @@ -1024,6 +1137,8 @@ class Connection(object): self._context._npn_advertise_helper.raise_if_problem() if self._context._npn_select_helper is not None: self._context._npn_select_helper.raise_if_problem() + if self._context._alpn_select_helper is not None: + self._context._alpn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: @@ -1124,6 +1239,9 @@ class Connection(object): API, the value is ignored :return: The number of bytes written """ + # Backward compatibility + buf = _text_to_bytes_and_warn("buf", buf) + if isinstance(buf, _memoryview): buf = buf.tobytes() if isinstance(buf, _buffer): @@ -1148,6 +1266,8 @@ class Connection(object): API, the value is ignored :return: The number of bytes written """ + buf = _text_to_bytes_and_warn("buf", buf) + if isinstance(buf, _memoryview): buf = buf.tobytes() if isinstance(buf, _buffer): @@ -1272,6 +1392,8 @@ class Connection(object): :param buf: The string to put into the memory BIO. :return: The number of bytes written """ + buf = _text_to_bytes_and_warn("buf", buf) + if self._into_ssl is None: raise TypeError("Connection sock was not None") @@ -1756,6 +1878,48 @@ class Connection(object): return _ffi.buffer(data[0], data_len[0])[:] + @_requires_alpn + def set_alpn_protos(self, protos): + """ + Specify the client's ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + input_str_len = _ffi.cast("unsigned", len(protostr)) + _lib.SSL_set_alpn_protos(self._ssl, input_str, input_str_len) + + + def get_alpn_proto_negotiated(self): + """ + Get the protocol that was negotiated by ALPN. + """ + if not _lib.Cryptography_HAS_ALPN: + raise NotImplementedError("ALPN not available") + + data = _ffi.new("unsigned char **") + data_len = _ffi.new("unsigned int *") + + _lib.SSL_get0_alpn_selected(self._ssl, data, data_len) + + if not data_len: + return b'' + + return _ffi.buffer(data[0], data_len[0])[:] + + ConnectionType = Connection diff --git a/OpenSSL/_util.py b/OpenSSL/_util.py index a4fe63a..0cc34d8 100644 --- a/OpenSSL/_util.py +++ b/OpenSSL/_util.py @@ -1,3 +1,4 @@ +from warnings import warn import sys from six import PY3, binary_type, text_type @@ -98,3 +99,29 @@ else: # A marker object to observe whether some optional arguments are passed any # value or not. UNSPECIFIED = object() + +_TEXT_WARNING = ( + text_type.__name__ + " for {0} is no longer accepted, use bytes" +) + +def text_to_bytes_and_warn(label, obj): + """ + If ``obj`` is text, emit a warning that it should be bytes instead and try + to convert it to bytes automatically. + + :param str label: The name of the parameter from which ``obj`` was taken + (so a developer can easily find the source of the problem and correct + it). + + :return: If ``obj`` is the text string type, a ``bytes`` object giving the + UTF-8 encoding of that text is returned. Otherwise, ``obj`` itself is + returned. + """ + if isinstance(obj, text_type): + warn( + _TEXT_WARNING.format(label), + category=DeprecationWarning, + stacklevel=3 + ) + return obj.encode('utf-8') + return obj diff --git a/OpenSSL/crypto.py b/OpenSSL/crypto.py index ba83efe..0350537 100644 --- a/OpenSSL/crypto.py +++ b/OpenSSL/crypto.py @@ -16,6 +16,7 @@ from OpenSSL._util import ( byte_string as _byte_string, native as _native, UNSPECIFIED as _UNSPECIFIED, + text_to_bytes_and_warn as _text_to_bytes_and_warn, ) FILETYPE_PEM = _lib.SSL_FILETYPE_PEM @@ -2097,6 +2098,8 @@ class PKCS12(object): :return: The string containing the PKCS12 """ + passphrase = _text_to_bytes_and_warn("passphrase", passphrase) + if self._cacerts is None: cacerts = _ffi.NULL else: @@ -2394,6 +2397,8 @@ def sign(pkey, data, digest): :param digest: message digest to use :return: signature """ + data = _text_to_bytes_and_warn("data", data) + digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest)) if digest_obj == _ffi.NULL: raise ValueError("No such digest method") @@ -2428,6 +2433,8 @@ def verify(cert, signature, data, digest): :param digest: message digest to use :return: None if the signature is correct, raise exception otherwise """ + data = _text_to_bytes_and_warn("data", data) + digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest)) if digest_obj == _ffi.NULL: raise ValueError("No such digest method") @@ -2518,6 +2525,8 @@ def load_pkcs12(buffer, passphrase=None): :param passphrase: (Optional) The password to decrypt the PKCS12 lump :returns: The PKCS12 object """ + passphrase = _text_to_bytes_and_warn("passphrase", passphrase) + if isinstance(buffer, _text_type): buffer = buffer.encode("ascii") diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py index ae973ab..f6f0751 100644 --- a/OpenSSL/test/test_crypto.py +++ b/OpenSSL/test/test_crypto.py @@ -14,7 +14,9 @@ import re from subprocess import PIPE, Popen from datetime import datetime, timedelta -from six import u, b, binary_type +from six import u, b, binary_type, PY3 +from warnings import simplefilter +from warnings import catch_warnings from OpenSSL.crypto import TYPE_RSA, TYPE_DSA, Error, PKey, PKeyType from OpenSSL.crypto import X509, X509Type, X509Name, X509NameType @@ -31,7 +33,9 @@ from OpenSSL.crypto import CRL, Revoked, load_crl from OpenSSL.crypto import NetscapeSPKI, NetscapeSPKIType from OpenSSL.crypto import ( sign, verify, get_elliptic_curve, get_elliptic_curves) -from OpenSSL.test.util import EqualityTestsMixin, TestCase +from OpenSSL.test.util import ( + EqualityTestsMixin, TestCase, WARNING_TYPE_EXPECTED +) from OpenSSL._util import native, lib def normalize_certificate_pem(pem): @@ -2062,6 +2066,31 @@ class PKCS12Tests(TestCase): self.verify_pkcs12_container(p12) + def test_load_pkcs12_text_passphrase(self): + """ + A PKCS12 string generated using the openssl command line can be loaded + with :py:obj:`load_pkcs12` and its components extracted and examined. + Using text as passphrase instead of bytes. DeprecationWarning expected. + """ + pem = client_key_pem + client_cert_pem + passwd = b"whatever" + p12_str = _runopenssl(pem, b"pkcs12", b"-export", b"-clcerts", + b"-passout", b"pass:" + passwd) + with catch_warnings(record=True) as w: + simplefilter("always") + p12 = load_pkcs12(p12_str, passphrase=b"whatever".decode("ascii")) + + self.assertEqual( + "{0} for passphrase is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ), + str(w[-1].message) + ) + self.assertIs(w[-1].category, DeprecationWarning) + + self.verify_pkcs12_container(p12) + + def test_load_pkcs12_no_passphrase(self): """ A PKCS12 string generated using openssl command line can be loaded with @@ -2263,6 +2292,26 @@ class PKCS12Tests(TestCase): dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=b"") + def test_export_without_bytes(self): + """ + Test :py:obj:`PKCS12.export` with text not bytes as passphrase + """ + p12 = self.gen_pkcs12(server_cert_pem, server_key_pem, root_cert_pem) + + with catch_warnings(record=True) as w: + simplefilter("always") + dumped_p12 = p12.export(passphrase=b"randomtext".decode("ascii")) + self.assertEqual( + "{0} for passphrase is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ), + str(w[-1].message) + ) + self.assertIs(w[-1].category, DeprecationWarning) + self.check_recovery( + dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=b"randomtext") + + def test_key_cert_mismatch(self): """ :py:obj:`PKCS12.export` raises an exception when a key and certificate @@ -3417,6 +3466,47 @@ class SignVerifyTests(TestCase): ValueError, verify, good_cert, sig, content, "strange-digest") + def test_sign_verify_with_text(self): + """ + :py:obj:`sign` generates a cryptographic signature which :py:obj:`verify` can check. + Deprecation warnings raised because using text instead of bytes as content + """ + content = ( + b"It was a bright cold day in April, and the clocks were striking " + b"thirteen. Winston Smith, his chin nuzzled into his breast in an " + b"effort to escape the vile wind, slipped quickly through the " + b"glass doors of Victory Mansions, though not quickly enough to " + b"prevent a swirl of gritty dust from entering along with him." + ).decode("ascii") + + priv_key = load_privatekey(FILETYPE_PEM, root_key_pem) + cert = load_certificate(FILETYPE_PEM, root_cert_pem) + for digest in ['md5', 'sha1']: + with catch_warnings(record=True) as w: + simplefilter("always") + sig = sign(priv_key, content, digest) + + self.assertEqual( + "{0} for data is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ), + str(w[-1].message) + ) + self.assertIs(w[-1].category, DeprecationWarning) + + with catch_warnings(record=True) as w: + simplefilter("always") + verify(cert, sig, content, digest) + + self.assertEqual( + "{0} for data is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ), + str(w[-1].message) + ) + self.assertIs(w[-1].category, DeprecationWarning) + + def test_sign_nulls(self): """ :py:obj:`sign` produces a signature for a string with embedded nulls. diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index c82dea6..4dedb6b 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -7,12 +7,13 @@ Unit tests for :py:obj:`OpenSSL.SSL`. from gc import collect, get_referrers from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN -from sys import platform, version_info, getfilesystemencoding +from sys import platform, getfilesystemencoding from socket import SHUT_RDWR, error, socket from os import makedirs from os.path import join from unittest import main from weakref import ref +from warnings import catch_warnings, simplefilter from six import PY3, text_type, u @@ -42,10 +43,10 @@ from OpenSSL.SSL import ( from OpenSSL.SSL import ( Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) -from OpenSSL.test.util import NON_ASCII, TestCase, b -from OpenSSL.test.test_crypto import ( - cleartextCertificatePEM, cleartextPrivateKeyPEM) +from OpenSSL._util import lib as _lib +from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b from OpenSSL.test.test_crypto import ( + cleartextCertificatePEM, cleartextPrivateKeyPEM, client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, root_cert_pem) @@ -1782,6 +1783,210 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin): +class ApplicationLayerProtoNegotiationTests(TestCase, _LoopbackMixin): + """ + Tests for ALPN in PyOpenSSL. + """ + # Skip tests on versions that don't support ALPN. + if _lib.Cryptography_HAS_ALPN: + + def test_alpn_success(self): + """ + Clients and servers that agree on the negotiated ALPN protocol can + correct establish a connection, and the agreed protocol is reported + by the connections. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2') + + + def test_alpn_set_on_connection(self): + """ + The same as test_alpn_success, but setting the ALPN protocols on + the connection rather than the context. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' + + # Setup the client context but don't set any ALPN protocols. + client_context = Context(TLSv1_METHOD) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + # Set the ALPN protocols on the client connection. + client = Connection(client_context, None) + client.set_alpn_protos([b'http/1.1', b'spdy/2']) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2') + + + def test_alpn_server_fail(self): + """ + When clients and servers cannot agree on what protocol to use next + the TLS connection does not get established. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + return b'' + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises(Error, self._interactInMemory, server, client) + + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + + def test_alpn_no_server(self): + """ + When clients and servers cannot agree on what protocol to use next + because the server doesn't offer ALPN, no protocol is negotiated. + """ + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # Do the dance. + self._interactInMemory(server, client) + + self.assertEqual(client.get_alpn_proto_negotiated(), b'') + + + def test_alpn_callback_exception(self): + """ + We can handle exceptions in the ALPN select callback. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + raise TypeError() + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + else: + # No ALPN. + def test_alpn_not_implemented(self): + """ + If ALPN is not in OpenSSL, we should raise NotImplementedError. + """ + # Test the context methods first. + context = Context(TLSv1_METHOD) + self.assertRaises( + NotImplementedError, context.set_alpn_protos, None + ) + self.assertRaises( + NotImplementedError, context.set_alpn_select_callback, None + ) + + # Now test a connection. + conn = Connection(context) + self.assertRaises( + NotImplementedError, context.set_alpn_protos, None + ) + + + class SessionTests(TestCase): """ Unit tests for :py:obj:`OpenSSL.SSL.Session`. @@ -1902,7 +2107,7 @@ class ConnectionTests(TestCase, _LoopbackMixin): self.assertRaises( TypeError, conn.set_tlsext_host_name, b("with\0null")) - if version_info >= (3,): + if PY3: # On Python 3.x, don't accidentally implicitly convert from text. self.assertRaises( TypeError, @@ -2543,6 +2748,26 @@ class ConnectionSendTests(TestCase, _LoopbackMixin): self.assertEquals(count, 2) self.assertEquals(client.recv(2), b('xy')) + + def test_text(self): + """ + When passed a text, :py:obj:`Connection.send` transmits all of it and + returns the number of bytes sent. It also raises a DeprecationWarning. + """ + server, client = self._loopback() + with catch_warnings(record=True) as w: + simplefilter("always") + count = server.send(b"xy".decode("ascii")) + self.assertEqual( + "{0} for buf is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ), + str(w[-1].message) + ) + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEquals(count, 2) + self.assertEquals(client.recv(2), b"xy") + try: memoryview except NameError: @@ -2763,6 +2988,25 @@ class ConnectionSendallTests(TestCase, _LoopbackMixin): self.assertEquals(client.recv(1), b('x')) + def test_text(self): + """ + :py:obj:`Connection.sendall` transmits all the content in the string + passed to it raising a DeprecationWarning in case of this being a text. + """ + server, client = self._loopback() + with catch_warnings(record=True) as w: + simplefilter("always") + server.sendall(b"x".decode("ascii")) + self.assertEqual( + "{0} for buf is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ), + str(w[-1].message) + ) + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEquals(client.recv(1), b"x") + + try: memoryview except NameError: diff --git a/OpenSSL/test/util.py b/OpenSSL/test/util.py index b24a73b..b8be91d 100644 --- a/OpenSSL/test/util.py +++ b/OpenSSL/test/util.py @@ -14,6 +14,8 @@ from tempfile import mktemp from unittest import TestCase import sys +from six import PY3 + from OpenSSL._util import exception_from_error_queue from OpenSSL.crypto import Error @@ -452,3 +454,10 @@ class EqualityTestsMixin(object): a = self.anInstance() b = Delegate() self.assertEqual(a != b, [b]) + + +# The type name expected in warnings about using the wrong string type. +if PY3: + WARNING_TYPE_EXPECTED = "str" +else: + WARNING_TYPE_EXPECTED = "unicode" diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index e6a0775..2929305 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -498,6 +498,26 @@ Context objects have the following methods: .. versionadded:: 0.15 +.. py:method:: Context.set_alpn_protos(protos) + + Specify the protocols that the client is prepared to speak after the TLS + connection has been negotiated using Application Layer Protocol + Negotiation. + + *protos* should be a list of protocols that the client is offering, each + as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``. + + +.. py:method:: Context.set_alpn_select_callback(callback) + + Specify a callback function that will be called on the server when a client + offers protocols using Application Layer Protocol Negotiation. + + *callback* should be the callback function. It will be invoked with two + arguments: the :py:class:`Connection` and a list of offered protocols as + bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of + these bytestrings, the chosen protocol. + .. _openssl-session: @@ -849,6 +869,22 @@ Connection objects have the following methods: .. versionadded:: 0.15 +.. py:method:: Connection.set_alpn_protos(protos) + + Specify the protocols that the client is prepared to speak after the TLS + connection has been negotiated using Application Layer Protocol + Negotiation. + + *protos* should be a list of protocols that the client is offering, each + as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``. + + +.. py:method:: Connection.get_alpn_proto_negotiated() + + Get the protocol that was negotiated by Application Layer Protocol + Negotiation. Returns a bytestring of the protocol name. If no protocol has + been negotiated yet, returns an empty string. + .. Rubric:: Footnotes |