summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHynek Schlawack <hs@ox.cx>2016-02-13 09:10:04 +0100
committerHynek Schlawack <hs@ox.cx>2016-03-16 13:31:51 +0100
commitb1f3ca86c0bb78a524bb9c677427bb8f1dd35323 (patch)
tree3545716b851252978f01198f935df87d37686138
parentb98d56999b10296fc5ba8b783ca427450cf5e0b4 (diff)
downloadpyopenssl-b1f3ca86c0bb78a524bb9c677427bb8f1dd35323.tar.gz
Implement Context.set_session_id
-rw-r--r--CHANGELOG.rst3
-rw-r--r--doc/api/ssl.rst18
-rw-r--r--src/OpenSSL/SSL.py44
-rw-r--r--tests/test_ssl.py154
4 files changed, 175 insertions, 44 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index d2f8f85..8939c94 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -41,6 +41,9 @@ Deprecations:
Changes:
^^^^^^^^
+- Fixed :meth:`OpenSSL.SSL.Context.set_session_id`, :meth:`OpenSSL.SSL.Connection.renegotiate`, :meth:`OpenSSL.SSL.Connection.renegotiate_pending`, and :meth:`OpenSSL.SSL.Context.load_client_ca`.
+ They were lacking an implementation since 0.14.
+ [`#422 <https://github.com/pyca/pyopenssl/pull/422>`_]
- Fixed segmentation fault when using keys larger than 4096-bit to sign data.
`#428 <https://github.com/pyca/pyopenssl/pull/428>`_
- Fixed ``AttributeError`` when ``OpenSSL.SSL.Connection.get_app_data()`` was called before setting any app data.
diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst
index d144cfd..0edf1ab 100644
--- a/doc/api/ssl.rst
+++ b/doc/api/ssl.rst
@@ -274,10 +274,7 @@ Context objects have the following methods:
Retrieve the Context object's verify mode, as set by :py:meth:`set_verify`.
-.. py:method:: Context.load_client_ca(pemfile)
-
- Read a file with PEM-formatted certificates that will be sent to the client
- when requesting a client certificate.
+.. automethod:: Context.load_client_ca
.. py:method:: Context.set_client_ca_list(certificate_authorities)
@@ -387,12 +384,7 @@ Context objects have the following methods:
.. versionadded:: 0.14
-.. py:method:: Context.set_session_id(name)
-
- Set the context *name* within which a session can be reused for this
- Context object. This is needed when doing session resumption, because there is
- no way for a stored session to know which Context object it is associated with.
- *name* may be any binary data.
+.. automethod:: Context.set_session_id
.. py:method:: Context.set_timeout(timeout)
@@ -684,11 +676,11 @@ Connection objects have the following methods:
bytes (for example, in response to a call to :py:meth:`recv`).
-.. py:method:: Connection.renegotiate()
+.. automethod:: Connection.renegotiate
- Renegotiate the SSL session. Call this if you wish to change cipher suites or
- anything like that.
+.. automethod:: Connection.renegotiate_pending
+.. automethod:: Connection.total_renegotiations
.. py:method:: Connection.send(string)
diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py
index 160cfd8..800ae1e 100644
--- a/src/OpenSSL/SSL.py
+++ b/src/OpenSSL/SSL.py
@@ -689,23 +689,39 @@ class Context(object):
def load_client_ca(self, cafile):
"""
- Load the trusted certificates that will be sent to the client
- (basically telling the client "These are the guys I trust"). Does not
- actually imply any of the certificates are trusted; that must be
+ Load the trusted certificates that will be sent to the client. Does
+ not actually imply any of the certificates are trusted; that must be
configured separately.
- :param cafile: The name of the certificates file
+ :param bytes cafile: The path to a certificates file in PEM format.
:return: None
"""
+ ca_list = _lib.SSL_load_client_CA_file(
+ _text_to_bytes_and_warn("cafile", cafile)
+ )
+ _openssl_assert(ca_list != _ffi.NULL)
+ # SSL_CTX_set_client_CA_list doesn't return anything.
+ _lib.SSL_CTX_set_client_CA_list(self._context, ca_list)
def set_session_id(self, buf):
"""
- Set the session identifier. This is needed if you want to do session
- resumption.
+ Set the session id to *buf* within which a session can be reused for
+ this Context object. This is needed when doing session resumption,
+ because there is no way for a stored session to know which Context
+ object it is associated with.
+
+ :param bytes buf: The session id.
- :param buf: A Python object that can be safely converted to a string
:returns: None
"""
+ buf = _text_to_bytes_and_warn("buf", buf)
+ _openssl_assert(
+ _lib.SSL_CTX_set_session_id_context(
+ self._context,
+ buf,
+ len(buf),
+ ) == 1
+ )
def set_session_cache_mode(self, mode):
"""
@@ -1406,10 +1422,15 @@ class Connection(object):
def renegotiate(self):
"""
- Renegotiate the session
+ Renegotiate the session.
- :return: True if the renegotiation can be started, false otherwise
+ :return: True if the renegotiation can be started, False otherwise
+ :rtype: bool
"""
+ if not self.renegotiate_pending():
+ _openssl_assert(_lib.SSL_renegotiate(self._ssl) == 1)
+ return True
+ return False
def do_handshake(self):
"""
@@ -1423,17 +1444,20 @@ class Connection(object):
def renegotiate_pending(self):
"""
- Check if there's a renegotiation in progress, it will return false once
+ Check if there's a renegotiation in progress, it will return False once
a renegotiation is finished.
:return: Whether there's a renegotiation in progress
+ :rtype: bool
"""
+ return _lib.SSL_renegotiate_pending(self._ssl) == 1
def total_renegotiations(self):
"""
Find out the total number of renegotiations.
:return: The number of renegotiations.
+ :rtype: int
"""
return _lib.SSL_total_renegotiations(self._ssl)
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index 8040c97..ab316fc 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -5,6 +5,9 @@
Unit tests for :mod:`OpenSSL.SSL`.
"""
+import datetime
+import uuid
+
from gc import collect, get_referrers
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
from sys import platform, getfilesystemencoding, version_info
@@ -19,6 +22,14 @@ import pytest
from six import PY3, text_type
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.x509.oid import NameOID
+
+
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
from OpenSSL.crypto import PKey, X509, X509Extension, X509Store
from OpenSSL.crypto import dump_privatekey, load_privatekey
@@ -348,6 +359,49 @@ class VersionTests(TestCase):
@pytest.fixture
+def ca_file(tmpdir):
+ """
+ Create a valid PEM file with CA certificates and return the path.
+ """
+ key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ backend=default_backend()
+ )
+ public_key = key.public_key()
+
+ builder = x509.CertificateBuilder()
+ builder = builder.subject_name(x509.Name([
+ x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"),
+ ]))
+ builder = builder.issuer_name(x509.Name([
+ x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"),
+ ]))
+ one_day = datetime.timedelta(1, 0, 0)
+ builder = builder.not_valid_before(datetime.datetime.today() - one_day)
+ builder = builder.not_valid_after(datetime.datetime.today() + one_day)
+ builder = builder.serial_number(int(uuid.uuid4()))
+ builder = builder.public_key(public_key)
+ builder = builder.add_extension(
+ x509.BasicConstraints(ca=True, path_length=None), critical=True,
+ )
+
+ certificate = builder.sign(
+ private_key=key, algorithm=hashes.SHA256(),
+ backend=default_backend()
+ )
+
+ ca_file = tmpdir.join("test.pem")
+ ca_file.write_binary(
+ certificate.public_bytes(
+ encoding=serialization.Encoding.PEM,
+ )
+ )
+
+ return str(ca_file).encode("ascii")
+
+
+@pytest.fixture
def context():
"""
A simple TLS 1.0 context.
@@ -389,6 +443,59 @@ class TestContext(object):
with pytest.raises(error):
context.set_cipher_list(cipher_list)
+ def test_load_client_ca(self, context, ca_file):
+ """
+ :meth:`Context.load_client_ca` works as far as we can tell.
+ """
+ context.load_client_ca(ca_file)
+
+ def test_load_client_ca_invalid(self, context, tmpdir):
+ """
+ :meth:`Context.load_client_ca` raises an Error if the ca file is
+ invalid.
+ """
+ ca_file = tmpdir.join("test.pem")
+ ca_file.write("")
+
+ with pytest.raises(Error) as e:
+ context.load_client_ca(str(ca_file).encode("ascii"))
+
+ assert "PEM routines" == e.value.args[0][0][0]
+
+ def test_load_client_ca_unicode(self, context, ca_file):
+ """
+ Passing the path as unicode raises a warning but works.
+ """
+ pytest.deprecated_call(
+ context.load_client_ca, ca_file.decode("ascii")
+ )
+
+ def test_set_session_id(self, context):
+ """
+ :meth:`Context.set_session_id` works as far as we can tell.
+ """
+ context.set_session_id(b"abc")
+
+ def test_set_session_id_fail(self, context):
+ """
+ :meth:`Context.set_session_id` errors are propagated.
+ """
+ with pytest.raises(Error) as e:
+ context.set_session_id(b"abc" * 1000)
+
+ assert [
+ ("SSL routines",
+ "SSL_CTX_set_session_id_context",
+ "ssl session id context too long")
+ ] == e.value.args[0]
+
+ def test_set_session_id_unicode(self, context):
+ """
+ :meth:`Context.set_session_id` raises a warning if a unicode string is
+ passed.
+ """
+ pytest.deprecated_call(context.set_session_id, u"abc")
+
class ContextTests(TestCase, _LoopbackMixin):
"""
@@ -1210,9 +1317,10 @@ class ContextTests(TestCase, _LoopbackMixin):
raise Exception("silly verify failure")
clientContext.set_verify(VERIFY_PEER, verify_callback)
- exc = self.assertRaises(
- Exception, self._handshake_test, serverContext, clientContext)
- self.assertEqual("silly verify failure", str(exc))
+ with pytest.raises(Exception) as exc:
+ self._handshake_test(serverContext, clientContext)
+
+ self.assertEqual("silly verify failure", str(exc.value))
def test_add_extra_chain_cert(self):
"""
@@ -1338,9 +1446,6 @@ class ContextTests(TestCase, _LoopbackMixin):
Error, context.use_certificate_chain_file, self.mktemp()
)
- # XXX load_client_ca
- # XXX set_session_id
-
def test_get_verify_mode_wrong_args(self):
"""
:py:obj:`Context.get_verify_mode` raises :py:obj:`TypeError` if called
@@ -2033,7 +2138,6 @@ class ConnectionTests(TestCase, _LoopbackMixin):
# XXX connect_ex -> TypeError
# XXX set_connect_state -> TypeError
# XXX set_accept_state -> TypeError
- # XXX renegotiate_pending
# XXX do_handshake -> TypeError
# XXX bio_read -> TypeError
# XXX recv -> TypeError
@@ -3136,24 +3240,32 @@ class ConnectionRenegotiateTests(TestCase, _LoopbackMixin):
connection = Connection(Context(TLSv1_METHOD), None)
self.assertEquals(connection.total_renegotiations(), 0)
-# def test_renegotiate(self):
-# """
-# """
-# server, client = self._loopback()
+ def test_renegotiate(self):
+ """
+ Go through a complete renegotiation cycle.
+ """
+ server, client = self._loopback()
+
+ server.send(b"hello world")
+
+ assert b"hello world" == client.recv(len(b"hello world"))
-# server.send("hello world")
-# self.assertEquals(client.recv(len("hello world")), "hello world")
+ assert 0 == server.total_renegotiations()
+ assert False is server.renegotiate_pending()
-# self.assertEquals(server.total_renegotiations(), 0)
-# self.assertTrue(server.renegotiate())
+ assert True is server.renegotiate()
-# server.setblocking(False)
-# client.setblocking(False)
-# while server.renegotiate_pending():
-# client.do_handshake()
-# server.do_handshake()
+ assert True is server.renegotiate_pending()
-# self.assertEquals(server.total_renegotiations(), 1)
+ server.setblocking(False)
+ client.setblocking(False)
+
+ client.do_handshake()
+ server.do_handshake()
+
+ assert 1 == server.total_renegotiations()
+ while False is server.renegotiate_pending():
+ pass
class ErrorTests(TestCase):