diff options
-rw-r--r-- | OpenSSL/SSL.py | 27 | ||||
-rw-r--r-- | OpenSSL/__init__.py | 3 | ||||
-rw-r--r-- | OpenSSL/_util.py | 9 | ||||
-rw-r--r-- | OpenSSL/crypto.py | 22 | ||||
-rw-r--r-- | OpenSSL/test/test_crypto.py | 2 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 37 |
6 files changed, 79 insertions, 21 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 8da25e2..ce2cc29 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -4,6 +4,9 @@ from itertools import count from weakref import WeakValueDictionary from errno import errorcode +from six import text_type as _text_type + + from OpenSSL._util import ( ffi as _ffi, lib as _lib, @@ -356,8 +359,12 @@ class Context(object): :param certfile: The name of the certificate chain file :return: None """ + if isinstance(certfile, _text_type): + # Perhaps sys.getfilesystemencoding() could be better? + certfile = certfile.encode("utf-8") + if not isinstance(certfile, bytes): - raise TypeError("certfile must be a byte string") + raise TypeError("certfile must be bytes or unicode") result = _lib.SSL_CTX_use_certificate_chain_file(self._context, certfile) if not result: @@ -372,8 +379,11 @@ class Context(object): :param filetype: (optional) The encoding of the file, default is PEM :return: None """ + if isinstance(certfile, _text_type): + # Perhaps sys.getfilesystemencoding() could be better? + certfile = certfile.encode("utf-8") if not isinstance(certfile, bytes): - raise TypeError("certfile must be a byte string") + raise TypeError("certfile must be bytes or unicode") if not isinstance(filetype, int): raise TypeError("filetype must be an integer") @@ -431,6 +441,10 @@ class Context(object): :param filetype: (optional) The encoding of the file, default is PEM :return: None """ + if isinstance(keyfile, _text_type): + # Perhaps sys.getfilesystemencoding() could be better? + keyfile = keyfile.encode("utf-8") + if not isinstance(keyfile, bytes): raise TypeError("keyfile must be a byte string") @@ -588,8 +602,11 @@ class Context(object): :param cipher_list: A cipher list, see ciphers(1) :return: None """ + if isinstance(cipher_list, _text_type): + cipher_list = cipher_list.encode("ascii") + if not isinstance(cipher_list, bytes): - raise TypeError("cipher_list must be a byte string") + raise TypeError("cipher_list must be bytes or unicode") result = _lib.SSL_CTX_set_cipher_list(self._context, cipher_list) if not result: @@ -1402,3 +1419,7 @@ class Connection(object): _raise_current_error() ConnectionType = Connection + +# This is similar to the initialization calls at the end of OpenSSL/crypto.py +# but is exercised mostly by the Context initializer. +_lib.SSL_library_init() diff --git a/OpenSSL/__init__.py b/OpenSSL/__init__.py index 396a97f..db96e1f 100644 --- a/OpenSSL/__init__.py +++ b/OpenSSL/__init__.py @@ -10,6 +10,3 @@ from OpenSSL.version import __version__ __all__ = [ 'rand', 'crypto', 'SSL', 'tsafe', '__version__'] - - -# ERR_load_crypto_strings, SSL_library_init, ERR_load_SSL_strings diff --git a/OpenSSL/_util.py b/OpenSSL/_util.py index 7c606b9..baeecc6 100644 --- a/OpenSSL/_util.py +++ b/OpenSSL/_util.py @@ -6,15 +6,18 @@ ffi = binding.ffi lib = binding.lib def exception_from_error_queue(exceptionType): + def text(charp): + return native(ffi.string(charp)) + errors = [] while True: error = lib.ERR_get_error() if error == 0: break errors.append(( - ffi.string(lib.ERR_lib_error_string(error)), - ffi.string(lib.ERR_func_error_string(error)), - ffi.string(lib.ERR_reason_error_string(error)))) + text(lib.ERR_lib_error_string(error)), + text(lib.ERR_func_error_string(error)), + text(lib.ERR_reason_error_string(error)))) raise exceptionType(errors) diff --git a/OpenSSL/crypto.py b/OpenSSL/crypto.py index 0a9e8a2..d0026bd 100644 --- a/OpenSSL/crypto.py +++ b/OpenSSL/crypto.py @@ -1202,6 +1202,9 @@ def load_certificate(type, buffer): :return: The X509 object """ + if isinstance(buffer, _text_type): + buffer = buffer.encode("ascii") + bio = _new_mem_buf(buffer) if type == FILETYPE_PEM: @@ -1988,6 +1991,9 @@ def load_privatekey(type, buffer, passphrase=None): :return: The PKey object """ + if isinstance(buffer, _text_type): + buffer = buffer.encode("ascii") + bio = _new_mem_buf(buffer) helper = _PassphraseHelper(type, passphrase) @@ -2044,6 +2050,9 @@ def load_certificate_request(type, buffer): :param buffer: The buffer the certificate request is stored in :return: The X509Req object """ + if isinstance(buffer, _text_type): + buffer = buffer.encode("ascii") + bio = _new_mem_buf(buffer) if type == FILETYPE_PEM: @@ -2137,6 +2146,9 @@ def load_crl(type, buffer): :return: The PKey object """ + if isinstance(buffer, _text_type): + buffer = buffer.encode("ascii") + bio = _new_mem_buf(buffer) if type == FILETYPE_PEM: @@ -2163,6 +2175,9 @@ def load_pkcs7_data(type, buffer): :param buffer: The buffer with the pkcs7 data. :return: The PKCS7 object """ + if isinstance(buffer, _text_type): + buffer = buffer.encode("ascii") + bio = _new_mem_buf(buffer) if type == FILETYPE_PEM: @@ -2191,6 +2206,9 @@ def load_pkcs12(buffer, passphrase): :param passphrase: (Optional) The password to decrypt the PKCS12 lump :returns: The PKCS12 object """ + if isinstance(buffer, _text_type): + buffer = buffer.encode("ascii") + bio = _new_mem_buf(buffer) p12 = _lib.d2i_PKCS12_bio(bio, _ffi.NULL) @@ -2290,3 +2308,7 @@ else: # OpenSSL library (and is linked against the same one that cryptography is # using)). _lib.OpenSSL_add_all_algorithms() + +# This is similar but exercised mainly by exception_from_error_queue. It calls +# both ERR_load_crypto_strings() and ERR_load_SSL_strings(). +_lib.SSL_load_error_strings() diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py index 51da99b..4e42f70 100644 --- a/OpenSSL/test/test_crypto.py +++ b/OpenSSL/test/test_crypto.py @@ -1941,7 +1941,7 @@ class PKCS12Tests(TestCase): """ passwd = 'whatever' e = self.assertRaises(Error, load_pkcs12, b'fruit loops', passwd) - self.assertEqual( e.args[0][0][0], b'asn1 encoding routines') + self.assertEqual( e.args[0][0][0], 'asn1 encoding routines') self.assertEqual( len(e.args[0][0]), 3) diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 95cb538..e366b47 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -14,6 +14,8 @@ from os.path import join from unittest import main from weakref import ref +from six import u + from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM from OpenSSL.crypto import PKey, X509, X509Extension, X509Store from OpenSSL.crypto import dump_privatekey, load_privatekey @@ -960,11 +962,11 @@ class ContextTests(TestCase, _LoopbackMixin): # Write out the chain file. chainFile = self.mktemp() - fObj = open(chainFile, 'w') + fObj = open(chainFile, 'wb') # Most specific to least general. - fObj.write(dump_certificate(FILETYPE_PEM, scert).decode('ascii')) - fObj.write(dump_certificate(FILETYPE_PEM, icert).decode('ascii')) - fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii')) + fObj.write(dump_certificate(FILETYPE_PEM, scert)) + fObj.write(dump_certificate(FILETYPE_PEM, icert)) + fObj.write(dump_certificate(FILETYPE_PEM, cacert)) fObj.close() serverContext = Context(TLSv1_METHOD) @@ -1057,10 +1059,11 @@ class ContextTests(TestCase, _LoopbackMixin): # XXX What should I assert here? -exarkun - def test_set_cipher_list(self): + def test_set_cipher_list_bytes(self): """ - :py:obj:`Context.set_cipher_list` accepts a :py:obj:`str` naming the ciphers which - connections created with the context object will be able to choose from. + :py:obj:`Context.set_cipher_list` accepts a :py:obj:`bytes` naming the + ciphers which connections created with the context object will be able + to choose from. """ context = Context(TLSv1_METHOD) context.set_cipher_list(b"hello world:EXP-RC4-MD5") @@ -1068,11 +1071,23 @@ class ContextTests(TestCase, _LoopbackMixin): self.assertEquals(conn.get_cipher_list(), ["EXP-RC4-MD5"]) + def test_set_cipher_list_text(self): + """ + :py:obj:`Context.set_cipher_list` accepts a :py:obj:`unicode` naming + the ciphers which connections created with the context object will be + able to choose from. + """ + context = Context(TLSv1_METHOD) + context.set_cipher_list(u("hello world:EXP-RC4-MD5")) + conn = Connection(context, None) + self.assertEquals(conn.get_cipher_list(), ["EXP-RC4-MD5"]) + + def test_set_cipher_list_wrong_args(self): """ - :py:obj:`Context.set_cipher_list` raises :py:obj:`TypeError` when passed - zero arguments or more than one argument or when passed a non-byte - string single argument and raises :py:obj:`OpenSSL.SSL.Error` when + :py:obj:`Context.set_cipher_list` raises :py:obj:`TypeError` when + passed zero arguments or more than one argument or when passed a + non-string single argument and raises :py:obj:`OpenSSL.SSL.Error` when passed an incorrect cipher list string. """ context = Context(TLSv1_METHOD) @@ -1080,7 +1095,7 @@ class ContextTests(TestCase, _LoopbackMixin): self.assertRaises(TypeError, context.set_cipher_list, object()) self.assertRaises(TypeError, context.set_cipher_list, b"EXP-RC4-MD5", object()) - self.assertRaises(Error, context.set_cipher_list, b"imaginary-cipher") + self.assertRaises(Error, context.set_cipher_list, "imaginary-cipher") def test_set_session_cache_mode_wrong_args(self): |