From 532b79e7b5dc159ce2d52a002e18b1e10b71a00c Mon Sep 17 00:00:00 2001 From: Alex Chan Date: Tue, 24 Jan 2017 15:14:52 +0000 Subject: Convert TestContext to be pytest-style (#589) --- tests/test_ssl.py | 679 +++++++++++++++++++++++------------------------------- 1 file changed, 284 insertions(+), 395 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 8c090aa..e0a720b 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -81,7 +81,8 @@ try: except ImportError: SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None -from .util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase +from .util import ( + WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, is_consistent_type) from .test_crypto import ( cleartextCertificatePEM, cleartextPrivateKeyPEM, client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, @@ -260,20 +261,7 @@ class _LoopbackMixin(object): return interact_in_memory(client_conn, server_conn) def _handshakeInMemory(self, client_conn, server_conn): - """ - Perform the TLS handshake between two :py:class:`Connection` instances - connected to each other via memory BIOs. - """ - client_conn.set_connect_state() - server_conn.set_accept_state() - - for conn in [client_conn, server_conn]: - try: - conn.do_handshake() - except WantReadError: - pass - - self._interactInMemory(client_conn, server_conn) + return handshake_in_memory(client_conn, server_conn) def interact_in_memory(client_conn, server_conn): @@ -321,6 +309,23 @@ def interact_in_memory(client_conn, server_conn): write.bio_write(dirty) +def handshake_in_memory(client_conn, server_conn): + """ + Perform the TLS handshake between two `Connection` instances connected to + each other via memory BIOs. + """ + client_conn.set_connect_state() + server_conn.set_accept_state() + + for conn in [client_conn, server_conn]: + try: + conn.do_handshake() + except WantReadError: + pass + + interact_in_memory(client_conn, server_conn) + + class VersionTests(TestCase): """ Tests for version information exposed by @@ -402,9 +407,7 @@ def context(): class TestContext(object): """ - py.test-based tests for :class:`OpenSSL.SSL.Context`. - - If possible, add new tests here. + Unit tests for `OpenSSL.SSL.Context`. """ @pytest.mark.parametrize("cipher_string", [ b"hello world:AES128-SHA", @@ -412,7 +415,7 @@ class TestContext(object): ]) def test_set_cipher_list(self, context, cipher_string): """ - :meth:`Context.set_cipher_list` accepts both byte and unicode strings + `Context.set_cipher_list` accepts both byte and unicode strings for naming the ciphers which connections created with the context object will be able to choose from. """ @@ -427,23 +430,22 @@ class TestContext(object): ]) def test_set_cipher_list_wrong_args(self, context, cipher_list, error): """ - :meth:`Context.set_cipher_list` raises :exc:`TypeError` when passed a - non-string argument and raises :exc:`OpenSSL.SSL.Error` when passed an - incorrect cipher list string. + `Context.set_cipher_list` raises `TypeError` when passed a non-string + argument and raises `OpenSSL.SSL.Error` when passed an incorrect cipher + list string. """ with pytest.raises(error): context.set_cipher_list(cipher_list) def test_load_client_ca(self, context, ca_file): """ - :meth:`Context.load_client_ca` works as far as we can tell. + `Context.load_client_ca` works as far as we can tell. """ context.load_client_ca(ca_file) def test_load_client_ca_invalid(self, context, tmpdir): """ - :meth:`Context.load_client_ca` raises an Error if the ca file is - invalid. + `Context.load_client_ca` raises an Error if the ca file is invalid. """ ca_file = tmpdir.join("test.pem") ca_file.write("") @@ -463,13 +465,13 @@ class TestContext(object): def test_set_session_id(self, context): """ - :meth:`Context.set_session_id` works as far as we can tell. + `Context.set_session_id` works as far as we can tell. """ context.set_session_id(b"abc") def test_set_session_id_fail(self, context): """ - :meth:`Context.set_session_id` errors are propagated. + `Context.set_session_id` errors are propagated. """ with pytest.raises(Error) as e: context.set_session_id(b"abc" * 1000) @@ -482,24 +484,16 @@ class TestContext(object): def test_set_session_id_unicode(self, context): """ - :meth:`Context.set_session_id` raises a warning if a unicode string is + `Context.set_session_id` raises a warning if a unicode string is passed. """ pytest.deprecated_call(context.set_session_id, u"abc") - -class ContextTests(TestCase, _LoopbackMixin): - """ - Unit tests for :class:`OpenSSL.SSL.Context`. - - If possible, add new tests to :class:`TestContext` above. - """ def test_method(self): """ - :py:obj:`Context` can be instantiated with one of - :py:obj:`SSLv2_METHOD`, :py:obj:`SSLv3_METHOD`, - :py:obj:`SSLv23_METHOD`, :py:obj:`TLSv1_METHOD`, - :py:obj:`TLSv1_1_METHOD`, or :py:obj:`TLSv1_2_METHOD`. + `Context` can be instantiated with one of `SSLv2_METHOD`, + `SSLv3_METHOD`, `SSLv23_METHOD`, `TLSv1_METHOD`, `TLSv1_1_METHOD`, + or `TLSv1_2_METHOD`. """ methods = [SSLv23_METHOD, TLSv1_METHOD] for meth in methods: @@ -514,44 +508,45 @@ class ContextTests(TestCase, _LoopbackMixin): # don't. Difficult to say in advance. pass - self.assertRaises(TypeError, Context, "") - self.assertRaises(ValueError, Context, 10) + with pytest.raises(TypeError): + Context("") + with pytest.raises(ValueError): + Context(10) @skip_if_py3 def test_method_long(self): """ - On Python 2 :py:class:`Context` accepts values of type - :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context` accepts values of type `long` as well as `int`. """ Context(long(TLSv1_METHOD)) def test_type(self): """ - :py:obj:`Context` and :py:obj:`ContextType` refer to the same type - object and can be used to create instances of that type. + `Context` and `ContextType` refer to the same type object and can + be used to create instances of that type. """ - self.assertIdentical(Context, ContextType) - self.assertConsistentType(Context, 'Context', TLSv1_METHOD) + assert Context is ContextType + assert is_consistent_type(Context, 'Context', TLSv1_METHOD) def test_use_privatekey(self): """ - :py:obj:`Context.use_privatekey` takes an :py:obj:`OpenSSL.crypto.PKey` - instance. + `Context.use_privatekey` takes an `OpenSSL.crypto.PKey` instance. """ key = PKey() key.generate_key(TYPE_RSA, 128) ctx = Context(TLSv1_METHOD) ctx.use_privatekey(key) - self.assertRaises(TypeError, ctx.use_privatekey, "") + with pytest.raises(TypeError): + ctx.use_privatekey("") - def test_use_privatekey_file_missing(self): + def test_use_privatekey_file_missing(self, tmpfile): """ - :py:obj:`Context.use_privatekey_file` raises - :py:obj:`OpenSSL.SSL.Error` when passed the name of a file which does - not exist. + `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` when passed + the name of a file which does not exist. """ ctx = Context(TLSv1_METHOD) - self.assertRaises(Error, ctx.use_privatekey_file, self.mktemp()) + with pytest.raises(Error): + ctx.use_privatekey_file(tmpfile) def _use_privatekey_file_test(self, pemfile, filetype): """ @@ -569,59 +564,56 @@ class ContextTests(TestCase, _LoopbackMixin): ctx = Context(TLSv1_METHOD) ctx.use_privatekey_file(pemfile, filetype) - def test_use_privatekey_file_bytes(self): + def test_use_privatekey_file_bytes(self, tmpfile): """ A private key can be specified from a file by passing a ``bytes`` instance giving the file name to ``Context.use_privatekey_file``. """ self._use_privatekey_file_test( - self.mktemp() + NON_ASCII.encode(getfilesystemencoding()), + tmpfile + NON_ASCII.encode(getfilesystemencoding()), FILETYPE_PEM, ) - def test_use_privatekey_file_unicode(self): + def test_use_privatekey_file_unicode(self, tmpfile): """ A private key can be specified from a file by passing a ``unicode`` instance giving the file name to ``Context.use_privatekey_file``. """ self._use_privatekey_file_test( - self.mktemp().decode(getfilesystemencoding()) + NON_ASCII, + tmpfile.decode(getfilesystemencoding()) + NON_ASCII, FILETYPE_PEM, ) @skip_if_py3 - def test_use_privatekey_file_long(self): + def test_use_privatekey_file_long(self, tmpfile): """ - On Python 2 :py:obj:`Context.use_privatekey_file` accepts a - filetype of type :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context.use_privatekey_file` accepts a filetype of + type `long` as well as `int`. """ - self._use_privatekey_file_test(self.mktemp(), long(FILETYPE_PEM)) + self._use_privatekey_file_test(tmpfile, long(FILETYPE_PEM)) def test_use_certificate_wrong_args(self): """ - :py:obj:`Context.use_certificate_wrong_args` raises :py:obj:`TypeError` - when not passed exactly one :py:obj:`OpenSSL.crypto.X509` instance as - an argument. + `Context.use_certificate_wrong_args` raises `TypeError` when not passed + exactly one `OpenSSL.crypto.X509` instance as an argument. """ ctx = Context(TLSv1_METHOD) - self.assertRaises(TypeError, ctx.use_certificate) - self.assertRaises(TypeError, ctx.use_certificate, "hello, world") - self.assertRaises( - TypeError, ctx.use_certificate, X509(), "hello, world" - ) + with pytest.raises(TypeError): + ctx.use_certificate("hello, world") def test_use_certificate_uninitialized(self): """ - :py:obj:`Context.use_certificate` raises :py:obj:`OpenSSL.SSL.Error` - when passed a :py:obj:`OpenSSL.crypto.X509` instance which has not been - initialized (ie, which does not actually have any certificate data). + `Context.use_certificate` raises `OpenSSL.SSL.Error` when passed a + `OpenSSL.crypto.X509` instance which has not been initialized + (ie, which does not actually have any certificate data). """ ctx = Context(TLSv1_METHOD) - self.assertRaises(Error, ctx.use_certificate, X509()) + with pytest.raises(Error): + ctx.use_certificate(X509()) def test_use_certificate(self): """ - :py:obj:`Context.use_certificate` sets the certificate which will be + `Context.use_certificate` sets the certificate which will be used to identify connections created using the context. """ # TODO @@ -635,31 +627,25 @@ class ContextTests(TestCase, _LoopbackMixin): def test_use_certificate_file_wrong_args(self): """ - :py:obj:`Context.use_certificate_file` raises :py:obj:`TypeError` if - called with zero arguments or more than two arguments, or if the first - argument is not a byte string or the second argumnent is not an - integer. + `Context.use_certificate_file` raises `TypeError` if the first + argument is not a byte string or the second argument is not an integer. """ ctx = Context(TLSv1_METHOD) - self.assertRaises(TypeError, ctx.use_certificate_file) - self.assertRaises(TypeError, ctx.use_certificate_file, b"somefile", - object()) - self.assertRaises( - TypeError, ctx.use_certificate_file, b"somefile", FILETYPE_PEM, - object()) - self.assertRaises( - TypeError, ctx.use_certificate_file, object(), FILETYPE_PEM) - self.assertRaises( - TypeError, ctx.use_certificate_file, b"somefile", object()) + with pytest.raises(TypeError): + ctx.use_certificate_file(object(), FILETYPE_PEM) + with pytest.raises(TypeError): + ctx.use_certificate_file(b"somefile", object()) + with pytest.raises(TypeError): + ctx.use_certificate_file(object(), FILETYPE_PEM) - def test_use_certificate_file_missing(self): + def test_use_certificate_file_missing(self, tmpfile): """ - :py:obj:`Context.use_certificate_file` raises - `:py:obj:`OpenSSL.SSL.Error` if passed the name of a file which does - not exist. + `Context.use_certificate_file` raises `OpenSSL.SSL.Error` if passed + the name of a file which does not exist. """ ctx = Context(TLSv1_METHOD) - self.assertRaises(Error, ctx.use_certificate_file, self.mktemp()) + with pytest.raises(Error): + ctx.use_certificate_file(tmpfile) def _use_certificate_file_test(self, certificate_file): """ @@ -676,31 +662,31 @@ class ContextTests(TestCase, _LoopbackMixin): ctx = Context(TLSv1_METHOD) ctx.use_certificate_file(certificate_file) - def test_use_certificate_file_bytes(self): + def test_use_certificate_file_bytes(self, tmpfile): """ - :py:obj:`Context.use_certificate_file` sets the certificate (given as a - ``bytes`` filename) which will be used to identify connections created + `Context.use_certificate_file` sets the certificate (given as a + `bytes` filename) which will be used to identify connections created using the context. """ - filename = self.mktemp() + NON_ASCII.encode(getfilesystemencoding()) + filename = tmpfile + NON_ASCII.encode(getfilesystemencoding()) self._use_certificate_file_test(filename) - def test_use_certificate_file_unicode(self): + def test_use_certificate_file_unicode(self, tmpfile): """ - :py:obj:`Context.use_certificate_file` sets the certificate (given as a - ``bytes`` filename) which will be used to identify connections created + `Context.use_certificate_file` sets the certificate (given as a + `bytes` filename) which will be used to identify connections created using the context. """ - filename = self.mktemp().decode(getfilesystemencoding()) + NON_ASCII + filename = tmpfile.decode(getfilesystemencoding()) + NON_ASCII self._use_certificate_file_test(filename) @skip_if_py3 - def test_use_certificate_file_long(self): + def test_use_certificate_file_long(self, tmpfile): """ - On Python 2 :py:obj:`Context.use_certificate_file` accepts a - filetype of type :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context.use_certificate_file` accepts a + filetype of type `long` as well as `int`. """ - pem_filename = self.mktemp() + pem_filename = tmpfile with open(pem_filename, "wb") as pem_file: pem_file.write(cleartextCertificatePEM) @@ -709,78 +695,52 @@ class ContextTests(TestCase, _LoopbackMixin): def test_check_privatekey_valid(self): """ - :py:obj:`Context.check_privatekey` returns :py:obj:`None` if the - :py:obj:`Context` instance has been configured to use a matched key and - certificate pair. + `Context.check_privatekey` returns `None` if the `Context` instance + has been configured to use a matched key and certificate pair. """ key = load_privatekey(FILETYPE_PEM, client_key_pem) cert = load_certificate(FILETYPE_PEM, client_cert_pem) context = Context(TLSv1_METHOD) context.use_privatekey(key) context.use_certificate(cert) - self.assertIs(None, context.check_privatekey()) + assert None is context.check_privatekey() def test_check_privatekey_invalid(self): """ - :py:obj:`Context.check_privatekey` raises :py:obj:`Error` if the - :py:obj:`Context` instance has been configured to use a key and - certificate pair which don't relate to each other. + `Context.check_privatekey` raises `Error` if the `Context` instance + has been configured to use a key and certificate pair which don't + relate to each other. """ key = load_privatekey(FILETYPE_PEM, client_key_pem) cert = load_certificate(FILETYPE_PEM, server_cert_pem) context = Context(TLSv1_METHOD) context.use_privatekey(key) context.use_certificate(cert) - self.assertRaises(Error, context.check_privatekey) - - def test_check_privatekey_wrong_args(self): - """ - :py:obj:`Context.check_privatekey` raises :py:obj:`TypeError` if called - with other than no arguments. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.check_privatekey, object()) - - def test_set_app_data_wrong_args(self): - """ - :py:obj:`Context.set_app_data` raises :py:obj:`TypeError` if called - with other than one argument. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_app_data) - self.assertRaises(TypeError, context.set_app_data, None, None) - - def test_get_app_data_wrong_args(self): - """ - :py:obj:`Context.get_app_data` raises :py:obj:`TypeError` if called - with any arguments. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.get_app_data, None) + with pytest.raises(Error): + context.check_privatekey() def test_app_data(self): """ - :py:obj:`Context.set_app_data` stores an object for later retrieval - using :py:obj:`Context.get_app_data`. + `Context.set_app_data` stores an object for later retrieval + using `Context.get_app_data`. """ app_data = object() context = Context(TLSv1_METHOD) context.set_app_data(app_data) - self.assertIdentical(context.get_app_data(), app_data) + assert context.get_app_data() is app_data def test_set_options_wrong_args(self): """ - :py:obj:`Context.set_options` raises :py:obj:`TypeError` if called with - the wrong number of arguments or a non-:py:obj:`int` argument. + `Context.set_options` raises `TypeError` if called with + a non-`int` argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_options) - self.assertRaises(TypeError, context.set_options, None) - self.assertRaises(TypeError, context.set_options, 1, None) + with pytest.raises(TypeError): + context.set_options(None) def test_set_options(self): """ - :py:obj:`Context.set_options` returns the new options value. + `Context.set_options` returns the new options value. """ context = Context(TLSv1_METHOD) options = context.set_options(OP_NO_SSLv2) @@ -789,8 +749,8 @@ class ContextTests(TestCase, _LoopbackMixin): @skip_if_py3 def test_set_options_long(self): """ - On Python 2 :py:obj:`Context.set_options` accepts values of type - :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context.set_options` accepts values of type + `long` as well as `int`. """ context = Context(TLSv1_METHOD) options = context.set_options(long(OP_NO_SSLv2)) @@ -798,142 +758,117 @@ class ContextTests(TestCase, _LoopbackMixin): def test_set_mode_wrong_args(self): """ - :py:obj:`Context.set`mode} raises :py:obj:`TypeError` if called with - the wrong number of arguments or a non-:py:obj:`int` argument. + `Context.set_mode` raises `TypeError` if called with + a non-`int` argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_mode) - self.assertRaises(TypeError, context.set_mode, None) - self.assertRaises(TypeError, context.set_mode, 1, None) + with pytest.raises(TypeError): + context.set_mode(None) def test_set_mode(self): """ - :py:obj:`Context.set_mode` accepts a mode bitvector and returns the + `Context.set_mode` accepts a mode bitvector and returns the newly set mode. """ context = Context(TLSv1_METHOD) - self.assertTrue( - MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS)) + assert MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS) @skip_if_py3 def test_set_mode_long(self): """ - On Python 2 :py:obj:`Context.set_mode` accepts values of type - :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context.set_mode` accepts values of type `long` as well + as `int`. """ context = Context(TLSv1_METHOD) mode = context.set_mode(long(MODE_RELEASE_BUFFERS)) - self.assertTrue(MODE_RELEASE_BUFFERS & mode) + assert MODE_RELEASE_BUFFERS & mode def test_set_timeout_wrong_args(self): """ - :py:obj:`Context.set_timeout` raises :py:obj:`TypeError` if called with - the wrong number of arguments or a non-:py:obj:`int` argument. + `Context.set_timeout` raises `TypeError` if called with + a non-`int` argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_timeout) - self.assertRaises(TypeError, context.set_timeout, None) - self.assertRaises(TypeError, context.set_timeout, 1, None) - - def test_get_timeout_wrong_args(self): - """ - :py:obj:`Context.get_timeout` raises :py:obj:`TypeError` if called with - any arguments. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.get_timeout, None) + with pytest.raises(TypeError): + context.set_timeout(None) def test_timeout(self): """ - :py:obj:`Context.set_timeout` sets the session timeout for all - connections created using the context object. - :py:obj:`Context.get_timeout` retrieves this value. + `Context.set_timeout` sets the session timeout for all connections + created using the context object. `Context.get_timeout` retrieves + this value. """ context = Context(TLSv1_METHOD) context.set_timeout(1234) - self.assertEquals(context.get_timeout(), 1234) + assert context.get_timeout() == 1234 @skip_if_py3 def test_timeout_long(self): """ - On Python 2 :py:obj:`Context.set_timeout` accepts values of type - `long` as well as int. + On Python 2 `Context.set_timeout` accepts values of type `long` as + well as int. """ context = Context(TLSv1_METHOD) context.set_timeout(long(1234)) - self.assertEquals(context.get_timeout(), 1234) + assert context.get_timeout() == 1234 def test_set_verify_depth_wrong_args(self): """ - :py:obj:`Context.set_verify_depth` raises :py:obj:`TypeError` if called - with the wrong number of arguments or a non-:py:obj:`int` argument. + `Context.set_verify_depth` raises `TypeError` if called with a + non-`int` argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_verify_depth) - self.assertRaises(TypeError, context.set_verify_depth, None) - self.assertRaises(TypeError, context.set_verify_depth, 1, None) - - def test_get_verify_depth_wrong_args(self): - """ - :py:obj:`Context.get_verify_depth` raises :py:obj:`TypeError` if called - with any arguments. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.get_verify_depth, None) + with pytest.raises(TypeError): + context.set_verify_depth(None) def test_verify_depth(self): """ - :py:obj:`Context.set_verify_depth` sets the number of certificates in + `Context.set_verify_depth` sets the number of certificates in a chain to follow before giving up. The value can be retrieved with - :py:obj:`Context.get_verify_depth`. + `Context.get_verify_depth`. """ context = Context(TLSv1_METHOD) context.set_verify_depth(11) - self.assertEquals(context.get_verify_depth(), 11) + assert context.get_verify_depth() == 11 @skip_if_py3 def test_verify_depth_long(self): """ - On Python 2 :py:obj:`Context.set_verify_depth` accepts values of - type `long` as well as int. + On Python 2 `Context.set_verify_depth` accepts values of type `long` + as well as int. """ context = Context(TLSv1_METHOD) context.set_verify_depth(long(11)) - self.assertEquals(context.get_verify_depth(), 11) + assert context.get_verify_depth() == 11 - def _write_encrypted_pem(self, passphrase): + def _write_encrypted_pem(self, passphrase, tmpfile): """ Write a new private key out to a new file, encrypted using the given passphrase. Return the path to the new file. """ key = PKey() key.generate_key(TYPE_RSA, 128) - pemFile = self.mktemp() - fObj = open(pemFile, 'w') pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase) - fObj.write(pem.decode('ascii')) - fObj.close() - return pemFile + with open(tmpfile, 'w') as fObj: + fObj.write(pem.decode('ascii')) + return tmpfile def test_set_passwd_cb_wrong_args(self): """ - :py:obj:`Context.set_passwd_cb` raises :py:obj:`TypeError` if called - with the wrong arguments or with a non-callable first argument. + `Context.set_passwd_cb` raises `TypeError` if called with a + non-callable first argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_passwd_cb) - self.assertRaises(TypeError, context.set_passwd_cb, None) - self.assertRaises( - TypeError, context.set_passwd_cb, lambda: None, None, None - ) # pragma: nocover + with pytest.raises(TypeError): + context.set_passwd_cb(None) - def test_set_passwd_cb(self): + def test_set_passwd_cb(self, tmpfile): """ - :py:obj:`Context.set_passwd_cb` accepts a callable which will be - invoked when a private key is loaded from an encrypted PEM. + `Context.set_passwd_cb` accepts a callable which will be invoked when + a private key is loaded from an encrypted PEM. """ passphrase = b"foobar" - pemFile = self._write_encrypted_pem(passphrase) + pemFile = self._write_encrypted_pem(passphrase, tmpfile) calledWith = [] def passphraseCallback(maxlen, verify, extra): @@ -942,63 +877,65 @@ class ContextTests(TestCase, _LoopbackMixin): context = Context(TLSv1_METHOD) context.set_passwd_cb(passphraseCallback) context.use_privatekey_file(pemFile) - self.assertTrue(len(calledWith), 1) - self.assertTrue(isinstance(calledWith[0][0], int)) - self.assertTrue(isinstance(calledWith[0][1], int)) - self.assertEqual(calledWith[0][2], None) + assert len(calledWith) == 1 + assert isinstance(calledWith[0][0], int) + assert isinstance(calledWith[0][1], int) + assert calledWith[0][2] is None - def test_passwd_callback_exception(self): + def test_passwd_callback_exception(self, tmpfile): """ - :py:obj:`Context.use_privatekey_file` propagates any exception raised + `Context.use_privatekey_file` propagates any exception raised by the passphrase callback. """ - pemFile = self._write_encrypted_pem(b"monkeys are nice") + pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) def passphraseCallback(maxlen, verify, extra): raise RuntimeError("Sorry, I am a fail.") context = Context(TLSv1_METHOD) context.set_passwd_cb(passphraseCallback) - self.assertRaises(RuntimeError, context.use_privatekey_file, pemFile) + with pytest.raises(RuntimeError): + context.use_privatekey_file(pemFile) - def test_passwd_callback_false(self): + def test_passwd_callback_false(self, tmpfile): """ - :py:obj:`Context.use_privatekey_file` raises - :py:obj:`OpenSSL.SSL.Error` if the passphrase callback returns a false - value. + `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the + passphrase callback returns a false value. """ - pemFile = self._write_encrypted_pem(b"monkeys are nice") + pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) def passphraseCallback(maxlen, verify, extra): return b"" context = Context(TLSv1_METHOD) context.set_passwd_cb(passphraseCallback) - self.assertRaises(Error, context.use_privatekey_file, pemFile) + with pytest.raises(Error): + context.use_privatekey_file(pemFile) - def test_passwd_callback_non_string(self): + def test_passwd_callback_non_string(self, tmpfile): """ - :py:obj:`Context.use_privatekey_file` raises - :py:obj:`OpenSSL.SSL.Error` if the passphrase callback returns a true - non-string value. + `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the + passphrase callback returns a true non-string value. """ - pemFile = self._write_encrypted_pem(b"monkeys are nice") + pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) def passphraseCallback(maxlen, verify, extra): return 10 context = Context(TLSv1_METHOD) context.set_passwd_cb(passphraseCallback) - self.assertRaises(ValueError, context.use_privatekey_file, pemFile) + # TODO: Surely this is the wrong error? + with pytest.raises(ValueError): + context.use_privatekey_file(pemFile) - def test_passwd_callback_too_long(self): + def test_passwd_callback_too_long(self, tmpfile): """ If the passphrase returned by the passphrase callback returns a string longer than the indicated maximum length, it is truncated. """ # A priori knowledge! passphrase = b"x" * 1024 - pemFile = self._write_encrypted_pem(passphrase) + pemFile = self._write_encrypted_pem(passphrase, tmpfile) def passphraseCallback(maxlen, verify, extra): assert maxlen == 1024 @@ -1012,7 +949,7 @@ class ContextTests(TestCase, _LoopbackMixin): def test_set_info_callback(self): """ - :py:obj:`Context.set_info_callback` accepts a callable which will be + `Context.set_info_callback` accepts a callable which will be invoked when certain information about an SSL connection is available. """ (server, client) = socket_pair() @@ -1044,14 +981,13 @@ class ContextTests(TestCase, _LoopbackMixin): notConnections = [ conn for (conn, where, ret) in called if not isinstance(conn, Connection)] - self.assertEqual( - [], notConnections, - "Some info callback arguments were not Connection instaces.") + assert [] == notConnections, ( + "Some info callback arguments were not Connection instances.") def _load_verify_locations_test(self, *args): """ Create a client context which will verify the peer certificate and call - its :py:obj:`load_verify_locations` method with the given arguments. + its `load_verify_locations` method with the given arguments. Then connect it to a server and ensure that the handshake succeeds. """ (server, client) = socket_pair() @@ -1083,48 +1019,45 @@ class ContextTests(TestCase, _LoopbackMixin): handshake(clientSSL, serverSSL) cert = clientSSL.get_peer_certificate() - self.assertEqual(cert.get_subject().CN, 'Testing Root CA') + assert cert.get_subject().CN == 'Testing Root CA' def _load_verify_cafile(self, cafile): """ Verify that if path to a file containing a certificate is passed to - ``Context.load_verify_locations`` for the ``cafile`` parameter, that + `Context.load_verify_locations` for the ``cafile`` parameter, that certificate is used as a trust root for the purposes of verifying - connections created using that ``Context``. + connections created using that `Context`. """ - fObj = open(cafile, 'w') - fObj.write(cleartextCertificatePEM.decode('ascii')) - fObj.close() + with open(cafile, 'w') as fObj: + fObj.write(cleartextCertificatePEM.decode('ascii')) self._load_verify_locations_test(cafile) - def test_load_verify_bytes_cafile(self): + def test_load_verify_bytes_cafile(self, tmpfile): """ - :py:obj:`Context.load_verify_locations` accepts a file name as a - ``bytes`` instance and uses the certificates within for verification - purposes. + `Context.load_verify_locations` accepts a file name as a `bytes` + instance and uses the certificates within for verification purposes. """ - cafile = self.mktemp() + NON_ASCII.encode(getfilesystemencoding()) + cafile = tmpfile + NON_ASCII.encode(getfilesystemencoding()) self._load_verify_cafile(cafile) - def test_load_verify_unicode_cafile(self): + def test_load_verify_unicode_cafile(self, tmpfile): """ - :py:obj:`Context.load_verify_locations` accepts a file name as a - ``unicode`` instance and uses the certificates within for verification - purposes. + `Context.load_verify_locations` accepts a file name as a `unicode` + instance and uses the certificates within for verification purposes. """ self._load_verify_cafile( - self.mktemp().decode(getfilesystemencoding()) + NON_ASCII + tmpfile.decode(getfilesystemencoding()) + NON_ASCII ) - def test_load_verify_invalid_file(self): + def test_load_verify_invalid_file(self, tmpfile): """ - :py:obj:`Context.load_verify_locations` raises :py:obj:`Error` when - passed a non-existent cafile. + `Context.load_verify_locations` raises `Error` when passed a + non-existent cafile. """ clientContext = Context(TLSv1_METHOD) - self.assertRaises( - Error, clientContext.load_verify_locations, self.mktemp()) + with pytest.raises(Error): + clientContext.load_verify_locations(tmpfile) def _load_verify_directory_locations_capath(self, capath): """ @@ -1144,41 +1077,34 @@ class ContextTests(TestCase, _LoopbackMixin): self._load_verify_locations_test(None, capath) - def test_load_verify_directory_bytes_capath(self): + def test_load_verify_directory_bytes_capath(self, tmpfile): """ - :py:obj:`Context.load_verify_locations` accepts a directory name as a - ``bytes`` instance and uses the certificates within for verification - purposes. + `Context.load_verify_locations` accepts a directory name as a `bytes` + instance and uses the certificates within for verification purposes. """ self._load_verify_directory_locations_capath( - self.mktemp() + NON_ASCII.encode(getfilesystemencoding()) + tmpfile + NON_ASCII.encode(getfilesystemencoding()) ) - def test_load_verify_directory_unicode_capath(self): + def test_load_verify_directory_unicode_capath(self, tmpfile): """ - :py:obj:`Context.load_verify_locations` accepts a directory name as a - ``unicode`` instance and uses the certificates within for verification - purposes. + `Context.load_verify_locations` accepts a directory name as a `unicode` + instance and uses the certificates within for verification purposes. """ self._load_verify_directory_locations_capath( - self.mktemp().decode(getfilesystemencoding()) + NON_ASCII + tmpfile.decode(getfilesystemencoding()) + NON_ASCII ) def test_load_verify_locations_wrong_args(self): """ - :py:obj:`Context.load_verify_locations` raises :py:obj:`TypeError` if - called with the wrong number of arguments or with non-:py:obj:`str` + `Context.load_verify_locations` raises `TypeError` if with non-`str` arguments. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.load_verify_locations) - self.assertRaises(TypeError, context.load_verify_locations, object()) - self.assertRaises( - TypeError, context.load_verify_locations, object(), object() - ) - self.assertRaises( - TypeError, context.load_verify_locations, None, None, None - ) + with pytest.raises(TypeError): + context.load_verify_locations(object()) + with pytest.raises(TypeError): + context.load_verify_locations(object(), object()) @pytest.mark.skipif( platform == "win32", @@ -1187,9 +1113,8 @@ class ContextTests(TestCase, _LoopbackMixin): ) def test_set_default_verify_paths(self): """ - :py:obj:`Context.set_default_verify_paths` causes the - platform-specific CA certificate locations to be used for - verification purposes. + `Context.set_default_verify_paths` causes the platform-specific CA + certificate locations to be used for verification purposes. """ # Testing this requires a server with a certificate signed by one # of the CAs in the platform CA location. Getting one of those @@ -1212,30 +1137,16 @@ class ContextTests(TestCase, _LoopbackMixin): clientSSL.set_connect_state() clientSSL.do_handshake() clientSSL.send(b"GET / HTTP/1.0\r\n\r\n") - self.assertTrue(clientSSL.recv(1024)) - - def test_set_default_verify_paths_signature(self): - """ - :py:obj:`Context.set_default_verify_paths` takes no arguments and - raises :py:obj:`TypeError` if given any. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_default_verify_paths, None) - self.assertRaises(TypeError, context.set_default_verify_paths, 1) - self.assertRaises(TypeError, context.set_default_verify_paths, "") + assert clientSSL.recv(1024) def test_add_extra_chain_cert_invalid_cert(self): """ - :py:obj:`Context.add_extra_chain_cert` raises :py:obj:`TypeError` if - called with other than one argument or if called with an object which - is not an instance of :py:obj:`X509`. + `Context.add_extra_chain_cert` raises `TypeError` if called with an + object which is not an instance of `X509`. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.add_extra_chain_cert) - self.assertRaises(TypeError, context.add_extra_chain_cert, object()) - self.assertRaises( - TypeError, context.add_extra_chain_cert, object(), object() - ) + with pytest.raises(TypeError): + context.add_extra_chain_cert(object()) def _handshake_test(self, serverContext, clientContext): """ @@ -1251,8 +1162,8 @@ class ContextTests(TestCase, _LoopbackMixin): client.set_connect_state() # Make them talk to each other. - # self._interactInMemory(client, server) - for i in range(3): + # interact_in_memory(client, server) + for _ in range(3): for s in [client, server]: try: s.do_handshake() @@ -1262,7 +1173,7 @@ class ContextTests(TestCase, _LoopbackMixin): def test_set_verify_callback_connection_argument(self): """ The first argument passed to the verify callback is the - :py:class:`Connection` instance for which verification is taking place. + `Connection` instance for which verification is taking place. """ serverContext = Context(TLSv1_METHOD) serverContext.use_privatekey( @@ -1282,15 +1193,15 @@ class ContextTests(TestCase, _LoopbackMixin): clientConnection = Connection(clientContext, None) clientConnection.set_connect_state() - self._handshakeInMemory(clientConnection, serverConnection) + handshake_in_memory(clientConnection, serverConnection) - self.assertIdentical(verify.connection, clientConnection) + assert verify.connection is clientConnection def test_set_verify_callback_exception(self): """ - If the verify callback passed to :py:obj:`Context.set_verify` raises an + If the verify callback passed to `Context.set_verify` raises an exception, verification fails and the exception is propagated to the - caller of :py:obj:`Connection.do_handshake`. + caller of `Connection.do_handshake`. """ serverContext = Context(TLSv1_METHOD) serverContext.use_privatekey( @@ -1307,14 +1218,14 @@ class ContextTests(TestCase, _LoopbackMixin): with pytest.raises(Exception) as exc: self._handshake_test(serverContext, clientContext) - self.assertEqual("silly verify failure", str(exc.value)) + assert "silly verify failure" == str(exc.value) - def test_add_extra_chain_cert(self): + def test_add_extra_chain_cert(self, tmpdir): """ - :py:obj:`Context.add_extra_chain_cert` accepts an :py:obj:`X509` + `Context.add_extra_chain_cert` accepts an `X509` instance to add to the certificate chain. - See :py:obj:`_create_certificate_chain` for the details of the + See `_create_certificate_chain` for the details of the certificate chain tested. The chain is tested by starting a server with scert and connecting @@ -1329,13 +1240,13 @@ class ContextTests(TestCase, _LoopbackMixin): for cert, name in [(cacert, 'ca.pem'), (icert, 'i.pem'), (scert, 's.pem')]: - with open(join(self.tmpdir, name), 'w') as f: + with tmpdir.join(name).open('w') as f: f.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii')) for key, name in [(cakey, 'ca.key'), (ikey, 'i.key'), (skey, 's.key')]: - with open(join(self.tmpdir, name), 'w') as f: + with tmpdir.join(name).open('w') as f: f.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii')) # Create the server context @@ -1349,14 +1260,14 @@ class ContextTests(TestCase, _LoopbackMixin): clientContext = Context(TLSv1_METHOD) clientContext.set_verify( VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) - clientContext.load_verify_locations(join(self.tmpdir, "ca.pem")) + clientContext.load_verify_locations(str(tmpdir.join("ca.pem"))) # Try it out. self._handshake_test(serverContext, clientContext) def _use_certificate_chain_file_test(self, certdir): """ - Verify that :py:obj:`Context.use_certificate_chain_file` reads a + Verify that `Context.use_certificate_chain_file` reads a certificate chain from a specified file. The chain is tested by starting a server with scert and connecting to @@ -1392,98 +1303,86 @@ class ContextTests(TestCase, _LoopbackMixin): self._handshake_test(serverContext, clientContext) - def test_use_certificate_chain_file_bytes(self): + def test_use_certificate_chain_file_bytes(self, tmpfile): """ ``Context.use_certificate_chain_file`` accepts the name of a file (as an instance of ``bytes``) to specify additional certificates to use to construct and verify a trust chain. """ self._use_certificate_chain_file_test( - self.mktemp() + NON_ASCII.encode(getfilesystemencoding()) + tmpfile + NON_ASCII.encode(getfilesystemencoding()) ) - def test_use_certificate_chain_file_unicode(self): + def test_use_certificate_chain_file_unicode(self, tmpfile): """ ``Context.use_certificate_chain_file`` accepts the name of a file (as an instance of ``unicode``) to specify additional certificates to use to construct and verify a trust chain. """ self._use_certificate_chain_file_test( - self.mktemp().decode(getfilesystemencoding()) + NON_ASCII + tmpfile.decode(getfilesystemencoding()) + NON_ASCII ) def test_use_certificate_chain_file_wrong_args(self): """ - :py:obj:`Context.use_certificate_chain_file` raises :py:obj:`TypeError` - if passed zero or more than one argument or when passed a non-byte - string single argument. It also raises :py:obj:`OpenSSL.SSL.Error` - when passed a bad chain file name (for example, the name of a file - which does not exist). + `Context.use_certificate_chain_file` raises `TypeError` if passed a + non-byte string single argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.use_certificate_chain_file) - self.assertRaises( - TypeError, context.use_certificate_chain_file, object() - ) - self.assertRaises( - TypeError, context.use_certificate_chain_file, b"foo", object() - ) - - self.assertRaises( - Error, context.use_certificate_chain_file, self.mktemp() - ) + with pytest.raises(TypeError): + context.use_certificate_chain_file(object()) - def test_get_verify_mode_wrong_args(self): + def test_use_certificate_chain_file_missing_file(self, tmpfile): """ - :py:obj:`Context.get_verify_mode` raises :py:obj:`TypeError` if called - with any arguments. + `Context.use_certificate_chain_file` raises `OpenSSL.SSL.Error` when + passed a bad chain file name (for example, the name of a file which + does not exist). """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.get_verify_mode, None) + with pytest.raises(Error): + context.use_certificate_chain_file(tmpfile) def test_set_verify_mode(self): """ - :py:obj:`Context.get_verify_mode` returns the verify mode flags - previously passed to :py:obj:`Context.set_verify`. + `Context.get_verify_mode` returns the verify mode flags previously + passed to `Context.set_verify`. """ context = Context(TLSv1_METHOD) - self.assertEquals(context.get_verify_mode(), 0) + assert context.get_verify_mode() == 0 context.set_verify( VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None) - self.assertEquals( - context.get_verify_mode(), VERIFY_PEER | VERIFY_CLIENT_ONCE) + assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) @skip_if_py3 def test_set_verify_mode_long(self): """ - On Python 2 :py:obj:`Context.set_verify_mode` accepts values of - type :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context.set_verify_mode` accepts values of type `long` + as well as `int`. """ context = Context(TLSv1_METHOD) - self.assertEquals(context.get_verify_mode(), 0) + assert context.get_verify_mode() == 0 context.set_verify( long(VERIFY_PEER | VERIFY_CLIENT_ONCE), lambda *args: None ) # pragma: nocover - self.assertEquals( - context.get_verify_mode(), VERIFY_PEER | VERIFY_CLIENT_ONCE) + assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) def test_load_tmp_dh_wrong_args(self): """ - :py:obj:`Context.load_tmp_dh` raises :py:obj:`TypeError` if called with - the wrong number of arguments or with a non-:py:obj:`str` argument. + `Context.load_tmp_dh` raises `TypeError` if called with a + non-`str` argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.load_tmp_dh) - self.assertRaises(TypeError, context.load_tmp_dh, "foo", None) - self.assertRaises(TypeError, context.load_tmp_dh, object()) + with pytest.raises(TypeError): + context.load_tmp_dh(object()) def test_load_tmp_dh_missing_file(self): """ - :py:obj:`Context.load_tmp_dh` raises :py:obj:`OpenSSL.SSL.Error` if the + `Context.load_tmp_dh` raises `OpenSSL.SSL.Error` if the specified file does not exist. """ context = Context(TLSv1_METHOD) - self.assertRaises(Error, context.load_tmp_dh, b"hello") + with pytest.raises(Error): + context.load_tmp_dh(b"hello") def _load_tmp_dh_test(self, dhfilename): """ @@ -1497,28 +1396,28 @@ class ContextTests(TestCase, _LoopbackMixin): context.load_tmp_dh(dhfilename) # XXX What should I assert here? -exarkun - def test_load_tmp_dh_bytes(self): + def test_load_tmp_dh_bytes(self, tmpfile): """ - :py:obj:`Context.load_tmp_dh` loads Diffie-Hellman parameters from the + `Context.load_tmp_dh` loads Diffie-Hellman parameters from the specified file (given as ``bytes``). """ self._load_tmp_dh_test( - self.mktemp() + NON_ASCII.encode(getfilesystemencoding()), + tmpfile + NON_ASCII.encode(getfilesystemencoding()), ) - def test_load_tmp_dh_unicode(self): + def test_load_tmp_dh_unicode(self, tmpfile): """ - :py:obj:`Context.load_tmp_dh` loads Diffie-Hellman parameters from the + `Context.load_tmp_dh` loads Diffie-Hellman parameters from the specified file (given as ``unicode``). """ self._load_tmp_dh_test( - self.mktemp().decode(getfilesystemencoding()) + NON_ASCII, + tmpfile.decode(getfilesystemencoding()) + NON_ASCII, ) def test_set_tmp_ecdh(self): """ - :py:obj:`Context.set_tmp_ecdh` sets the elliptic curve for - Diffie-Hellman to the specified curve. + `Context.set_tmp_ecdh` sets the elliptic curve for Diffie-Hellman to + the specified curve. """ context = Context(TLSv1_METHOD) for curve in get_elliptic_curves(): @@ -1533,52 +1432,42 @@ class ContextTests(TestCase, _LoopbackMixin): def test_set_session_cache_mode_wrong_args(self): """ - :py:obj:`Context.set_session_cache_mode` raises :py:obj:`TypeError` if + `Context.set_session_cache_mode` raises `TypeError` if called with + a non-integer argument. called with other than one integer argument. """ context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_session_cache_mode) - self.assertRaises(TypeError, context.set_session_cache_mode, object()) - - def test_get_session_cache_mode_wrong_args(self): - """ - :py:obj:`Context.get_session_cache_mode` raises :py:obj:`TypeError` if - called with any arguments. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.get_session_cache_mode, 1) + with pytest.raises(TypeError): + context.set_session_cache_mode(object()) def test_session_cache_mode(self): """ - :py:obj:`Context.set_session_cache_mode` specifies how sessions are - cached. The setting can be retrieved via - :py:obj:`Context.get_session_cache_mode`. + `Context.set_session_cache_mode` specifies how sessions are cached. + The setting can be retrieved via `Context.get_session_cache_mode`. """ context = Context(TLSv1_METHOD) context.set_session_cache_mode(SESS_CACHE_OFF) off = context.set_session_cache_mode(SESS_CACHE_BOTH) - self.assertEqual(SESS_CACHE_OFF, off) - self.assertEqual(SESS_CACHE_BOTH, context.get_session_cache_mode()) + assert SESS_CACHE_OFF == off + assert SESS_CACHE_BOTH == context.get_session_cache_mode() @skip_if_py3 def test_session_cache_mode_long(self): """ - On Python 2 :py:obj:`Context.set_session_cache_mode` accepts values - of type :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Context.set_session_cache_mode` accepts values + of type `long` as well as `int`. """ context = Context(TLSv1_METHOD) context.set_session_cache_mode(long(SESS_CACHE_BOTH)) - self.assertEqual( - SESS_CACHE_BOTH, context.get_session_cache_mode()) + assert SESS_CACHE_BOTH == context.get_session_cache_mode() def test_get_cert_store(self): """ - :py:obj:`Context.get_cert_store` returns a :py:obj:`X509Store` - instance. + `Context.get_cert_store` returns a `X509Store` instance. """ context = Context(TLSv1_METHOD) store = context.get_cert_store() - self.assertIsInstance(store, X509Store) + assert isinstance(store, X509Store) class TestServerNameCallback(object): -- cgit v1.2.1