diff options
author | Alex Gaynor <alex.gaynor@gmail.com> | 2020-07-23 20:40:46 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-23 19:40:46 -0500 |
commit | 037371861693f26297320dcd5fd8c221b6d8df26 (patch) | |
tree | ab18ca46617b0036e137cd6a154726acbab36bdf | |
parent | 4ca4fb9e8ed3c45f09efab8269e4078d40f39d9b (diff) | |
download | pyopenssl-git-037371861693f26297320dcd5fd8c221b6d8df26.tar.gz |
Paint it Black by the Rolling Stones (#920)
-rw-r--r-- | .travis.yml | 2 | ||||
-rw-r--r-- | doc/conf.py | 110 | ||||
-rw-r--r-- | leakcheck/context-info-callback.py | 18 | ||||
-rw-r--r-- | leakcheck/context-passphrase-callback.py | 10 | ||||
-rw-r--r-- | leakcheck/context-verify-callback.py | 29 | ||||
-rw-r--r-- | leakcheck/crypto.py | 54 | ||||
-rw-r--r-- | leakcheck/thread-crash.py | 18 | ||||
-rw-r--r-- | leakcheck/thread-key-gen.py | 3 | ||||
-rw-r--r-- | pyproject.toml | 4 | ||||
-rwxr-xr-x | setup.py | 75 | ||||
-rw-r--r-- | src/OpenSSL/SSL.py | 357 | ||||
-rw-r--r-- | src/OpenSSL/__init__.py | 24 | ||||
-rw-r--r-- | src/OpenSSL/_util.py | 20 | ||||
-rw-r--r-- | src/OpenSSL/crypto.py | 248 | ||||
-rw-r--r-- | src/OpenSSL/version.py | 10 | ||||
-rw-r--r-- | tests/conftest.py | 2 | ||||
-rw-r--r-- | tests/memdbg.py | 18 | ||||
-rw-r--r-- | tests/test_crypto.py | 722 | ||||
-rw-r--r-- | tests/test_rand.py | 8 | ||||
-rw-r--r-- | tests/test_ssl.py | 735 | ||||
-rw-r--r-- | tests/test_util.py | 1 | ||||
-rw-r--r-- | tests/util.py | 10 | ||||
-rw-r--r-- | tox.ini | 11 |
23 files changed, 1466 insertions, 1023 deletions
diff --git a/.travis.yml b/.travis.yml index daed6e5..48a74cc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -83,7 +83,7 @@ jobs: - python: "2.7" env: TOXENV=pypi-readme - - python: "2.7" + - python: "3.8" env: TOXENV=flake8 - python: "2.7" diff --git a/doc/conf.py b/doc/conf.py index 3940dd2..cb699c8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -32,8 +32,9 @@ def read_file(*parts): def find_version(*file_paths): version_file = read_file(*file_paths) - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - version_file, re.M) + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M + ) if version_match: return version_match.group(1) raise RuntimeError("Unable to find version string.") @@ -45,34 +46,34 @@ sys.path.insert(0, os.path.abspath(os.path.join(DOC_DIR, ".."))) # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '1.0' +needs_sphinx = "1.0" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ "sphinx.ext.autodoc", - 'sphinx.ext.intersphinx', + "sphinx.ext.intersphinx", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'pyOpenSSL' +project = u"pyOpenSSL" authors = u"The pyOpenSSL developers" copyright = u"2001 " + authors @@ -87,73 +88,74 @@ release = version # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if not on_rtd: # only import and set the theme if we're building docs locally import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' + + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # "<project> v<release> documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -162,96 +164,92 @@ if not on_rtd: # only import and set the theme if we're building docs locally # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a <link> tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'pyOpenSSLdoc' +htmlhelp_basename = "pyOpenSSLdoc" # -- Options for LaTeX output -------------------------------------------------- # The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' +# latex_paper_size = 'letter' # The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' +# latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'pyOpenSSL.tex', u'pyOpenSSL Documentation', - authors, 'manual'), + ("index", "pyOpenSSL.tex", u"pyOpenSSL Documentation", authors, "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Additional stuff for the LaTeX preamble. -#latex_preamble = '' +# latex_preamble = '' # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'pyopenssl', u'pyOpenSSL Documentation', - [authors], 1) -] +man_pages = [("index", "pyopenssl", u"pyOpenSSL Documentation", [authors], 1)] intersphinx_mapping = { "https://docs.python.org/3": None, diff --git a/leakcheck/context-info-callback.py b/leakcheck/context-info-callback.py index 6a3925c..b99adc2 100644 --- a/leakcheck/context-info-callback.py +++ b/leakcheck/context-info-callback.py @@ -29,7 +29,8 @@ cleartextPrivateKeyPEM = ( "0QwrX8nxFeTytr8pFGezj4a4KVCdb2B3CL+p3f70K7RIo9d/7b6frJI6ZL/LHQf2\n" "UP4pKRDkgKsVDx7MELECQGm072/Z7vmb03h/uE95IYJOgY4nfmYs0QKA9Is18wUz\n" "DpjfE33p0Ha6GO1VZRIQoqE24F8o5oimy3BEjryFuw4=\n" - "-----END RSA PRIVATE KEY-----\n") + "-----END RSA PRIVATE KEY-----\n" +) cleartextCertificatePEM = ( @@ -48,24 +49,31 @@ cleartextCertificatePEM = ( "q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5\n" "RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q\n" "UBj849/xpszEM7BhwKE0GiQ=\n" - "-----END CERTIFICATE-----\n") + "-----END CERTIFICATE-----\n" +) count = count() + + def go(): port = socket() - port.bind(('', 0)) + port.bind(("", 0)) port.listen(1) called = [] + def info(conn, where, ret): print count.next() called.append(None) + context = Context(TLSv1_METHOD) context.set_info_callback(info) context.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) context.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) while 1: client = socket() diff --git a/leakcheck/context-passphrase-callback.py b/leakcheck/context-passphrase-callback.py index ba71655..141ac8d 100644 --- a/leakcheck/context-passphrase-callback.py +++ b/leakcheck/context-passphrase-callback.py @@ -15,17 +15,23 @@ from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey k = PKey() k.generate_key(TYPE_RSA, 128) -file('pkey.pem', 'w').write(dump_privatekey(FILETYPE_PEM, k, "blowfish", "foobar")) +file("pkey.pem", "w").write( + dump_privatekey(FILETYPE_PEM, k, "blowfish", "foobar") +) count = count() + + def go(): def cb(a, b, c): print count.next() return "foobar" + c = Context(TLSv1_METHOD) c.set_passwd_cb(cb) while 1: - c.use_privatekey_file('pkey.pem') + c.use_privatekey_file("pkey.pem") + threads = [Thread(target=go, args=()) for i in xrange(2)] for th in threads: diff --git a/leakcheck/context-verify-callback.py b/leakcheck/context-verify-callback.py index 0ae586b..b9ce1d5 100644 --- a/leakcheck/context-verify-callback.py +++ b/leakcheck/context-verify-callback.py @@ -11,7 +11,13 @@ from itertools import count from threading import Thread from socket import socket -from OpenSSL.SSL import Context, TLSv1_METHOD, VERIFY_PEER, Connection, WantReadError +from OpenSSL.SSL import ( + Context, + TLSv1_METHOD, + VERIFY_PEER, + Connection, + WantReadError, +) from OpenSSL.crypto import FILETYPE_PEM, load_certificate, load_privatekey cleartextPrivateKeyPEM = ( @@ -29,7 +35,8 @@ cleartextPrivateKeyPEM = ( "0QwrX8nxFeTytr8pFGezj4a4KVCdb2B3CL+p3f70K7RIo9d/7b6frJI6ZL/LHQf2\n" "UP4pKRDkgKsVDx7MELECQGm072/Z7vmb03h/uE95IYJOgY4nfmYs0QKA9Is18wUz\n" "DpjfE33p0Ha6GO1VZRIQoqE24F8o5oimy3BEjryFuw4=\n" - "-----END RSA PRIVATE KEY-----\n") + "-----END RSA PRIVATE KEY-----\n" +) cleartextCertificatePEM = ( @@ -48,25 +55,32 @@ cleartextCertificatePEM = ( "q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5\n" "RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q\n" "UBj849/xpszEM7BhwKE0GiQ=\n" - "-----END CERTIFICATE-----\n") + "-----END CERTIFICATE-----\n" +) count = count() + + def go(): port = socket() - port.bind(('', 0)) + port.bind(("", 0)) port.listen(1) called = [] + def info(*args): print count.next() called.append(None) return 1 + context = Context(TLSv1_METHOD) context.set_verify(VERIFY_PEER, info) context.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) context.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) while 1: client = socket() @@ -86,7 +100,7 @@ def go(): while not called: for ssl in clientSSL, serverSSL: try: - ssl.send('foo') + ssl.send("foo") except WantReadError, e: pass @@ -96,4 +110,3 @@ for th in threads: th.start() for th in threads: th.join() - diff --git a/leakcheck/crypto.py b/leakcheck/crypto.py index ca79b7c..c995610 100644 --- a/leakcheck/crypto.py +++ b/leakcheck/crypto.py @@ -4,23 +4,31 @@ import sys from OpenSSL.crypto import ( - FILETYPE_PEM, TYPE_DSA, Error, PKey, X509, load_privatekey, CRL, Revoked, - get_elliptic_curves, _X509_REVOKED_dup) + FILETYPE_PEM, + TYPE_DSA, + Error, + PKey, + X509, + load_privatekey, + CRL, + Revoked, + get_elliptic_curves, + _X509_REVOKED_dup, +) from OpenSSL._util import lib as _lib - class BaseChecker(object): def __init__(self, iterations): self.iterations = iterations - class Checker_X509_get_pubkey(BaseChecker): """ Leak checks for L{X509.get_pubkey}. """ + def check_exception(self): """ Call the method repeatedly such that it will raise an exception. @@ -32,7 +40,6 @@ class Checker_X509_get_pubkey(BaseChecker): except Error: pass - def check_success(self): """ Call the method repeatedly such that it will return a PKey object. @@ -48,11 +55,11 @@ class Checker_X509_get_pubkey(BaseChecker): cert.get_pubkey() - class Checker_load_privatekey(BaseChecker): """ Leak checks for :py:obj:`load_privatekey`. """ + ENCRYPTED_PEM = """\ -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED @@ -67,14 +74,15 @@ FCB5K3c2kkTv2KjcCAimjxkE+SBKfHg35W0wB0AWkXpVFO5W/TbHg4tqtkpt/KMn /MPnSxvYr/vEqYMfW4Y83c45iqK0Cyr2pwY60lcn8Kk= -----END RSA PRIVATE KEY----- """ + def check_load_privatekey_callback(self): """ Call the function with an encrypted PEM and a passphrase callback. """ for i in xrange(self.iterations * 10): load_privatekey( - FILETYPE_PEM, self.ENCRYPTED_PEM, lambda *args: "hello, secret") - + FILETYPE_PEM, self.ENCRYPTED_PEM, lambda *args: "hello, secret" + ) def check_load_privatekey_callback_incorrect(self): """ @@ -84,12 +92,13 @@ FCB5K3c2kkTv2KjcCAimjxkE+SBKfHg35W0wB0AWkXpVFO5W/TbHg4tqtkpt/KMn for i in xrange(self.iterations * 10): try: load_privatekey( - FILETYPE_PEM, self.ENCRYPTED_PEM, - lambda *args: "hello, public") + FILETYPE_PEM, + self.ENCRYPTED_PEM, + lambda *args: "hello, public", + ) except Error: pass - def check_load_privatekey_callback_wrong_type(self): """ Call the function with an encrypted PEM and a passphrase callback which @@ -98,17 +107,17 @@ FCB5K3c2kkTv2KjcCAimjxkE+SBKfHg35W0wB0AWkXpVFO5W/TbHg4tqtkpt/KMn for i in xrange(self.iterations * 10): try: load_privatekey( - FILETYPE_PEM, self.ENCRYPTED_PEM, - lambda *args: {}) + FILETYPE_PEM, self.ENCRYPTED_PEM, lambda *args: {} + ) except ValueError: pass - class Checker_CRL(BaseChecker): """ Leak checks for L{CRL.add_revoked} and L{CRL.get_revoked}. """ + def check_add_revoked(self): """ Call the add_revoked method repeatedly on an empty CRL. @@ -116,7 +125,6 @@ class Checker_CRL(BaseChecker): for i in xrange(self.iterations * 200): CRL().add_revoked(Revoked()) - def check_get_revoked(self): """ Create a CRL object with 100 Revoked objects, then call the @@ -129,11 +137,11 @@ class Checker_CRL(BaseChecker): crl.get_revoked() - class Checker_X509_REVOKED_dup(BaseChecker): """ Leak checks for :py:obj:`_X509_REVOKED_dup`. """ + def check_X509_REVOKED_dup(self): """ Copy an empty Revoked object repeatedly. The copy is not garbage @@ -144,11 +152,11 @@ class Checker_X509_REVOKED_dup(BaseChecker): _lib.X509_REVOKED_free(revoked_copy) - class Checker_EllipticCurve(BaseChecker): """ Leak checks for :py:obj:`_EllipticCurve`. """ + def check_to_EC_KEY(self): """ Repeatedly create an EC_KEY* from an :py:obj:`_EllipticCurve`. The @@ -162,22 +170,22 @@ class Checker_EllipticCurve(BaseChecker): def vmsize(): - return [x for x in file('/proc/self/status').readlines() if 'VmSize' in x] + return [x for x in file("/proc/self/status").readlines() if "VmSize" in x] -def main(iterations='1000'): +def main(iterations="1000"): iterations = int(iterations) for klass in globals(): - if klass.startswith('Checker_'): + if klass.startswith("Checker_"): klass = globals()[klass] print klass checker = klass(iterations) for meth in dir(checker): - if meth.startswith('check_'): - print '\t', meth, vmsize(), '...', + if meth.startswith("check_"): + print "\t", meth, vmsize(), "...", getattr(checker, meth)() print vmsize() -if __name__ == '__main__': +if __name__ == "__main__": main(*sys.argv[1:]) diff --git a/leakcheck/thread-crash.py b/leakcheck/thread-crash.py index a1ebbdd..bd9426d 100644 --- a/leakcheck/thread-crash.py +++ b/leakcheck/thread-crash.py @@ -14,23 +14,24 @@ from threading import Thread from OpenSSL.SSL import Connection, Context, TLSv1_METHOD + def send(conn): while 1: for i in xrange(1024 * 32): - conn.send('x') - print 'Sent 32KB on', hex(id(conn)) + conn.send("x") + print "Sent 32KB on", hex(id(conn)) def recv(conn): while 1: for i in xrange(1024 * 64): conn.recv(1) - print 'Received 64KB on', hex(id(conn)) + print "Received 64KB on", hex(id(conn)) def main(): port = socket() - port.bind(('', 0)) + port.bind(("", 0)) port.listen(5) client = socket() @@ -41,15 +42,15 @@ def main(): server = port.accept()[0] clientCtx = Context(TLSv1_METHOD) - clientCtx.set_cipher_list('ALL:ADH') - clientCtx.load_tmp_dh('dhparam.pem') + clientCtx.set_cipher_list("ALL:ADH") + clientCtx.load_tmp_dh("dhparam.pem") sslClient = Connection(clientCtx, client) sslClient.set_connect_state() serverCtx = Context(TLSv1_METHOD) - serverCtx.set_cipher_list('ALL:ADH') - serverCtx.load_tmp_dh('dhparam.pem') + serverCtx.set_cipher_list("ALL:ADH") + serverCtx.load_tmp_dh("dhparam.pem") sslServer = Connection(serverCtx, server) sslServer.set_accept_state() @@ -68,4 +69,5 @@ def main(): t3.join() t4.join() + main() diff --git a/leakcheck/thread-key-gen.py b/leakcheck/thread-key-gen.py index 62e1a58..346ad7b 100644 --- a/leakcheck/thread-key-gen.py +++ b/leakcheck/thread-key-gen.py @@ -9,6 +9,7 @@ from threading import Thread from OpenSSL.crypto import TYPE_RSA, TYPE_DSA, PKey + def generate_rsa(): keys = [] for i in range(100): @@ -16,6 +17,7 @@ def generate_rsa(): key.generate_key(TYPE_RSA, 1024) keys.append(key) + def generate_dsa(): keys = [] for i in range(100): @@ -35,4 +37,5 @@ def main(): for t in threads: t.start() + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ff6e2bb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.black] +line-length = 79 +target-version = ["py27"] + @@ -36,8 +36,7 @@ def find_meta(meta): Extract __*meta*__ from META_FILE. """ meta_match = re.search( - r"^__{meta}__ = ['\"]([^'\"]*)['\"]".format(meta=meta), - META_FILE, re.M + r"^__{meta}__ = ['\"]([^'\"]*)['\"]".format(meta=meta), META_FILE, re.M ) if meta_match: return meta_match.group(1) @@ -46,13 +45,17 @@ def find_meta(meta): URI = find_meta("uri") LONG = ( - read_file("README.rst") + "\n\n" + - "Release Information\n" + - "===================\n\n" + - re.search(r"(\d{2}.\d.\d \(.*?\)\n.*?)\n\n\n----\n", - read_file("CHANGELOG.rst"), re.S).group(1) + - "\n\n`Full changelog " + - "<{uri}en/stable/changelog.html>`_.\n\n" + read_file("README.rst") + + "\n\n" + + "Release Information\n" + + "===================\n\n" + + re.search( + r"(\d{2}.\d.\d \(.*?\)\n.*?)\n\n\n----\n", + read_file("CHANGELOG.rst"), + re.S, + ).group(1) + + "\n\n`Full changelog " + + "<{uri}en/stable/changelog.html>`_.\n\n" ).format(uri=URI) @@ -67,45 +70,35 @@ if __name__ == "__main__": url=URI, license=find_meta("license"), classifiers=[ - 'Development Status :: 6 - Mature', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX', - - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Security :: Cryptography', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: System :: Networking', + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Security :: Cryptography", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: System :: Networking", ], - python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', - + python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*", packages=find_packages(where="src"), package_dir={"": "src"}, install_requires=[ # Fix cryptographyMinimum in tox.ini when changing this! "cryptography>=2.8", - "six>=1.5.2" + "six>=1.5.2", ], extras_require={ - "test": [ - "flaky", - "pretend", - "pytest>=3.0.1", - ], - "docs": [ - "sphinx", - "sphinx_rtd_theme", - ] + "test": ["flaky", "pretend", "pytest>=3.0.1"], + "docs": ["sphinx", "sphinx_rtd_theme"], }, ) diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 25308f1..b4b308f 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -23,100 +23,108 @@ from OpenSSL._util import ( ) from OpenSSL.crypto import ( - FILETYPE_PEM, _PassphraseHelper, PKey, X509Name, X509, X509Store) + FILETYPE_PEM, + _PassphraseHelper, + PKey, + X509Name, + X509, + X509Store, +) __all__ = [ - 'OPENSSL_VERSION_NUMBER', - 'SSLEAY_VERSION', - 'SSLEAY_CFLAGS', - 'SSLEAY_PLATFORM', - 'SSLEAY_DIR', - 'SSLEAY_BUILT_ON', - 'SENT_SHUTDOWN', - 'RECEIVED_SHUTDOWN', - 'SSLv2_METHOD', - 'SSLv3_METHOD', - 'SSLv23_METHOD', - 'TLSv1_METHOD', - 'TLSv1_1_METHOD', - 'TLSv1_2_METHOD', - 'OP_NO_SSLv2', - 'OP_NO_SSLv3', - 'OP_NO_TLSv1', - 'OP_NO_TLSv1_1', - 'OP_NO_TLSv1_2', - 'OP_NO_TLSv1_3', - 'MODE_RELEASE_BUFFERS', - 'OP_SINGLE_DH_USE', - 'OP_SINGLE_ECDH_USE', - 'OP_EPHEMERAL_RSA', - 'OP_MICROSOFT_SESS_ID_BUG', - 'OP_NETSCAPE_CHALLENGE_BUG', - 'OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG', - 'OP_SSLREF2_REUSE_CERT_TYPE_BUG', - 'OP_MICROSOFT_BIG_SSLV3_BUFFER', - 'OP_MSIE_SSLV2_RSA_PADDING', - 'OP_SSLEAY_080_CLIENT_DH_BUG', - 'OP_TLS_D5_BUG', - 'OP_TLS_BLOCK_PADDING_BUG', - 'OP_DONT_INSERT_EMPTY_FRAGMENTS', - 'OP_CIPHER_SERVER_PREFERENCE', - 'OP_TLS_ROLLBACK_BUG', - 'OP_PKCS1_CHECK_1', - 'OP_PKCS1_CHECK_2', - 'OP_NETSCAPE_CA_DN_BUG', - 'OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG', - 'OP_NO_COMPRESSION', - 'OP_NO_QUERY_MTU', - 'OP_COOKIE_EXCHANGE', - 'OP_NO_TICKET', - 'OP_ALL', - 'VERIFY_PEER', - 'VERIFY_FAIL_IF_NO_PEER_CERT', - 'VERIFY_CLIENT_ONCE', - 'VERIFY_NONE', - 'SESS_CACHE_OFF', - 'SESS_CACHE_CLIENT', - 'SESS_CACHE_SERVER', - 'SESS_CACHE_BOTH', - 'SESS_CACHE_NO_AUTO_CLEAR', - 'SESS_CACHE_NO_INTERNAL_LOOKUP', - 'SESS_CACHE_NO_INTERNAL_STORE', - 'SESS_CACHE_NO_INTERNAL', - 'SSL_ST_CONNECT', - 'SSL_ST_ACCEPT', - 'SSL_ST_MASK', - 'SSL_CB_LOOP', - 'SSL_CB_EXIT', - 'SSL_CB_READ', - 'SSL_CB_WRITE', - 'SSL_CB_ALERT', - 'SSL_CB_READ_ALERT', - 'SSL_CB_WRITE_ALERT', - 'SSL_CB_ACCEPT_LOOP', - 'SSL_CB_ACCEPT_EXIT', - 'SSL_CB_CONNECT_LOOP', - 'SSL_CB_CONNECT_EXIT', - 'SSL_CB_HANDSHAKE_START', - 'SSL_CB_HANDSHAKE_DONE', - 'Error', - 'WantReadError', - 'WantWriteError', - 'WantX509LookupError', - 'ZeroReturnError', - 'SysCallError', - 'SSLeay_version', - 'Session', - 'Context', - 'Connection' + "OPENSSL_VERSION_NUMBER", + "SSLEAY_VERSION", + "SSLEAY_CFLAGS", + "SSLEAY_PLATFORM", + "SSLEAY_DIR", + "SSLEAY_BUILT_ON", + "SENT_SHUTDOWN", + "RECEIVED_SHUTDOWN", + "SSLv2_METHOD", + "SSLv3_METHOD", + "SSLv23_METHOD", + "TLSv1_METHOD", + "TLSv1_1_METHOD", + "TLSv1_2_METHOD", + "OP_NO_SSLv2", + "OP_NO_SSLv3", + "OP_NO_TLSv1", + "OP_NO_TLSv1_1", + "OP_NO_TLSv1_2", + "OP_NO_TLSv1_3", + "MODE_RELEASE_BUFFERS", + "OP_SINGLE_DH_USE", + "OP_SINGLE_ECDH_USE", + "OP_EPHEMERAL_RSA", + "OP_MICROSOFT_SESS_ID_BUG", + "OP_NETSCAPE_CHALLENGE_BUG", + "OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG", + "OP_SSLREF2_REUSE_CERT_TYPE_BUG", + "OP_MICROSOFT_BIG_SSLV3_BUFFER", + "OP_MSIE_SSLV2_RSA_PADDING", + "OP_SSLEAY_080_CLIENT_DH_BUG", + "OP_TLS_D5_BUG", + "OP_TLS_BLOCK_PADDING_BUG", + "OP_DONT_INSERT_EMPTY_FRAGMENTS", + "OP_CIPHER_SERVER_PREFERENCE", + "OP_TLS_ROLLBACK_BUG", + "OP_PKCS1_CHECK_1", + "OP_PKCS1_CHECK_2", + "OP_NETSCAPE_CA_DN_BUG", + "OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG", + "OP_NO_COMPRESSION", + "OP_NO_QUERY_MTU", + "OP_COOKIE_EXCHANGE", + "OP_NO_TICKET", + "OP_ALL", + "VERIFY_PEER", + "VERIFY_FAIL_IF_NO_PEER_CERT", + "VERIFY_CLIENT_ONCE", + "VERIFY_NONE", + "SESS_CACHE_OFF", + "SESS_CACHE_CLIENT", + "SESS_CACHE_SERVER", + "SESS_CACHE_BOTH", + "SESS_CACHE_NO_AUTO_CLEAR", + "SESS_CACHE_NO_INTERNAL_LOOKUP", + "SESS_CACHE_NO_INTERNAL_STORE", + "SESS_CACHE_NO_INTERNAL", + "SSL_ST_CONNECT", + "SSL_ST_ACCEPT", + "SSL_ST_MASK", + "SSL_CB_LOOP", + "SSL_CB_EXIT", + "SSL_CB_READ", + "SSL_CB_WRITE", + "SSL_CB_ALERT", + "SSL_CB_READ_ALERT", + "SSL_CB_WRITE_ALERT", + "SSL_CB_ACCEPT_LOOP", + "SSL_CB_ACCEPT_EXIT", + "SSL_CB_CONNECT_LOOP", + "SSL_CB_CONNECT_EXIT", + "SSL_CB_HANDSHAKE_START", + "SSL_CB_HANDSHAKE_DONE", + "Error", + "WantReadError", + "WantWriteError", + "WantX509LookupError", + "ZeroReturnError", + "SysCallError", + "SSLeay_version", + "Session", + "Context", + "Connection", ] try: _buffer = buffer except NameError: + class _buffer(object): pass + OPENSSL_VERSION_NUMBER = _lib.OPENSSL_VERSION_NUMBER SSLEAY_VERSION = _lib.SSLEAY_VERSION SSLEAY_CFLAGS = _lib.SSLEAY_CFLAGS @@ -199,12 +207,9 @@ if _lib.Cryptography_HAS_SSL_ST: SSL_ST_BEFORE = _lib.SSL_ST_BEFORE SSL_ST_OK = _lib.SSL_ST_OK SSL_ST_RENEGOTIATE = _lib.SSL_ST_RENEGOTIATE - __all__.extend([ - 'SSL_ST_INIT', - 'SSL_ST_BEFORE', - 'SSL_ST_OK', - 'SSL_ST_RENEGOTIATE', - ]) + __all__.extend( + ["SSL_ST_INIT", "SSL_ST_BEFORE", "SSL_ST_OK", "SSL_ST_RENEGOTIATE"] + ) SSL_CB_LOOP = _lib.SSL_CB_LOOP SSL_CB_EXIT = _lib.SSL_CB_EXIT @@ -333,7 +338,8 @@ class _VerifyHelper(_CallbackExceptionHelper): return 0 self.callback = _ffi.callback( - "int (*)(int, X509_STORE_CTX *)", wrapper) + "int (*)(int, X509_STORE_CTX *)", wrapper + ) class _NpnAdvertiseHelper(_CallbackExceptionHelper): @@ -352,7 +358,7 @@ class _NpnAdvertiseHelper(_CallbackExceptionHelper): # Join the protocols into a Python bytestring, length-prefixing # each element. - protostr = b''.join( + protostr = b"".join( chain.from_iterable((int2byte(len(p)), p) for p in protos) ) @@ -373,7 +379,7 @@ class _NpnAdvertiseHelper(_CallbackExceptionHelper): self.callback = _ffi.callback( "int (*)(SSL *, const unsigned char **, unsigned int *, void *)", - wrapper + wrapper, ) @@ -397,9 +403,9 @@ class _NpnSelectHelper(_CallbackExceptionHelper): protolist = [] while instr: length = indexbytes(instr, 0) - proto = instr[1:length + 1] + proto = instr[1 : length + 1] protolist.append(proto) - instr = instr[length + 1:] + instr = instr[length + 1 :] # Call the callback outstr = callback(conn, protolist) @@ -420,9 +426,11 @@ class _NpnSelectHelper(_CallbackExceptionHelper): return 2 # SSL_TLSEXT_ERR_ALERT_FATAL self.callback = _ffi.callback( - ("int (*)(SSL *, unsigned char **, unsigned char *, " - "const unsigned char *, unsigned int, void *)"), - wrapper + ( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)" + ), + wrapper, ) @@ -449,15 +457,15 @@ class _ALPNSelectHelper(_CallbackExceptionHelper): protolist = [] while instr: encoded_len = indexbytes(instr, 0) - proto = instr[1:encoded_len + 1] + proto = instr[1 : encoded_len + 1] protolist.append(proto) - instr = instr[encoded_len + 1:] + instr = instr[encoded_len + 1 :] # Call the callback outbytes = callback(conn, protolist) any_accepted = True if outbytes is NO_OVERLAPPING_PROTOCOLS: - outbytes = b'' + outbytes = b"" any_accepted = False elif not isinstance(outbytes, bytes): raise TypeError( @@ -482,9 +490,11 @@ class _ALPNSelectHelper(_CallbackExceptionHelper): return _lib.SSL_TLSEXT_ERR_ALERT_FATAL self.callback = _ffi.callback( - ("int (*)(SSL *, unsigned char **, unsigned char *, " - "const unsigned char *, unsigned int, void *)"), - wrapper + ( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)" + ), + wrapper, ) @@ -596,7 +606,7 @@ class _OCSPClientCallbackHelper(_CallbackExceptionHelper): ocsp_len = _lib.SSL_get_tlsext_status_ocsp_resp(ssl, ocsp_ptr) if ocsp_len < 0: # No OCSP data. - ocsp_data = b'' + ocsp_data = b"" else: # Copy the OCSP data, then pass it to the callback. ocsp_data = _ffi.buffer(ocsp_ptr[0], ocsp_len)[:] @@ -628,7 +638,8 @@ def _asFileDescriptor(obj): raise TypeError("argument must be an int, or have a fileno() method.") elif fd < 0: raise ValueError( - "file descriptor cannot be a negative integer (%i)" % (fd,)) + "file descriptor cannot be a negative integer (%i)" % (fd,) + ) return fd @@ -643,8 +654,11 @@ def SSLeay_version(type): def _warn_npn(): - warnings.warn("NPN is deprecated. Protocols should switch to using ALPN.", - DeprecationWarning, stacklevel=3) + warnings.warn( + "NPN is deprecated. Protocols should switch to using ALPN.", + DeprecationWarning, + stacklevel=3, + ) def _make_requires(flag, error): @@ -657,11 +671,14 @@ def _make_requires(flag, error): ``Cryptography_HAS_NEXTPROTONEG``. :param error: The string to be used in the exception if the flag is false. """ + def _requires_decorator(func): if not flag: + @wraps(func) def explode(*args, **kwargs): raise NotImplementedError(error) + return explode else: return func @@ -687,6 +704,7 @@ class Session(object): .. versionadded:: 0.14 """ + pass @@ -698,6 +716,7 @@ class Context(object): :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD. """ + _methods = { SSLv2_METHOD: "SSLv2_method", SSLv3_METHOD: "SSLv3_method", @@ -709,7 +728,8 @@ class Context(object): _methods = dict( (identifier, getattr(_lib, name)) for (identifier, name) in _methods.items() - if getattr(_lib, name, None) is not None) + if getattr(_lib, name, None) is not None + ) def __init__(self, method): if not isinstance(method, integer_types): @@ -790,8 +810,10 @@ class Context(object): @wraps(callback) def wrapper(size, verify, userdata): return callback(size, verify, self._passphrase_userdata) + return _PassphraseHelper( - FILETYPE_PEM, wrapper, more_args=True, truncate=True) + FILETYPE_PEM, wrapper, more_args=True, truncate=True + ) def set_passwd_cb(self, callback, userdata=None): """ @@ -818,7 +840,8 @@ class Context(object): self._passphrase_helper = self._wrap_callback(callback) self._passphrase_callback = self._passphrase_helper.callback _lib.SSL_CTX_set_default_passwd_cb( - self._context, self._passphrase_callback) + self._context, self._passphrase_callback + ) self._passphrase_userdata = userdata def set_default_verify_paths(self): @@ -848,9 +871,9 @@ class Context(object): # First we'll check to see if any env vars have been set. If so, # we won't try to do anything else because the user has set the path # themselves. - dir_env_var = _ffi.string( - _lib.X509_get_default_cert_dir_env() - ).decode("ascii") + dir_env_var = _ffi.string(_lib.X509_get_default_cert_dir_env()).decode( + "ascii" + ) file_env_var = _ffi.string( _lib.X509_get_default_cert_file_env() ).decode("ascii") @@ -861,13 +884,12 @@ class Context(object): # to the exact values we use in our manylinux1 builds. If they are # then we know to load the fallbacks if ( - default_dir == _CRYPTOGRAPHY_MANYLINUX1_CA_DIR and - default_file == _CRYPTOGRAPHY_MANYLINUX1_CA_FILE + default_dir == _CRYPTOGRAPHY_MANYLINUX1_CA_DIR + and default_file == _CRYPTOGRAPHY_MANYLINUX1_CA_FILE ): # This is manylinux1, let's load our fallback paths self._fallback_default_verify_paths( - _CERTIFICATE_FILE_LOCATIONS, - _CERTIFICATE_PATH_LOCATIONS + _CERTIFICATE_FILE_LOCATIONS, _CERTIFICATE_PATH_LOCATIONS ) def _check_env_vars_set(self, dir_env_var, file_env_var): @@ -877,8 +899,8 @@ class Context(object): :return: bool """ return ( - os.environ.get(file_env_var) is not None or - os.environ.get(dir_env_var) is not None + os.environ.get(file_env_var) is not None + or os.environ.get(dir_env_var) is not None ) def _fallback_default_verify_paths(self, file_path, dir_path): @@ -996,7 +1018,8 @@ class Context(object): raise TypeError("filetype must be an integer") use_result = _lib.SSL_CTX_use_PrivateKey_file( - self._context, keyfile, filetype) + self._context, keyfile, filetype + ) if not use_result: self._raise_passphrase_exception() @@ -1052,11 +1075,8 @@ class Context(object): """ buf = _text_to_bytes_and_warn("buf", buf) _openssl_assert( - _lib.SSL_CTX_set_session_id_context( - self._context, - buf, - len(buf), - ) == 1 + _lib.SSL_CTX_set_session_id_context(self._context, buf, len(buf),) + == 1 ) def set_session_cache_mode(self, mode): @@ -1202,19 +1222,17 @@ class Context(object): # invalid cipher string is passed, but without the following check # for the TLS 1.3 specific cipher suites it would never error. tmpconn = Connection(self, None) - if ( - tmpconn.get_cipher_list() == [ - 'TLS_AES_256_GCM_SHA384', - 'TLS_CHACHA20_POLY1305_SHA256', - 'TLS_AES_128_GCM_SHA256' - ] - ): + if tmpconn.get_cipher_list() == [ + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_128_GCM_SHA256", + ]: raise Error( [ ( - 'SSL routines', - 'SSL_CTX_set_cipher_list', - 'no cipher match', + "SSL routines", + "SSL_CTX_set_cipher_list", + "no cipher match", ), ], ) @@ -1240,9 +1258,7 @@ class Context(object): if not isinstance(ca_name, X509Name): raise TypeError( "client CAs must be X509Name objects, not %s " - "objects" % ( - type(ca_name).__name__, - ) + "objects" % (type(ca_name).__name__,) ) copy = _lib.X509_NAME_dup(ca_name._name) _openssl_assert(copy != _ffi.NULL) @@ -1273,7 +1289,8 @@ class Context(object): raise TypeError("certificate_authority must be an X509 instance") add_result = _lib.SSL_CTX_add_client_CA( - self._context, certificate_authority._x509) + self._context, certificate_authority._x509 + ) _openssl_assert(add_result == 1) def set_timeout(self, timeout): @@ -1311,11 +1328,14 @@ class Context(object): function call. :return: None """ + @wraps(callback) def wrapper(ssl, where, return_code): callback(Connection._reverse_mapping[ssl], where, return_code) + self._info_callback = _ffi.callback( - "void (*)(const SSL *, int, int)", wrapper) + "void (*)(const SSL *, int, int)", wrapper + ) _lib.SSL_CTX_set_info_callback(self._context, self._info_callback) def get_app_data(self): @@ -1388,15 +1408,18 @@ class Context(object): .. versionadded:: 0.13 """ + @wraps(callback) def wrapper(ssl, alert, arg): callback(Connection._reverse_mapping[ssl]) return 0 self._tlsext_servername_callback = _ffi.callback( - "int (*)(SSL *, int *, void *)", wrapper) + "int (*)(SSL *, int *, void *)", wrapper + ) _lib.SSL_CTX_set_tlsext_servername_callback( - self._context, self._tlsext_servername_callback) + self._context, self._tlsext_servername_callback + ) def set_tlsext_use_srtp(self, profiles): """ @@ -1431,7 +1454,8 @@ class Context(object): self._npn_advertise_helper = _NpnAdvertiseHelper(callback) self._npn_advertise_callback = self._npn_advertise_helper.callback _lib.SSL_CTX_set_next_protos_advertised_cb( - self._context, self._npn_advertise_callback, _ffi.NULL) + self._context, self._npn_advertise_callback, _ffi.NULL + ) @_requires_npn def set_npn_select_callback(self, callback): @@ -1450,7 +1474,8 @@ class Context(object): self._npn_select_helper = _NpnSelectHelper(callback) self._npn_select_callback = self._npn_select_helper.callback _lib.SSL_CTX_set_next_proto_select_cb( - self._context, self._npn_select_callback, _ffi.NULL) + self._context, self._npn_select_callback, _ffi.NULL + ) @_requires_alpn def set_alpn_protos(self, protos): @@ -1465,7 +1490,7 @@ class Context(object): """ # Take the list of protocols and join them together, prefixing them # with their lengths. - protostr = b''.join( + protostr = b"".join( chain.from_iterable((int2byte(len(p)), p) for p in protos) ) @@ -1492,7 +1517,8 @@ class Context(object): self._alpn_select_helper = _ALPNSelectHelper(callback) self._alpn_select_callback = self._alpn_select_helper.callback _lib.SSL_CTX_set_alpn_select_cb( - self._context, self._alpn_select_callback, _ffi.NULL) + self._context, self._alpn_select_callback, _ffi.NULL + ) def _set_ocsp_callback(self, helper, data): """ @@ -1556,6 +1582,7 @@ class Context(object): class Connection(object): """ """ + _reverse_mapping = WeakValueDictionary() def __init__(self, context, socket=None): @@ -1609,7 +1636,8 @@ class Connection(object): self._from_ssl = None self._socket = socket set_result = _lib.SSL_set_fd( - self._ssl, _asFileDescriptor(self._socket)) + self._ssl, _asFileDescriptor(self._socket) + ) _openssl_assert(set_result == 1) def __getattr__(self, name): @@ -1618,9 +1646,10 @@ class Connection(object): on the Connection object. """ if self._socket is None: - raise AttributeError("'%s' object has no attribute '%s'" % ( - self.__class__.__name__, name - )) + raise AttributeError( + "'%s' object has no attribute '%s'" + % (self.__class__.__name__, name) + ) else: return getattr(self._socket, name) @@ -1777,9 +1806,7 @@ class Connection(object): # SSL_write's num arg is an int, # so we cannot send more than 2**31-1 bytes at once. result = _lib.SSL_write( - self._ssl, - data + total_sent, - min(left_to_send, 2147483647) + self._ssl, data + total_sent, min(left_to_send, 2147483647) ) self._raise_ssl_error(self._ssl, result) total_sent += result @@ -1803,6 +1830,7 @@ class Connection(object): result = _lib.SSL_read(self._ssl, buf, bufsiz) self._raise_ssl_error(self._ssl, result) return _ffi.buffer(buf, result)[:] + read = recv def recv_into(self, buffer, nbytes=None, flags=None): @@ -2069,7 +2097,8 @@ class Connection(object): :raise: NotImplementedError """ raise NotImplementedError( - "Cannot make file object of OpenSSL.SSL.Connection") + "Cannot make file object of OpenSSL.SSL.Connection" + ) def get_app_data(self): """ @@ -2182,10 +2211,16 @@ class Connection(object): context_buf = context context_len = len(context) use_context = 1 - success = _lib.SSL_export_keying_material(self._ssl, outp, olen, - label, len(label), - context_buf, context_len, - use_context) + success = _lib.SSL_export_keying_material( + self._ssl, + outp, + olen, + label, + len(label), + context_buf, + context_len, + use_context, + ) _openssl_assert(success == 1) return _ffi.buffer(outp, olen)[:] @@ -2470,7 +2505,7 @@ class Connection(object): """ # Take the list of protocols and join them together, prefixing them # with their lengths. - protostr = b''.join( + protostr = b"".join( chain.from_iterable((int2byte(len(p)), p) for p in protos) ) @@ -2493,7 +2528,7 @@ class Connection(object): _lib.SSL_get0_alpn_selected(self._ssl, data, data_len) if not data_len: - return b'' + return b"" return _ffi.buffer(data[0], data_len[0])[:] diff --git a/src/OpenSSL/__init__.py b/src/OpenSSL/__init__.py index 810d00d..11e896a 100644 --- a/src/OpenSSL/__init__.py +++ b/src/OpenSSL/__init__.py @@ -7,14 +7,26 @@ pyOpenSSL - A simple wrapper around the OpenSSL library from OpenSSL import crypto, SSL from OpenSSL.version import ( - __author__, __copyright__, __email__, __license__, __summary__, __title__, - __uri__, __version__, + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, ) __all__ = [ - "SSL", "crypto", - - "__author__", "__copyright__", "__email__", "__license__", "__summary__", - "__title__", "__uri__", "__version__", + "SSL", + "crypto", + "__author__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__version__", ] diff --git a/src/OpenSSL/_util.py b/src/OpenSSL/_util.py index 9f2d724..1beefe6 100644 --- a/src/OpenSSL/_util.py +++ b/src/OpenSSL/_util.py @@ -46,10 +46,13 @@ def exception_from_error_queue(exception_type): error = lib.ERR_get_error() if error == 0: break - errors.append(( - text(lib.ERR_lib_error_string(error)), - text(lib.ERR_func_error_string(error)), - text(lib.ERR_reason_error_string(error)))) + errors.append( + ( + text(lib.ERR_lib_error_string(error)), + text(lib.ERR_func_error_string(error)), + text(lib.ERR_reason_error_string(error)), + ) + ) raise exception_type(errors) @@ -59,6 +62,7 @@ def make_assert(error): Create an assert function that uses :func:`exception_from_error_queue` to raise an exception wrapped by *error*. """ + def openssl_assert(ok): """ If *ok* is not True, retrieve the error from OpenSSL and raise it. @@ -108,9 +112,13 @@ def path_string(s): if PY2: + def byte_string(s): return s + + else: + def byte_string(s): return s.encode("charmap") @@ -141,9 +149,9 @@ def text_to_bytes_and_warn(label, obj): warnings.warn( _TEXT_WARNING.format(label), category=DeprecationWarning, - stacklevel=3 + stacklevel=3, ) - return obj.encode('utf-8') + return obj.encode("utf-8") return obj diff --git a/src/OpenSSL/crypto.py b/src/OpenSSL/crypto.py index 30dd478..0744ca7 100644 --- a/src/OpenSSL/crypto.py +++ b/src/OpenSSL/crypto.py @@ -7,7 +7,8 @@ from operator import __eq__, __ne__, __lt__, __le__, __gt__, __ge__ from six import ( integer_types as _integer_types, text_type as _text_type, - PY2 as _PY2) + PY2 as _PY2, +) from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, rsa @@ -24,42 +25,42 @@ from OpenSSL._util import ( ) __all__ = [ - 'FILETYPE_PEM', - 'FILETYPE_ASN1', - 'FILETYPE_TEXT', - 'TYPE_RSA', - 'TYPE_DSA', - 'Error', - 'PKey', - 'get_elliptic_curves', - 'get_elliptic_curve', - 'X509Name', - 'X509Extension', - 'X509Req', - 'X509', - 'X509StoreFlags', - 'X509Store', - 'X509StoreContextError', - 'X509StoreContext', - 'load_certificate', - 'dump_certificate', - 'dump_publickey', - 'dump_privatekey', - 'Revoked', - 'CRL', - 'PKCS7', - 'PKCS12', - 'NetscapeSPKI', - 'load_publickey', - 'load_privatekey', - 'dump_certificate_request', - 'load_certificate_request', - 'sign', - 'verify', - 'dump_crl', - 'load_crl', - 'load_pkcs7_data', - 'load_pkcs12' + "FILETYPE_PEM", + "FILETYPE_ASN1", + "FILETYPE_TEXT", + "TYPE_RSA", + "TYPE_DSA", + "Error", + "PKey", + "get_elliptic_curves", + "get_elliptic_curve", + "X509Name", + "X509Extension", + "X509Req", + "X509", + "X509StoreFlags", + "X509Store", + "X509StoreContextError", + "X509StoreContext", + "load_certificate", + "dump_certificate", + "dump_publickey", + "dump_privatekey", + "Revoked", + "CRL", + "PKCS7", + "PKCS12", + "NetscapeSPKI", + "load_publickey", + "load_privatekey", + "dump_certificate_request", + "load_certificate_request", + "sign", + "verify", + "dump_crl", + "load_crl", + "load_pkcs7_data", + "load_pkcs12", ] FILETYPE_PEM = _lib.SSL_FILETYPE_PEM @@ -93,6 +94,7 @@ def _get_backend(): triggering this side effect unless _get_backend is called. """ from cryptography.hazmat.backends.openssl.backend import backend + return backend @@ -135,7 +137,7 @@ def _bio_to_string(bio): """ Copy the contents of an OpenSSL BIO object into a Python byte string. """ - result_buffer = _ffi.new('char**') + result_buffer = _ffi.new("char**") buffer_length = _lib.BIO_get_mem_data(bio, result_buffer) return _ffi.buffer(result_buffer[0], buffer_length)[:] @@ -172,7 +174,7 @@ def _get_asn1_time(timestamp): @return: The time value from C{timestamp} as a L{bytes} string in a certain format. Or C{None} if the object contains no time value. """ - string_timestamp = _ffi.cast('ASN1_STRING*', timestamp) + string_timestamp = _ffi.cast("ASN1_STRING*", timestamp) if _lib.ASN1_STRING_length(string_timestamp) == 0: return None elif ( @@ -195,7 +197,8 @@ def _get_asn1_time(timestamp): _untested_error("ASN1_TIME_to_generalizedtime") else: string_timestamp = _ffi.cast( - "ASN1_STRING*", generalized_timestamp[0]) + "ASN1_STRING*", generalized_timestamp[0] + ) string_data = _lib.ASN1_STRING_data(string_timestamp) string_result = _ffi.string(string_data) _lib.ASN1_GENERALIZEDTIME_free(generalized_timestamp[0]) @@ -219,6 +222,7 @@ class PKey(object): """ A class representing an DSA or RSA public key or key pair. """ + _only_public = False _initialized = True @@ -257,8 +261,15 @@ class PKey(object): .. versionadded:: 16.1.0 """ pkey = cls() - if not isinstance(crypto_key, (rsa.RSAPublicKey, rsa.RSAPrivateKey, - dsa.DSAPublicKey, dsa.DSAPrivateKey)): + if not isinstance( + crypto_key, + ( + rsa.RSAPublicKey, + rsa.RSAPrivateKey, + dsa.DSAPublicKey, + dsa.DSAPrivateKey, + ), + ): raise TypeError("Unsupported key type") pkey._pkey = crypto_key._evp_pkey @@ -375,6 +386,7 @@ class _EllipticCurve(object): instances each of which represents one curve supported by the system. @type _curves: :py:type:`NoneType` or :py:type:`set` """ + _curves = None if not _PY2: @@ -401,14 +413,12 @@ class _EllipticCurve(object): elliptic curves the underlying library supports. """ num_curves = lib.EC_get_builtin_curves(_ffi.NULL, 0) - builtin_curves = _ffi.new('EC_builtin_curve[]', num_curves) + builtin_curves = _ffi.new("EC_builtin_curve[]", num_curves) # The return value on this call should be num_curves again. We # could check it to make sure but if it *isn't* then.. what could # we do? Abort the whole process, I suppose...? -exarkun lib.EC_get_builtin_curves(builtin_curves, num_curves) - return set( - cls.from_nid(lib, c.nid) - for c in builtin_curves) + return set(cls.from_nid(lib, c.nid) for c in builtin_curves) @classmethod def _get_elliptic_curves(cls, lib): @@ -541,14 +551,16 @@ class X509Name(object): self._name = _ffi.gc(name, _lib.X509_NAME_free) def __setattr__(self, name, value): - if name.startswith('_'): + if name.startswith("_"): return super(X509Name, self).__setattr__(name, value) # Note: we really do not want str subclasses here, so we do not use # isinstance. if type(name) is not str: - raise TypeError("attribute name must be string, not '%.200s'" % ( - type(value).__name__,)) + raise TypeError( + "attribute name must be string, not '%.200s'" + % (type(value).__name__,) + ) nid = _lib.OBJ_txt2nid(_byte_string(name)) if nid == _lib.NID_undef: @@ -569,10 +581,11 @@ class X509Name(object): break if isinstance(value, _text_type): - value = value.encode('utf-8') + value = value.encode("utf-8") add_result = _lib.X509_NAME_add_entry_by_NID( - self._name, nid, _lib.MBSTRING_UTF8, value, -1, -1, 0) + self._name, nid, _lib.MBSTRING_UTF8, value, -1, -1, 0 + ) if not add_result: _raise_current_error() @@ -608,9 +621,9 @@ class X509Name(object): _openssl_assert(data_length >= 0) try: - result = _ffi.buffer( - result_buffer[0], data_length - )[:].decode('utf-8') + result = _ffi.buffer(result_buffer[0], data_length)[:].decode( + "utf-8" + ) finally: # XXX untested _lib.OPENSSL_free(result_buffer[0]) @@ -622,6 +635,7 @@ class X509Name(object): return NotImplemented result = _lib.X509_NAME_cmp(self._name, other._name) return op(result, 0) + return f __eq__ = _cmp(__eq__) @@ -639,11 +653,13 @@ class X509Name(object): """ result_buffer = _ffi.new("char[]", 512) format_result = _lib.X509_NAME_oneline( - self._name, result_buffer, len(result_buffer)) + self._name, result_buffer, len(result_buffer) + ) _openssl_assert(format_result != _ffi.NULL) return "<X509Name object '%s'>" % ( - _native(_ffi.string(result_buffer)),) + _native(_ffi.string(result_buffer)), + ) def hash(self): """ @@ -664,7 +680,7 @@ class X509Name(object): :return: The DER encoded form of this name. :rtype: :py:class:`bytes` """ - result_buffer = _ffi.new('unsigned char**') + result_buffer = _ffi.new("unsigned char**") encode_result = _lib.i2d_X509_NAME(self._name, result_buffer) _openssl_assert(encode_result >= 0) @@ -691,8 +707,9 @@ class X509Name(object): # ffi.string does not handle strings containing NULL bytes # (which may have been generated by old, broken software) - value = _ffi.buffer(_lib.ASN1_STRING_data(fval), - _lib.ASN1_STRING_length(fval))[:] + value = _ffi.buffer( + _lib.ASN1_STRING_data(fval), _lib.ASN1_STRING_length(fval) + )[:] result.append((_ffi.string(name), value)) return result @@ -793,7 +810,8 @@ class X509Extension(object): parts.append(_native(_bio_to_string(bio))) else: value = _native( - _ffi.buffer(name.d.ia5.data, name.d.ia5.length)[:]) + _ffi.buffer(name.d.ia5.data, name.d.ia5.length)[:] + ) parts.append(label + ":" + value) return ", ".join(parts) @@ -843,7 +861,7 @@ class X509Extension(object): .. versionadded:: 0.12 """ octet_result = _lib.X509_EXTENSION_get_data(self._extension) - string_result = _ffi.cast('ASN1_STRING*', octet_result) + string_result = _ffi.cast("ASN1_STRING*", octet_result) char_result = _lib.ASN1_STRING_data(string_result) result_length = _lib.ASN1_STRING_length(string_result) return _ffi.buffer(char_result, result_length)[:] @@ -869,8 +887,9 @@ class X509Req(object): .. versionadded:: 17.1.0 """ from cryptography.hazmat.backends.openssl.x509 import ( - _CertificateSigningRequest + _CertificateSigningRequest, ) + backend = _get_backend() return _CertificateSigningRequest(backend, self._req) @@ -1052,6 +1071,7 @@ class X509(object): """ An X.509 certificate. """ + def __init__(self): x509 = _lib.X509_new() _openssl_assert(x509 != _ffi.NULL) @@ -1077,6 +1097,7 @@ class X509(object): .. versionadded:: 17.1.0 """ from cryptography.hazmat.backends.openssl.x509 import _Certificate + backend = _get_backend() return _Certificate(backend, self._x509) @@ -1218,12 +1239,16 @@ class X509(object): result_length[0] = len(result_buffer) digest_result = _lib.X509_digest( - self._x509, digest, result_buffer, result_length) + self._x509, digest, result_buffer, result_length + ) _openssl_assert(digest_result == 1) - return b":".join([ - b16encode(ch).upper() for ch - in _ffi.buffer(result_buffer, result_length[0])]) + return b":".join( + [ + b16encode(ch).upper() + for ch in _ffi.buffer(result_buffer, result_length[0]) + ] + ) def subject_name_hash(self): """ @@ -1248,7 +1273,7 @@ class X509(object): hex_serial = hex(serial)[2:] if not isinstance(hex_serial, bytes): - hex_serial = hex_serial.encode('ascii') + hex_serial = hex_serial.encode("ascii") bignum_serial = _ffi.new("BIGNUM**") @@ -1259,7 +1284,8 @@ class X509(object): if bignum_serial[0] == _ffi.NULL: set_result = _lib.ASN1_INTEGER_set( - _lib.X509_get_serialNumber(self._x509), small_serial) + _lib.X509_get_serialNumber(self._x509), small_serial + ) if set_result: # TODO Not tested _raise_current_error() @@ -1524,6 +1550,7 @@ class X509StoreFlags(object): .. _OpenSSL Verification Flags: https://www.openssl.org/docs/manmaster/man3/X509_VERIFY_PARAM_set_flags.html """ + CRL_CHECK = _lib.X509_V_FLAG_CRL_CHECK CRL_CHECK_ALL = _lib.X509_V_FLAG_CRL_CHECK_ALL IGNORE_CRITICAL = _lib.X509_V_FLAG_IGNORE_CRITICAL @@ -1644,7 +1671,7 @@ class X509Store(object): param = _lib.X509_VERIFY_PARAM_new() param = _ffi.gc(param, _lib.X509_VERIFY_PARAM_free) - _lib.X509_VERIFY_PARAM_set_time(param, int(vfy_time.strftime('%s'))) + _lib.X509_VERIFY_PARAM_set_time(param, int(vfy_time.strftime("%s"))) _openssl_assert(_lib.X509_STORE_set1_param(self._store, param) != 0) @@ -1722,8 +1749,13 @@ class X509StoreContext(object): errors = [ _lib.X509_STORE_CTX_get_error(self._store_ctx), _lib.X509_STORE_CTX_get_error_depth(self._store_ctx), - _native(_ffi.string(_lib.X509_verify_cert_error_string( - _lib.X509_STORE_CTX_get_error(self._store_ctx)))), + _native( + _ffi.string( + _lib.X509_verify_cert_error_string( + _lib.X509_STORE_CTX_get_error(self._store_ctx) + ) + ) + ), ] # A context error should always be associated with a certificate, so we # expect this call to never return :class:`None`. @@ -1787,8 +1819,7 @@ def load_certificate(type, buffer): elif type == FILETYPE_ASN1: x509 = _lib.d2i_X509_bio(bio, _ffi.NULL) else: - raise ValueError( - "type argument must be FILETYPE_PEM or FILETYPE_ASN1") + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") if x509 == _ffi.NULL: _raise_current_error() @@ -1817,7 +1848,8 @@ def dump_certificate(type, cert): else: raise ValueError( "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " - "FILETYPE_TEXT") + "FILETYPE_TEXT" + ) _openssl_assert(result_code == 1) return _bio_to_string(bio) @@ -1873,7 +1905,8 @@ def dump_privatekey(type, pkey, cipher=None, passphrase=None): if passphrase is None: raise TypeError( "if a value is given for cipher " - "one must also be given for passphrase") + "one must also be given for passphrase" + ) cipher_obj = _lib.EVP_get_cipherbyname(_byte_string(cipher)) if cipher_obj == _ffi.NULL: raise ValueError("Invalid cipher name") @@ -1883,8 +1916,14 @@ def dump_privatekey(type, pkey, cipher=None, passphrase=None): helper = _PassphraseHelper(type, passphrase) if type == FILETYPE_PEM: result_code = _lib.PEM_write_bio_PrivateKey( - bio, pkey._pkey, cipher_obj, _ffi.NULL, 0, - helper.callback, helper.callback_args) + bio, + pkey._pkey, + cipher_obj, + _ffi.NULL, + 0, + helper.callback, + helper.callback_args, + ) helper.raise_if_problem() elif type == FILETYPE_ASN1: result_code = _lib.i2d_PrivateKey_bio(bio, pkey._pkey) @@ -1892,15 +1931,13 @@ def dump_privatekey(type, pkey, cipher=None, passphrase=None): if _lib.EVP_PKEY_id(pkey._pkey) != _lib.EVP_PKEY_RSA: raise TypeError("Only RSA keys are supported for FILETYPE_TEXT") - rsa = _ffi.gc( - _lib.EVP_PKEY_get1_RSA(pkey._pkey), - _lib.RSA_free - ) + rsa = _ffi.gc(_lib.EVP_PKEY_get1_RSA(pkey._pkey), _lib.RSA_free) result_code = _lib.RSA_print(bio, rsa, 0) else: raise ValueError( "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " - "FILETYPE_TEXT") + "FILETYPE_TEXT" + ) _openssl_assert(result_code != 0) @@ -1911,6 +1948,7 @@ class Revoked(object): """ A certificate revocation. """ + # https://www.openssl.org/docs/manmaster/man5/x509v3_config.html#CRL-distribution-points # which differs from crl_reasons of crypto/x509v3/v3_enum.c that matches # OCSP_crl_reason_str. We use the latter, just like the command line @@ -1950,7 +1988,8 @@ class Revoked(object): asn1_serial = _ffi.gc( _lib.BN_to_ASN1_INTEGER(bignum_serial, _ffi.NULL), - _lib.ASN1_INTEGER_free) + _lib.ASN1_INTEGER_free, + ) _lib.X509_REVOKED_set_serialNumber(self._revoked, asn1_serial) def get_serial(self): @@ -2001,7 +2040,7 @@ class Revoked(object): elif not isinstance(reason, bytes): raise TypeError("reason must be None or a byte string") else: - reason = reason.lower().replace(b' ', b'') + reason = reason.lower().replace(b" ", b"") reason_code = [r.lower() for r in self._crl_reasons].index(reason) new_reason_ext = _lib.ASN1_ENUMERATED_new() @@ -2013,7 +2052,8 @@ class Revoked(object): self._delete_reason() add_result = _lib.X509_REVOKED_add1_ext_i2d( - self._revoked, _lib.NID_crl_reason, new_reason_ext, 0, 0) + self._revoked, _lib.NID_crl_reason, new_reason_ext, 0, 0 + ) _openssl_assert(add_result == 1) def get_reason(self): @@ -2095,8 +2135,9 @@ class CRL(object): .. versionadded:: 17.1.0 """ from cryptography.hazmat.backends.openssl.x509 import ( - _CertificateRevocationList + _CertificateRevocationList, ) + backend = _get_backend() return _CertificateRevocationList(backend, self._crl) @@ -2236,13 +2277,15 @@ class CRL(object): digest_obj = _lib.EVP_get_digestbyname(digest) _openssl_assert(digest_obj != _ffi.NULL) _lib.X509_CRL_set_issuer_name( - self._crl, _lib.X509_get_subject_name(issuer_cert._x509)) + self._crl, _lib.X509_get_subject_name(issuer_cert._x509) + ) _lib.X509_CRL_sort(self._crl) result = _lib.X509_CRL_sign(self._crl, issuer_key._pkey, digest_obj) _openssl_assert(result != 0) - def export(self, cert, key, type=FILETYPE_PEM, days=100, - digest=_UNSPECIFIED): + def export( + self, cert, key, type=FILETYPE_PEM, days=100, digest=_UNSPECIFIED + ): """ Export the CRL as a string. @@ -2500,10 +2543,17 @@ class PKCS12(object): cert = self._cert._x509 pkcs12 = _lib.PKCS12_create( - passphrase, friendlyname, pkey, cert, cacerts, + passphrase, + friendlyname, + pkey, + cert, + cacerts, _lib.NID_pbe_WithSHA1And3_Key_TripleDES_CBC, _lib.NID_pbe_WithSHA1And3_Key_TripleDES_CBC, - iter, maciter, 0) + iter, + maciter, + 0, + ) if pkcs12 == _ffi.NULL: _raise_current_error() pkcs12 = _ffi.gc(pkcs12, _lib.PKCS12_free) @@ -2667,7 +2717,7 @@ class _PassphraseHelper(object): "passphrase returned by callback is too long" ) for i in range(len(result)): - buf[i] = result[i:i + 1] + buf[i] = result[i : i + 1] return len(result) except Exception as e: self._problems.append(e) @@ -2692,7 +2742,8 @@ def load_publickey(type, buffer): if type == FILETYPE_PEM: evp_pkey = _lib.PEM_read_bio_PUBKEY( - bio, _ffi.NULL, _ffi.NULL, _ffi.NULL) + bio, _ffi.NULL, _ffi.NULL, _ffi.NULL + ) elif type == FILETYPE_ASN1: evp_pkey = _lib.d2i_PUBKEY_bio(bio, _ffi.NULL) else: @@ -2728,7 +2779,8 @@ def load_privatekey(type, buffer, passphrase=None): helper = _PassphraseHelper(type, passphrase) if type == FILETYPE_PEM: evp_pkey = _lib.PEM_read_bio_PrivateKey( - bio, _ffi.NULL, helper.callback, helper.callback_args) + bio, _ffi.NULL, helper.callback, helper.callback_args + ) helper.raise_if_problem() elif type == FILETYPE_ASN1: evp_pkey = _lib.d2i_PrivateKey_bio(bio, _ffi.NULL) @@ -2827,7 +2879,8 @@ def sign(pkey, data, digest): signature_buffer = _ffi.new("unsigned char[]", length) signature_length = _ffi.new("unsigned int *") final_result = _lib.EVP_SignFinal( - md_ctx, signature_buffer, signature_length, pkey._pkey) + md_ctx, signature_buffer, signature_length, pkey._pkey + ) _openssl_assert(final_result == 1) return _ffi.buffer(signature_buffer, signature_length[0])[:] @@ -2891,7 +2944,8 @@ def dump_crl(type, crl): else: raise ValueError( "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " - "FILETYPE_TEXT") + "FILETYPE_TEXT" + ) _openssl_assert(ret == 1) return _bio_to_string(bio) @@ -3061,4 +3115,4 @@ _lib.SSL_load_error_strings() # Set the default string mask to match OpenSSL upstream (since 2005) and # RFC5280 recommendations. -_lib.ASN1_STRING_set_default_mask_asc(b'utf8only') +_lib.ASN1_STRING_set_default_mask_asc(b"utf8only") diff --git a/src/OpenSSL/version.py b/src/OpenSSL/version.py index 339b9ae..76de33a 100644 --- a/src/OpenSSL/version.py +++ b/src/OpenSSL/version.py @@ -7,8 +7,14 @@ pyOpenSSL - A simple wrapper around the OpenSSL library """ __all__ = [ - "__author__", "__copyright__", "__email__", "__license__", "__summary__", - "__title__", "__uri__", "__version__", + "__author__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__version__", ] __version__ = "20.0.0.dev" diff --git a/tests/conftest.py b/tests/conftest.py index 366624e..5bae6b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ def pytest_report_header(config): return "OpenSSL: {openssl}\ncryptography: {cryptography}".format( openssl=OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION), - cryptography=cryptography.__version__ + cryptography=cryptography.__version__, ) diff --git a/tests/memdbg.py b/tests/memdbg.py index 6e608a7..590b72d 100644 --- a/tests/memdbg.py +++ b/tests/memdbg.py @@ -5,8 +5,8 @@ import traceback from cffi import api as _api -sys.modules['ssl'] = None -sys.modules['_hashlib'] = None +sys.modules["ssl"] = None +sys.modules["_hashlib"] = None _ffi = _api.FFI() @@ -16,18 +16,22 @@ _ffi.cdef( void free(void *ptr); void *realloc(void *ptr, size_t size); - int CRYPTO_set_mem_functions(void *(*m)(size_t),void *(*r)(void *,size_t), void (*f)(void *)); + int CRYPTO_set_mem_functions( + void *(*m)(size_t),void *(*r)(void *,size_t), void (*f)(void *)); int backtrace(void **buffer, int size); char **backtrace_symbols(void *const *buffer, int size); void backtrace_symbols_fd(void *const *buffer, int size, int fd); - """) # noqa + """ +) # noqa _api = _ffi.verify( """ #include <openssl/crypto.h> #include <stdlib.h> #include <execinfo.h> - """, libraries=["crypto"]) + """, + libraries=["crypto"], +) C = _ffi.dlopen(None) verbose = False @@ -80,8 +84,8 @@ def free(p): if _api.CRYPTO_set_mem_functions(malloc, realloc, free): - log('Enabled memory debugging') + log("Enabled memory debugging") heap = {} else: - log('Failed to enable memory debugging') + log("Failed to enable memory debugging") heap = None diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 2a0c967..75f4a5a 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -26,7 +26,7 @@ from OpenSSL.crypto import ( X509Store, X509StoreFlags, X509StoreContext, - X509StoreContextError + X509StoreContextError, ) from OpenSSL.crypto import X509Req from OpenSSL.crypto import X509Extension @@ -40,7 +40,11 @@ from OpenSSL.crypto import PKCS12, load_pkcs12 from OpenSSL.crypto import CRL, Revoked, dump_crl, load_crl from OpenSSL.crypto import NetscapeSPKI from OpenSSL.crypto import ( - sign, verify, get_elliptic_curve, get_elliptic_curves) + sign, + verify, + get_elliptic_curve, + get_elliptic_curves, +) from .util import EqualityTestsMixin, is_consistent_type, WARNING_TYPE_EXPECTED @@ -162,7 +166,8 @@ h0VtBuQoHPtjqZXF59oX6hMMmGLMs9pV0UA3fJs5MYA4/V5ZcQy0Ie0QoJNejLzE -----END CERTIFICATE----- """ -server_key_pem = normalize_privatekey_pem(b"""-----BEGIN RSA PRIVATE KEY----- +server_key_pem = normalize_privatekey_pem( + b"""-----BEGIN RSA PRIVATE KEY----- MIICWwIBAAKBgQC+pvhuud1dLaQQvzipdtlcTotgr5SuE2LvSx0gz/bg1U3u1eQ+ U5eqsxaEUceaX5p5Kk+QflvW8qdjVNxQuYS5uc0gK2+OZnlIYxCf4n5GYGzVIx3Q SBj/TAEFB2WuVinZBiCbxgL7PFM1Kpa+EwVkCAduPpSflJJPwkYGrK2MHQIDAQAB @@ -177,7 +182,8 @@ FwwOhpahld+vqhYk+pfuWWUpQciE+Bu7ZQJASjfT4sQv4qbbKK/scePicnDdx9th NaeNCFfH3aeTrX0LyQJAMBWjWmeKM2G2sCExheeQK0ROnaBC8itCECD4Jsve4nqf r50+LF74iLXFwqysVCebPKMOpDWp/qQ1BbJQIPs7/A== -----END RSA PRIVATE KEY----- -""") +""" +) intermediate_server_cert_pem = b"""-----BEGIN CERTIFICATE----- MIICWDCCAcGgAwIBAgIRAPQFY9jfskSihdiNSNdt6GswDQYJKoZIhvcNAQENBQAw @@ -229,7 +235,8 @@ JRgjHbWutZfZvbSHXr9n7PIphG1Ojg== -----END CERTIFICATE----- """ -client_key_pem = normalize_privatekey_pem(b"""-----BEGIN RSA PRIVATE KEY----- +client_key_pem = normalize_privatekey_pem( + b"""-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDAZh/SRtNm5ntMT4qb6YzEpTroMlq2rn+GrRHRiZ+xkCw/CGNh btPir7/QxaUj26BSmQrHw1bGKEbPsWiW7bdXSespl+xKiku4G/KvnnmWdeJHqsiX eUZtqurMELcPQAw9xPHEuhqqUJvvEoMTsnCEqGM+7DtboCRajYyHfluARQIDAQAB @@ -244,7 +251,8 @@ si6xwT7GzMDkk/ko684AV3KPc/h6G0yGtFIrMg7J3uExpR/VdH2KgwMkZXisSMvw JJEQjOMCVsEJlRk54WWjAkEAzoZNH6UhDdBK5F38rVt/y4SEHgbSfJHIAmPS32Kq f6GGcfNpip0Uk7q7udTKuX7Q/buZi/C4YW7u3VKAquv9NA== -----END RSA PRIVATE KEY----- -""") +""" +) cleartextCertificatePEM = b"""-----BEGIN CERTIFICATE----- MIIC6TCCAlKgAwIBAgIIPQzE4MbeufQwDQYJKoZIhvcNAQEFBQAwWDELMAkGA1UE @@ -266,7 +274,8 @@ lEqxh3aFEUx9IOQ4sgnx1/NOFXBpkRtivl6O0Ec= -----END CERTIFICATE----- """ -cleartextPrivateKeyPEM = normalize_privatekey_pem(b"""\ +cleartextPrivateKeyPEM = normalize_privatekey_pem( + b"""\ -----BEGIN RSA PRIVATE KEY----- MIICXQIBAAKBgQD5mkLpi7q6ROdu7khB3S9aanA0Zls7vvfGOmB80/yeylhGpsjA jWen0VtSQke/NlEPGtO38tsV7CsuFnSmschvAnGrcJl76b0UOOHUgDTIoRxC6QDU @@ -282,7 +291,8 @@ ttXigLnCqR486JDPTi9ZscoZkZ+w7y6e/hH8t6d5Vjt48JVyfjPIaJY+km58LcN3 6AWSeGAdtRFHVzR7oHjVAkB4hutvxiOeiIVQNBhM6RSI9aBPMI21DoX2JRoxvNW2 cbvAhow217X9V0dVerEOKxnNYspXRrh36h7k4mQA+sDq -----END RSA PRIVATE KEY----- -""") +""" +) cleartextCertificateRequestPEM = b"""-----BEGIN CERTIFICATE REQUEST----- MIIBnjCCAQcCAQAwXjELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAklMMRAwDgYDVQQH @@ -359,7 +369,8 @@ Ho4EzbYCOaEAMQA= -----END PKCS7----- """ -pkcs7DataASN1 = base64.b64decode(b""" +pkcs7DataASN1 = base64.b64decode( + b""" MIIDNwYJKoZIhvcNAQcCoIIDKDCCAyQCAQExADALBgkqhkiG9w0BBwGgggMKMIID BjCCAm+gAwIBAgIBATANBgkqhkiG9w0BAQQFADB7MQswCQYDVQQGEwJTRzERMA8G A1UEChMITTJDcnlwdG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQDExtN @@ -378,7 +389,8 @@ bYIBADANBgkqhkiG9w0BAQQFAAOBgQA7/CqT6PoHycTdhEStWNZde7M/2Yc6BoJu VwnW8YxGO8Sn6UJ4FeffZNcYZddSDKosw8LtPOeWoK3JINjAk5jiPQ2cww++7QGG /g5NDjxFZNDJP1dGiLAxPW6JXwov4v0FmdzfLOZ01jDcgQQZqEpYlgpuI5JEWUQ9 Ho4EzbYCOaEAMQA= -""") +""" +) crlData = b"""\ -----BEGIN X509 CRL----- @@ -606,8 +618,8 @@ class TestX509Ext(object): # This isn't necessarily the best string representation. Perhaps it # will be changed/improved in the future. assert ( - str(X509Extension(b'basicConstraints', True, b'CA:false')) == - 'CA:FALSE' + str(X509Extension(b"basicConstraints", True, b"CA:false")) + == "CA:FALSE" ) def test_type(self): @@ -616,30 +628,40 @@ class TestX509Ext(object): """ assert is_consistent_type( X509Extension, - 'X509Extension', b'basicConstraints', True, b'CA:true') + "X509Extension", + b"basicConstraints", + True, + b"CA:true", + ) def test_construction(self): """ `X509Extension` accepts an extension type name, a critical flag, and an extension value and returns an `X509Extension` instance. """ - basic = X509Extension(b'basicConstraints', True, b'CA:true') + basic = X509Extension(b"basicConstraints", True, b"CA:true") assert isinstance(basic, X509Extension) - comment = X509Extension(b'nsComment', False, b'pyOpenSSL unit test') + comment = X509Extension(b"nsComment", False, b"pyOpenSSL unit test") assert isinstance(comment, X509Extension) - @pytest.mark.parametrize('type_name, critical, value', [ - (b'thisIsMadeUp', False, b'hi'), - (b'basicConstraints', False, b'blah blah'), - - # Exercise a weird one (an extension which uses the r2i method). This - # exercises the codepath that requires a non-NULL ctx to be passed to - # X509V3_EXT_nconf. It can't work now because we provide no - # configuration database. It might be made to work in the future. - (b'proxyCertInfo', True, - b'language:id-ppl-anyLanguage,pathlen:1,policy:text:AB') - ]) + @pytest.mark.parametrize( + "type_name, critical, value", + [ + (b"thisIsMadeUp", False, b"hi"), + (b"basicConstraints", False, b"blah blah"), + # Exercise a weird one (an extension which uses the r2i method). + # This exercises the codepath that requires a non-NULL ctx to be + # passed to X509V3_EXT_nconf. It can't work now because we provide + # no configuration database. It might be made to work in the + # future. + ( + b"proxyCertInfo", + True, + b"language:id-ppl-anyLanguage,pathlen:1,policy:text:AB", + ), + ], + ) def test_invalid_extension(self, type_name, critical, value): """ `X509Extension` raises something if it is passed a bad @@ -648,19 +670,19 @@ class TestX509Ext(object): with pytest.raises(Error): X509Extension(type_name, critical, value) - @pytest.mark.parametrize('critical_flag', [True, False]) + @pytest.mark.parametrize("critical_flag", [True, False]) def test_get_critical(self, critical_flag): """ `X509ExtensionType.get_critical` returns the value of the extension's critical flag. """ - ext = X509Extension(b'basicConstraints', critical_flag, b'CA:true') + ext = X509Extension(b"basicConstraints", critical_flag, b"CA:true") assert ext.get_critical() == critical_flag - @pytest.mark.parametrize('short_name, value', [ - (b'basicConstraints', b'CA:true'), - (b'nsComment', b'foo bar'), - ]) + @pytest.mark.parametrize( + "short_name, value", + [(b"basicConstraints", b"CA:true"), (b"nsComment", b"foo bar")], + ) def test_get_short_name(self, short_name, value): """ `X509ExtensionType.get_short_name` returns a string giving the @@ -674,9 +696,9 @@ class TestX509Ext(object): `X509Extension.get_data` returns a string giving the data of the extension. """ - ext = X509Extension(b'basicConstraints', True, b'CA:true') + ext = X509Extension(b"basicConstraints", True, b"CA:true") # Expect to get back the DER encoded form of CA:true. - assert ext.get_data() == b'0\x03\x01\x01\xff' + assert ext.get_data() == b"0\x03\x01\x01\xff" def test_unused_subject(self, x509_data): """ @@ -685,13 +707,14 @@ class TestX509Ext(object): """ pkey, x509 = x509_data ext1 = X509Extension( - b'basicConstraints', False, b'CA:TRUE', subject=x509) + b"basicConstraints", False, b"CA:TRUE", subject=x509 + ) x509.add_extensions([ext1]) - x509.sign(pkey, 'sha1') + x509.sign(pkey, "sha1") # This is a little lame. Can we think of a better way? text = dump_certificate(FILETYPE_TEXT, x509) - assert b'X509v3 Basic Constraints:' in text - assert b'CA:TRUE' in text + assert b"X509v3 Basic Constraints:" in text + assert b"CA:TRUE" in text def test_subject(self, x509_data): """ @@ -700,11 +723,12 @@ class TestX509Ext(object): """ pkey, x509 = x509_data ext3 = X509Extension( - b'subjectKeyIdentifier', False, b'hash', subject=x509) + b"subjectKeyIdentifier", False, b"hash", subject=x509 + ) x509.add_extensions([ext3]) - x509.sign(pkey, 'sha1') + x509.sign(pkey, "sha1") text = dump_certificate(FILETYPE_TEXT, x509) - assert b'X509v3 Subject Key Identifier:' in text + assert b"X509v3 Subject Key Identifier:" in text def test_missing_subject(self): """ @@ -712,14 +736,9 @@ class TestX509Ext(object): is given no value, something happens. """ with pytest.raises(Error): - X509Extension(b'subjectKeyIdentifier', False, b'hash') - - @pytest.mark.parametrize('bad_obj', [ - True, - object(), - "hello", - [], - ]) + X509Extension(b"subjectKeyIdentifier", False, b"hash") + + @pytest.mark.parametrize("bad_obj", [True, object(), "hello", []]) def test_invalid_subject(self, bad_obj): """ If the `subject` parameter is given a value which is not an @@ -727,7 +746,8 @@ class TestX509Ext(object): """ with pytest.raises(TypeError): X509Extension( - 'basicConstraints', False, 'CA:TRUE', subject=bad_obj) + "basicConstraints", False, "CA:TRUE", subject=bad_obj + ) def test_unused_issuer(self, x509_data): """ @@ -736,12 +756,13 @@ class TestX509Ext(object): """ pkey, x509 = x509_data ext1 = X509Extension( - b'basicConstraints', False, b'CA:TRUE', issuer=x509) + b"basicConstraints", False, b"CA:TRUE", issuer=x509 + ) x509.add_extensions([ext1]) - x509.sign(pkey, 'sha1') + x509.sign(pkey, "sha1") text = dump_certificate(FILETYPE_TEXT, x509) - assert b'X509v3 Basic Constraints:' in text - assert b'CA:TRUE' in text + assert b"X509v3 Basic Constraints:" in text + assert b"CA:TRUE" in text def test_issuer(self, x509_data): """ @@ -750,13 +771,13 @@ class TestX509Ext(object): """ pkey, x509 = x509_data ext2 = X509Extension( - b'authorityKeyIdentifier', False, b'issuer:always', - issuer=x509) + b"authorityKeyIdentifier", False, b"issuer:always", issuer=x509 + ) x509.add_extensions([ext2]) - x509.sign(pkey, 'sha1') + x509.sign(pkey, "sha1") text = dump_certificate(FILETYPE_TEXT, x509) - assert b'X509v3 Authority Key Identifier:' in text - assert b'DirName:/CN=Yoda root CA' in text + assert b"X509v3 Authority Key Identifier:" in text + assert b"DirName:/CN=Yoda root CA" in text def test_missing_issuer(self): """ @@ -765,15 +786,10 @@ class TestX509Ext(object): """ with pytest.raises(Error): X509Extension( - b'authorityKeyIdentifier', - False, b'keyid:always,issuer:always') - - @pytest.mark.parametrize('bad_obj', [ - True, - object(), - "hello", - [], - ]) + b"authorityKeyIdentifier", False, b"keyid:always,issuer:always" + ) + + @pytest.mark.parametrize("bad_obj", [True, object(), "hello", []]) def test_invalid_issuer(self, bad_obj): """ If the `issuer` parameter is given a value which is not an @@ -781,8 +797,11 @@ class TestX509Ext(object): """ with pytest.raises(TypeError): X509Extension( - 'basicConstraints', False, 'keyid:always,issuer:always', - issuer=bad_obj) + "basicConstraints", + False, + "keyid:always,issuer:always", + issuer=bad_obj, + ) class TestPKey(object): @@ -850,7 +869,7 @@ class TestPKey(object): """ `PKey` can be used to create instances of that type. """ - assert is_consistent_type(PKey, 'PKey') + assert is_consistent_type(PKey, "PKey") def test_construction(self): """ @@ -992,6 +1011,7 @@ def x509_name(**attrs): # Make the order stable - order matters! def key(attr): return attr[1] + attrs.sort(key=key) for k, v in attrs: setattr(name, k, v) @@ -1099,6 +1119,7 @@ class TestX509Name(object): """ `X509Name` instances should compare based on their NIDs. """ + def _equality(a, b, assert_true, assert_false): assert_true(a == b) assert_false(a != b) @@ -1122,30 +1143,28 @@ class TestX509Name(object): assert_equal(x509_name(), x509_name()) # Instances with equal NIDs should compare equal to each other. - assert_equal(x509_name(commonName="foo"), - x509_name(commonName="foo")) + assert_equal(x509_name(commonName="foo"), x509_name(commonName="foo")) # Instance with equal NIDs set using different aliases should compare # equal to each other. - assert_equal(x509_name(commonName="foo"), - x509_name(CN="foo")) + assert_equal(x509_name(commonName="foo"), x509_name(CN="foo")) # Instances with more than one NID with the same values should compare # equal to each other. - assert_equal(x509_name(CN="foo", organizationalUnitName="bar"), - x509_name(commonName="foo", OU="bar")) + assert_equal( + x509_name(CN="foo", organizationalUnitName="bar"), + x509_name(commonName="foo", OU="bar"), + ) def assert_not_equal(a, b): _equality(a, b, assert_false, assert_true) # Instances with different values for the same NID should not compare # equal to each other. - assert_not_equal(x509_name(CN="foo"), - x509_name(CN="bar")) + assert_not_equal(x509_name(CN="foo"), x509_name(CN="bar")) # Instances with different NIDs should not compare equal to each other. - assert_not_equal(x509_name(CN="foo"), - x509_name(OU="foo")) + assert_not_equal(x509_name(CN="foo"), x509_name(OU="foo")) assert_not_equal(x509_name(), object()) @@ -1165,8 +1184,7 @@ class TestX509Name(object): # An X509Name with a NID with a value which sorts less than the value # of the same NID on another X509Name compares less than the other # X509Name. - assert_less_than(x509_name(CN="abc"), - x509_name(CN="def")) + assert_less_than(x509_name(CN="abc"), x509_name(CN="def")) def assert_greater_than(a, b): _inequality(a, b, assert_false, assert_true) @@ -1174,8 +1192,7 @@ class TestX509Name(object): # An X509Name with a NID with a value which sorts greater than the # value of the same NID on another X509Name compares greater than the # other X509Name. - assert_greater_than(x509_name(CN="def"), - x509_name(CN="abc")) + assert_greater_than(x509_name(CN="def"), x509_name(CN="abc")) def test_hash(self): """ @@ -1192,9 +1209,10 @@ class TestX509Name(object): `X509Name.der` returns the DER encoded form of the name. """ a = x509_name(CN="foo", C="US") - assert (a.der() == - b'0\x1b1\x0b0\t\x06\x03U\x04\x06\x13\x02US' - b'1\x0c0\n\x06\x03U\x04\x03\x0c\x03foo') + assert ( + a.der() == b"0\x1b1\x0b0\t\x06\x03U\x04\x06\x13\x02US" + b"1\x0c0\n\x06\x03U\x04\x03\x0c\x03foo" + ) def test_get_components(self): """ @@ -1225,8 +1243,8 @@ class TestX509Name(object): cert = load_certificate(FILETYPE_PEM, nulbyteSubjectAltNamePEM) subject = cert.get_subject() components = subject.get_components() - ccn = [value for name, value in components if name == b'CN'] - assert ccn[0] == b'null.python.org\x00example.org' + ccn = [value for name, value in components if name == b"CN"] + assert ccn[0] == b"null.python.org\x00example.org" def test_set_attribute_failure(self): """ @@ -1295,7 +1313,7 @@ class _PKeyInteractionTestsMixin: request.set_pubkey(key) request.sign(key, GOOD_DIGEST) # If the type has a verify method, cover that too. - if getattr(request, 'verify', None) is not None: + if getattr(request, "verify", None) is not None: pub = request.get_pubkey() assert request.verify(pub) # Make another key that won't verify. @@ -1320,7 +1338,7 @@ class TestX509Req(_PKeyInteractionTestsMixin): """ `X509Req` can be used to create instances of that type. """ - assert is_consistent_type(X509Req, 'X509Req') + assert is_consistent_type(X509Req, "X509Req") def test_construction(self): """ @@ -1372,13 +1390,14 @@ class TestX509Req(_PKeyInteractionTestsMixin): and adds them to the X509 request. """ request = X509Req() - request.add_extensions([ - X509Extension(b'basicConstraints', True, b'CA:false')]) + request.add_extensions( + [X509Extension(b"basicConstraints", True, b"CA:false")] + ) exts = request.get_extensions() assert len(exts) == 1 - assert exts[0].get_short_name() == b'basicConstraints' + assert exts[0].get_short_name() == b"basicConstraints" assert exts[0].get_critical() == 1 - assert exts[0].get_data() == b'0\x00' + assert exts[0].get_data() == b"0\x00" def test_get_extensions(self): """ @@ -1388,17 +1407,20 @@ class TestX509Req(_PKeyInteractionTestsMixin): request = X509Req() exts = request.get_extensions() assert exts == [] - request.add_extensions([ - X509Extension(b'basicConstraints', True, b'CA:true'), - X509Extension(b'keyUsage', False, b'digitalSignature')]) + request.add_extensions( + [ + X509Extension(b"basicConstraints", True, b"CA:true"), + X509Extension(b"keyUsage", False, b"digitalSignature"), + ] + ) exts = request.get_extensions() assert len(exts) == 2 - assert exts[0].get_short_name() == b'basicConstraints' + assert exts[0].get_short_name() == b"basicConstraints" assert exts[0].get_critical() == 1 - assert exts[0].get_data() == b'0\x03\x01\x01\xff' - assert exts[1].get_short_name() == b'keyUsage' + assert exts[0].get_data() == b"0\x03\x01\x01\xff" + assert exts[1].get_short_name() == b"keyUsage" assert exts[1].get_critical() == 0 - assert exts[1].get_data() == b'\x03\x02\x07\x80' + assert exts[1].get_data() == b"\x03\x02\x07\x80" def test_add_extensions_wrong_args(self): """ @@ -1477,6 +1499,7 @@ class TestX509(_PKeyInteractionTestsMixin): """ Tests for `OpenSSL.crypto.X509`. """ + pemData = cleartextCertificatePEM + cleartextPrivateKeyPEM extpem = """ @@ -1510,7 +1533,7 @@ WpOdIpB8KksUTCzV591Nr1wd """ `X509` can be used to create instances of that type. """ - assert is_consistent_type(X509, 'X509') + assert is_consistent_type(X509, "X509") def test_construction(self): """ @@ -1518,7 +1541,7 @@ WpOdIpB8KksUTCzV591Nr1wd """ certificate = X509() assert isinstance(certificate, X509) - assert type(certificate).__name__ == 'X509' + assert type(certificate).__name__ == "X509" assert type(certificate) == X509 def test_set_version_wrong_args(self): @@ -1565,8 +1588,8 @@ WpOdIpB8KksUTCzV591Nr1wd validity period to it. """ certificate = X509() - set = getattr(certificate, 'set_not' + which) - get = getattr(certificate, 'get_not' + which) + set = getattr(certificate, "set_not" + which) + get = getattr(certificate, "get_not" + which) # Starts with no value. assert get() is None @@ -1650,8 +1673,8 @@ WpOdIpB8KksUTCzV591Nr1wd current time plus the number of seconds passed in. """ cert = load_certificate(FILETYPE_PEM, self.pemData) - not_before_min = ( - datetime.utcnow().replace(microsecond=0) + timedelta(seconds=100) + not_before_min = datetime.utcnow().replace(microsecond=0) + timedelta( + seconds=100 ) cert.gmtime_adj_notBefore(100) not_before = datetime.strptime( @@ -1676,8 +1699,8 @@ WpOdIpB8KksUTCzV591Nr1wd to be the current time plus the number of seconds passed in. """ cert = load_certificate(FILETYPE_PEM, self.pemData) - not_after_min = ( - datetime.utcnow().replace(microsecond=0) + timedelta(seconds=100) + not_after_min = datetime.utcnow().replace(microsecond=0) + timedelta( + seconds=100 ) cert.gmtime_adj_notAfter(100) not_after = datetime.strptime( @@ -1724,8 +1747,9 @@ WpOdIpB8KksUTCzV591Nr1wd # digest will not product the same digest). # Digest verified with the command: # openssl x509 -in root_cert.pem -noout -fingerprint -md5 - cert.digest("MD5") == - b"19:B3:05:26:2B:F8:F2:FF:0B:8F:21:07:A8:28:B8:75") + cert.digest("MD5") + == b"19:B3:05:26:2B:F8:F2:FF:0B:8F:21:07:A8:28:B8:75" + ) def _extcert(self, pkey, extensions): cert = X509() @@ -1740,9 +1764,10 @@ WpOdIpB8KksUTCzV591Nr1wd cert.set_notAfter(when) cert.add_extensions(extensions) - cert.sign(pkey, 'sha1') + cert.sign(pkey, "sha1") return load_certificate( - FILETYPE_PEM, dump_certificate(FILETYPE_PEM, cert)) + FILETYPE_PEM, dump_certificate(FILETYPE_PEM, cert) + ) def test_extension_count(self): """ @@ -1750,10 +1775,11 @@ WpOdIpB8KksUTCzV591Nr1wd that are present in the certificate. """ pkey = load_privatekey(FILETYPE_PEM, client_key_pem) - ca = X509Extension(b'basicConstraints', True, b'CA:FALSE') - key = X509Extension(b'keyUsage', True, b'digitalSignature') + ca = X509Extension(b"basicConstraints", True, b"CA:FALSE") + key = X509Extension(b"keyUsage", True, b"digitalSignature") subjectAltName = X509Extension( - b'subjectAltName', True, b'DNS:example.com') + b"subjectAltName", True, b"DNS:example.com" + ) # Try a certificate with no extensions at all. c = self._extcert(pkey, []) @@ -1773,27 +1799,28 @@ WpOdIpB8KksUTCzV591Nr1wd `X509Extension` corresponding to the extension at that index. """ pkey = load_privatekey(FILETYPE_PEM, client_key_pem) - ca = X509Extension(b'basicConstraints', True, b'CA:FALSE') - key = X509Extension(b'keyUsage', True, b'digitalSignature') + ca = X509Extension(b"basicConstraints", True, b"CA:FALSE") + key = X509Extension(b"keyUsage", True, b"digitalSignature") subjectAltName = X509Extension( - b'subjectAltName', False, b'DNS:example.com') + b"subjectAltName", False, b"DNS:example.com" + ) cert = self._extcert(pkey, [ca, key, subjectAltName]) ext = cert.get_extension(0) assert isinstance(ext, X509Extension) assert ext.get_critical() - assert ext.get_short_name() == b'basicConstraints' + assert ext.get_short_name() == b"basicConstraints" ext = cert.get_extension(1) assert isinstance(ext, X509Extension) assert ext.get_critical() - assert ext.get_short_name() == b'keyUsage' + assert ext.get_short_name() == b"keyUsage" ext = cert.get_extension(2) assert isinstance(ext, X509Extension) assert not ext.get_critical() - assert ext.get_short_name() == b'subjectAltName' + assert ext.get_short_name() == b"subjectAltName" with pytest.raises(IndexError): cert.get_extension(-1) @@ -1811,13 +1838,14 @@ WpOdIpB8KksUTCzV591Nr1wd cert = load_certificate(FILETYPE_PEM, nulbyteSubjectAltNamePEM) ext = cert.get_extension(3) - assert ext.get_short_name() == b'subjectAltName' + assert ext.get_short_name() == b"subjectAltName" assert ( b"DNS:altnull.python.org\x00example.com, " b"email:null@python.org\x00user@example.org, " b"URI:http://null.python.org\x00http://example.org, " - b"IP Address:192.0.2.1, IP Address:2001:DB8:0:0:0:0:0:1\n" == - str(ext).encode("ascii")) + b"IP Address:192.0.2.1, IP Address:2001:DB8:0:0:0:0:0:1\n" + == str(ext).encode("ascii") + ) def test_invalid_digest_algorithm(self): """ @@ -1835,10 +1863,13 @@ WpOdIpB8KksUTCzV591Nr1wd cert = load_certificate(FILETYPE_PEM, self.pemData) subj = cert.get_subject() assert isinstance(subj, X509Name) - assert ( - subj.get_components() == - [(b'C', b'US'), (b'ST', b'IL'), (b'L', b'Chicago'), - (b'O', b'Testing'), (b'CN', b'Testing Root CA')]) + assert subj.get_components() == [ + (b"C", b"US"), + (b"ST", b"IL"), + (b"L", b"Chicago"), + (b"O", b"Testing"), + (b"CN", b"Testing Root CA"), + ] def test_set_subject_wrong_args(self): """ @@ -1856,12 +1887,13 @@ WpOdIpB8KksUTCzV591Nr1wd """ cert = X509() name = cert.get_subject() - name.C = 'AU' - name.OU = 'Unit Tests' + name.C = "AU" + name.OU = "Unit Tests" cert.set_subject(name) - assert ( - cert.get_subject().get_components() == - [(b'C', b'AU'), (b'OU', b'Unit Tests')]) + assert cert.get_subject().get_components() == [ + (b"C", b"AU"), + (b"OU", b"Unit Tests"), + ] def test_get_issuer(self): """ @@ -1871,10 +1903,13 @@ WpOdIpB8KksUTCzV591Nr1wd subj = cert.get_issuer() assert isinstance(subj, X509Name) comp = subj.get_components() - assert ( - comp == - [(b'C', b'US'), (b'ST', b'IL'), (b'L', b'Chicago'), - (b'O', b'Testing'), (b'CN', b'Testing Root CA')]) + assert comp == [ + (b"C", b"US"), + (b"ST", b"IL"), + (b"L", b"Chicago"), + (b"O", b"Testing"), + (b"CN", b"Testing Root CA"), + ] def test_set_issuer_wrong_args(self): """ @@ -1892,12 +1927,13 @@ WpOdIpB8KksUTCzV591Nr1wd """ cert = X509() name = cert.get_issuer() - name.C = 'AU' - name.OU = 'Unit Tests' + name.C = "AU" + name.OU = "Unit Tests" cert.set_issuer(name) - assert ( - cert.get_issuer().get_components() == - [(b'C', b'AU'), (b'OU', b'Unit Tests')]) + assert cert.get_issuer().get_components() == [ + (b"C", b"AU"), + (b"OU", b"Unit Tests"), + ] def test_get_pubkey_uninitialized(self): """ @@ -2004,7 +2040,7 @@ class TestX509Store(object): """ `X509Store` is a type object. """ - assert is_consistent_type(X509Store, 'X509Store') + assert is_consistent_type(X509Store, "X509Store") def test_add_cert(self): """ @@ -2014,7 +2050,7 @@ class TestX509Store(object): store = X509Store() store.add_cert(cert) - @pytest.mark.parametrize('cert', [None, 1.0, 'cert', object()]) + @pytest.mark.parametrize("cert", [None, 1.0, "cert", object()]) def test_add_cert_wrong_args(self, cert): """ `X509Store.add_cert` raises `TypeError` if passed a non-X509 object @@ -2039,13 +2075,14 @@ class TestPKCS12(object): """ Test for `OpenSSL.crypto.PKCS12` and `OpenSSL.crypto.load_pkcs12`. """ + pemData = cleartextCertificatePEM + cleartextPrivateKeyPEM def test_type(self): """ `PKCS12` is a type object. """ - assert is_consistent_type(PKCS12, 'PKCS12') + assert is_consistent_type(PKCS12, "PKCS12") def test_empty_construction(self): """ @@ -2068,13 +2105,13 @@ class TestPKCS12(object): for bad_arg in [3, PKey(), X509]: with pytest.raises(TypeError): p12.set_certificate(bad_arg) - for bad_arg in [3, 'legbone', X509()]: + for bad_arg in [3, "legbone", X509()]: with pytest.raises(TypeError): p12.set_privatekey(bad_arg) for bad_arg in [3, X509(), (3, 4), (PKey(),)]: with pytest.raises(TypeError): p12.set_ca_certificates(bad_arg) - for bad_arg in [6, ('foo', 'bar')]: + for bad_arg in [6, ("foo", "bar")]: with pytest.raises(TypeError): p12.set_friendlyname(bad_arg) @@ -2134,12 +2171,13 @@ class TestPKCS12(object): # it to. At some point, hopefully this will change so that # p12.get_certificate() is actually what returns the loaded # certificate. - assert ( - cleartextCertificatePEM == - dump_certificate(FILETYPE_PEM, p12.get_ca_certificates()[0])) + assert cleartextCertificatePEM == dump_certificate( + FILETYPE_PEM, p12.get_ca_certificates()[0] + ) - def gen_pkcs12(self, cert_pem=None, key_pem=None, ca_pem=None, - friendly_name=None): + def gen_pkcs12( + self, cert_pem=None, key_pem=None, ca_pem=None, friendly_name=None + ): """ Generate a PKCS12 object with components from PEM. Verify that the set functions return None. @@ -2161,27 +2199,48 @@ class TestPKCS12(object): assert ret is None return p12 - def check_recovery(self, p12_str, key=None, cert=None, ca=None, passwd=b"", - extra=()): + def check_recovery( + self, p12_str, key=None, cert=None, ca=None, passwd=b"", extra=() + ): """ Use openssl program to confirm three components are recoverable from a PKCS12 string. """ if key: recovered_key = _runopenssl( - p12_str, b"pkcs12", b"-nocerts", b"-nodes", b"-passin", - b"pass:" + passwd, *extra) - assert recovered_key[-len(key):] == key + p12_str, + b"pkcs12", + b"-nocerts", + b"-nodes", + b"-passin", + b"pass:" + passwd, + *extra + ) + assert recovered_key[-len(key) :] == key if cert: recovered_cert = _runopenssl( - p12_str, b"pkcs12", b"-clcerts", b"-nodes", b"-passin", - b"pass:" + passwd, b"-nokeys", *extra) - assert recovered_cert[-len(cert):] == cert + p12_str, + b"pkcs12", + b"-clcerts", + b"-nodes", + b"-passin", + b"pass:" + passwd, + b"-nokeys", + *extra + ) + assert recovered_cert[-len(cert) :] == cert if ca: recovered_cert = _runopenssl( - p12_str, b"pkcs12", b"-cacerts", b"-nodes", b"-passin", - b"pass:" + passwd, b"-nokeys", *extra) - assert recovered_cert[-len(ca):] == ca + p12_str, + b"pkcs12", + b"-cacerts", + b"-nodes", + b"-passin", + b"pass:" + passwd, + b"-nokeys", + *extra + ) + assert recovered_cert[-len(ca) :] == ca def verify_pkcs12_container(self, p12): """ @@ -2193,9 +2252,11 @@ class TestPKCS12(object): """ cert_pem = dump_certificate(FILETYPE_PEM, p12.get_certificate()) key_pem = dump_privatekey(FILETYPE_PEM, p12.get_privatekey()) - assert ( - (client_cert_pem, client_key_pem, None) == - (cert_pem, key_pem, p12.get_ca_certificates())) + assert (client_cert_pem, client_key_pem, None) == ( + cert_pem, + key_pem, + p12.get_ca_certificates(), + ) def test_load_pkcs12(self): """ @@ -2210,7 +2271,7 @@ class TestPKCS12(object): b"-export", b"-clcerts", b"-passout", - b"pass:" + passwd + b"pass:" + passwd, ) p12 = load_pkcs12(p12_str, passphrase=passwd) self.verify_pkcs12_container(p12) @@ -2223,15 +2284,21 @@ class TestPKCS12(object): """ pem = client_key_pem + client_cert_pem passwd = b"whatever" - p12_str = _runopenssl(pem, b"pkcs12", b"-export", b"-clcerts", - b"-passout", b"pass:" + passwd) + p12_str = _runopenssl( + pem, + b"pkcs12", + b"-export", + b"-clcerts", + b"-passout", + b"pass:" + passwd, + ) with pytest.warns(DeprecationWarning) as w: simplefilter("always") p12 = load_pkcs12(p12_str, passphrase=b"whatever".decode("ascii")) - assert ( - "{0} for passphrase is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + msg = "{0} for passphrase is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) + assert msg == str(w[-1].message) self.verify_pkcs12_container(p12) @@ -2243,7 +2310,8 @@ class TestPKCS12(object): """ pem = client_key_pem + client_cert_pem p12_str = _runopenssl( - pem, b"pkcs12", b"-export", b"-clcerts", b"-passout", b"pass:") + pem, b"pkcs12", b"-export", b"-clcerts", b"-passout", b"pass:" + ) p12 = load_pkcs12(p12_str) self.verify_pkcs12_container(p12) @@ -2262,7 +2330,8 @@ class TestPKCS12(object): extracted and examined. """ self.verify_pkcs12_container( - self._dump_and_load(dump_passphrase=None, load_passphrase=b'')) + self._dump_and_load(dump_passphrase=None, load_passphrase=b"") + ) def test_load_pkcs12_null_passphrase_load_null(self): """ @@ -2271,7 +2340,8 @@ class TestPKCS12(object): extracted and examined. """ self.verify_pkcs12_container( - self._dump_and_load(dump_passphrase=None, load_passphrase=None)) + self._dump_and_load(dump_passphrase=None, load_passphrase=None) + ) def test_load_pkcs12_empty_passphrase_load_empty(self): """ @@ -2280,7 +2350,8 @@ class TestPKCS12(object): extracted and examined. """ self.verify_pkcs12_container( - self._dump_and_load(dump_passphrase=b'', load_passphrase=b'')) + self._dump_and_load(dump_passphrase=b"", load_passphrase=b"") + ) def test_load_pkcs12_empty_passphrase_load_null(self): """ @@ -2289,17 +2360,18 @@ class TestPKCS12(object): extracted and examined. """ self.verify_pkcs12_container( - self._dump_and_load(dump_passphrase=b'', load_passphrase=None)) + self._dump_and_load(dump_passphrase=b"", load_passphrase=None) + ) def test_load_pkcs12_garbage(self): """ `load_pkcs12` raises `OpenSSL.crypto.Error` when passed a string which is not a PKCS12 dump. """ - passwd = 'whatever' + passwd = "whatever" with pytest.raises(Error) as err: - load_pkcs12(b'fruit loops', passwd) - assert err.value.args[0][0][0] == 'asn1 encoding routines' + load_pkcs12(b"fruit loops", passwd) + assert err.value.args[0][0][0] == "asn1 encoding routines" assert len(err.value.args[0][0]) == 3 def test_replace(self): @@ -2329,7 +2401,7 @@ class TestPKCS12(object): """ passwd = b'Dogmeat[]{}!@#$%^&*()~`?/.,<>-_+=";:' p12 = self.gen_pkcs12(server_cert_pem, server_key_pem, root_cert_pem) - for friendly_name in [b'Serverlicious', None, b'###']: + for friendly_name in [b"Serverlicious", None, b"###"]: p12.set_friendlyname(friendly_name) assert p12.get_friendlyname() == friendly_name dumped_p12 = p12.export(passphrase=passwd, iter=2, maciter=3) @@ -2340,8 +2412,12 @@ class TestPKCS12(object): # does not store the friendly name in the cert's # alias, which we could then extract. self.check_recovery( - dumped_p12, key=server_key_pem, cert=server_cert_pem, - ca=root_cert_pem, passwd=passwd) + dumped_p12, + key=server_key_pem, + cert=server_cert_pem, + ca=root_cert_pem, + passwd=passwd, + ) def test_various_empty_passphrases(self): """ @@ -2355,8 +2431,12 @@ class TestPKCS12(object): dumped_p12_nopw = p12.export(iter=9, maciter=4) for dumped_p12 in [dumped_p12_empty, dumped_p12_none, dumped_p12_nopw]: self.check_recovery( - dumped_p12, key=client_key_pem, cert=client_cert_pem, - ca=root_cert_pem, passwd=passwd) + dumped_p12, + key=client_key_pem, + cert=client_cert_pem, + ca=root_cert_pem, + passwd=passwd, + ) def test_removing_ca_cert(self): """ @@ -2375,8 +2455,12 @@ class TestPKCS12(object): p12 = self.gen_pkcs12(server_cert_pem, server_key_pem, root_cert_pem) dumped_p12 = p12.export(maciter=-1, passphrase=passwd, iter=2) self.check_recovery( - dumped_p12, key=server_key_pem, cert=server_cert_pem, - passwd=passwd, extra=(b"-nomacver",)) + dumped_p12, + key=server_key_pem, + cert=server_cert_pem, + passwd=passwd, + extra=(b"-nomacver",), + ) def test_load_without_mac(self): """ @@ -2402,14 +2486,14 @@ class TestPKCS12(object): """ A PKCS12 with an empty CA certificates list can be exported. """ - passwd = b'Hobie 18' + passwd = b"Hobie 18" p12 = self.gen_pkcs12(server_cert_pem, server_key_pem) p12.set_ca_certificates([]) assert () == p12.get_ca_certificates() dumped_p12 = p12.export(passphrase=passwd, iter=3) self.check_recovery( - dumped_p12, key=server_key_pem, cert=server_cert_pem, - passwd=passwd) + dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=passwd + ) def test_export_without_args(self): """ @@ -2418,7 +2502,8 @@ class TestPKCS12(object): p12 = self.gen_pkcs12(server_cert_pem, server_key_pem, root_cert_pem) dumped_p12 = p12.export() # no args self.check_recovery( - dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=b"") + dumped_p12, key=server_key_pem, cert=server_cert_pem, passwd=b"" + ) def test_export_without_bytes(self): """ @@ -2429,15 +2514,15 @@ class TestPKCS12(object): with pytest.warns(DeprecationWarning) as w: simplefilter("always") dumped_p12 = p12.export(passphrase=b"randomtext".decode("ascii")) - assert ( - "{0} for passphrase is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + msg = "{0} for passphrase is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) + assert msg == str(w[-1].message) self.check_recovery( dumped_p12, key=server_key_pem, cert=server_cert_pem, - passwd=b"randomtext" + passwd=b"randomtext", ) def test_key_cert_mismatch(self): @@ -2468,6 +2553,7 @@ class TestLoadPublicKey(object): """ Tests for :func:`load_publickey`. """ + def test_loading_works(self): """ load_publickey loads public keys and sets correct attributes. @@ -2496,7 +2582,7 @@ class TestLoadPublicKey(object): """ load_publickey works with text strings, not just bytes. """ - serialized = cleartextPublicKeyPEM.decode('ascii') + serialized = cleartextPublicKeyPEM.decode("ascii") key = load_publickey(FILETYPE_PEM, serialized) dumped_pem = dump_publickey(FILETYPE_PEM, key) @@ -2522,7 +2608,8 @@ class TestFunction(object): """ with pytest.raises(TypeError): load_privatekey( - FILETYPE_PEM, encryptedPrivateKeyPEMPassphrase, object()) + FILETYPE_PEM, encryptedPrivateKeyPEMPassphrase, object() + ) def test_load_privatekey_wrongPassphrase(self): """ @@ -2550,8 +2637,10 @@ class TestFunction(object): string if given the passphrase. """ key = load_privatekey( - FILETYPE_PEM, encryptedPrivateKeyPEM, - encryptedPrivateKeyPEMPassphrase) + FILETYPE_PEM, + encryptedPrivateKeyPEM, + encryptedPrivateKeyPEMPassphrase, + ) assert isinstance(key, PKey) def test_load_privatekey_passphrase_exception(self): @@ -2559,6 +2648,7 @@ class TestFunction(object): If the passphrase callback raises an exception, that exception is raised by `load_privatekey`. """ + def cb(ignored): raise ArithmeticError @@ -2576,6 +2666,7 @@ class TestFunction(object): def cb(*a): called.append(None) return b"quack" + with pytest.raises(Error) as err: load_privatekey(FILETYPE_PEM, encryptedPrivateKeyPEM, cb) assert called @@ -2592,6 +2683,7 @@ class TestFunction(object): def cb(writing): called.append(writing) return encryptedPrivateKeyPEMPassphrase + key = load_privatekey(FILETYPE_PEM, encryptedPrivateKeyPEM, cb) assert isinstance(key, PKey) assert called == [False] @@ -2603,7 +2695,8 @@ class TestFunction(object): """ with pytest.raises(ValueError): load_privatekey( - FILETYPE_PEM, encryptedPrivateKeyPEM, lambda *args: 3) + FILETYPE_PEM, encryptedPrivateKeyPEM, lambda *args: 3 + ) def test_dump_privatekey_wrong_args(self): """ @@ -2664,6 +2757,7 @@ class TestFunction(object): `crypto.load_privatekey` should raise an error when the passphrase provided by the callback is too long, not silently truncate it. """ + def cb(ignored): return "a" * 1025 @@ -2709,7 +2803,8 @@ class TestFunction(object): assert dumped_pem2 == cleartextCertificatePEM dumped_text = dump_certificate(FILETYPE_TEXT, cert) good_text = _runopenssl( - dumped_pem, b"x509", b"-noout", b"-text", b"-nameopt", b"") + dumped_pem, b"x509", b"-noout", b"-text", b"-nameopt", b"" + ) assert dumped_text == good_text def test_dump_certificate_bad_type(self): @@ -2788,7 +2883,8 @@ class TestFunction(object): `dump_certificate_request` writes a PEM, DER, and text. """ req = load_certificate_request( - FILETYPE_PEM, cleartextCertificateRequestPEM) + FILETYPE_PEM, cleartextCertificateRequestPEM + ) dumped_pem = dump_certificate_request(FILETYPE_PEM, req) assert dumped_pem == cleartextCertificateRequestPEM dumped_der = dump_certificate_request(FILETYPE_ASN1, req) @@ -2799,7 +2895,8 @@ class TestFunction(object): assert dumped_pem2 == cleartextCertificateRequestPEM dumped_text = dump_certificate_request(FILETYPE_TEXT, req) good_text = _runopenssl( - dumped_pem, b"req", b"-noout", b"-text", b"-nameopt", b"") + dumped_pem, b"req", b"-noout", b"-text", b"-nameopt", b"" + ) assert dumped_text == good_text with pytest.raises(ValueError): dump_certificate_request(100, req) @@ -2815,6 +2912,7 @@ class TestFunction(object): def cb(writing): called.append(writing) return passphrase + key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) pem = dump_privatekey(FILETYPE_PEM, key, GOOD_CIPHER, cb) assert isinstance(pem, bytes) @@ -2829,6 +2927,7 @@ class TestFunction(object): `dump_privatekey` should not overwrite the exception raised by the passphrase callback. """ + def cb(ignored): raise ArithmeticError @@ -2841,6 +2940,7 @@ class TestFunction(object): `crypto.dump_privatekey` should raise an error when the passphrase provided by the callback is too long, not silently truncate it. """ + def cb(ignored): return "a" * 1025 @@ -2946,7 +3046,7 @@ class TestPKCS7(object): type name. """ pkcs7 = load_pkcs7_data(FILETYPE_PEM, pkcs7Data) - assert pkcs7.get_type_name() == b'pkcs7-signedData' + assert pkcs7.get_type_name() == b"pkcs7-signedData" def test_attribute(self): """ @@ -2973,7 +3073,7 @@ class TestNetscapeSPKI(_PKeyInteractionTestsMixin): """ `NetscapeSPKI` can be used to create instances of that type. """ - assert is_consistent_type(NetscapeSPKI, 'NetscapeSPKI') + assert is_consistent_type(NetscapeSPKI, "NetscapeSPKI") def test_construction(self): """ @@ -3004,6 +3104,7 @@ class TestRevoked(object): """ Tests for `OpenSSL.crypto.Revoked`. """ + def test_ignores_unsupported_revoked_cert_extension_get_reason(self): """ The get_reason method on the Revoked class checks to see if the @@ -3013,7 +3114,7 @@ class TestRevoked(object): crl = load_crl(FILETYPE_PEM, crlDataUnsupportedExtension) revoked = crl.get_revoked() reason = revoked[1].get_reason() - assert reason == b'Unspecified' + assert reason == b"Unspecified" def test_ignores_unsupported_revoked_cert_extension_set_new_reason(self): crl = load_crl(FILETYPE_PEM, crlDataUnsupportedExtension) @@ -3030,7 +3131,7 @@ class TestRevoked(object): revoked = Revoked() assert isinstance(revoked, Revoked) assert type(revoked) == Revoked - assert revoked.get_serial() == b'00' + assert revoked.get_serial() == b"00" assert revoked.get_rev_date() is None assert revoked.get_reason() is None @@ -3040,17 +3141,17 @@ class TestRevoked(object): `OpenSSL.crypto.Revoked`. Confirm errors are handled with grace. """ revoked = Revoked() - ret = revoked.set_serial(b'10b') + ret = revoked.set_serial(b"10b") assert ret is None ser = revoked.get_serial() - assert ser == b'010B' + assert ser == b"010B" - revoked.set_serial(b'31ppp') # a type error would be nice + revoked.set_serial(b"31ppp") # a type error would be nice ser = revoked.get_serial() - assert ser == b'31' + assert ser == b"31" with pytest.raises(ValueError): - revoked.set_serial(b'pqrst') + revoked.set_serial(b"pqrst") with pytest.raises(TypeError): revoked.set_serial(100) @@ -3081,15 +3182,15 @@ class TestRevoked(object): ret = revoked.set_reason(r) assert ret is None reason = revoked.get_reason() - assert ( - reason.lower().replace(b' ', b'') == - r.lower().replace(b' ', b'')) + assert reason.lower().replace(b" ", b"") == r.lower().replace( + b" ", b"" + ) r = reason # again with the resp of get revoked.set_reason(None) assert revoked.get_reason() is None - @pytest.mark.parametrize('reason', [object(), 1.0, u'foo']) + @pytest.mark.parametrize("reason", [object(), 1.0, u"foo"]) def test_set_reason_wrong_args(self, reason): """ `Revoked.set_reason` raises `TypeError` if called with an argument @@ -3106,13 +3207,14 @@ class TestRevoked(object): """ revoked = Revoked() with pytest.raises(ValueError): - revoked.set_reason(b'blue') + revoked.set_reason(b"blue") class TestCRL(object): """ Tests for `OpenSSL.crypto.CRL`. """ + cert = load_certificate(FILETYPE_PEM, cleartextCertificatePEM) pkey = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) @@ -3121,9 +3223,11 @@ class TestCRL(object): intermediate_cert = load_certificate(FILETYPE_PEM, intermediate_cert_pem) intermediate_key = load_privatekey(FILETYPE_PEM, intermediate_key_pem) intermediate_server_cert = load_certificate( - FILETYPE_PEM, intermediate_server_cert_pem) + FILETYPE_PEM, intermediate_server_cert_pem + ) intermediate_server_key = load_privatekey( - FILETYPE_PEM, intermediate_server_key_pem) + FILETYPE_PEM, intermediate_server_key_pem + ) def test_construction(self): """ @@ -3142,8 +3246,8 @@ class TestCRL(object): revoked = Revoked() now = datetime.now().strftime("%Y%m%d%H%M%SZ").encode("ascii") revoked.set_rev_date(now) - revoked.set_serial(b'3ab') - revoked.set_reason(b'sUpErSeDEd') + revoked.set_serial(b"3ab") + revoked.set_reason(b"sUpErSeDEd") crl.add_revoked(revoked) return crl @@ -3160,13 +3264,17 @@ class TestCRL(object): crl = x509.load_pem_x509_crl(dumped_crl, backend) revoked = crl.get_revoked_certificate_by_serial_number(0x03AB) assert revoked is not None - assert crl.issuer == x509.Name([ - x509.NameAttribute(x509.NameOID.COUNTRY_NAME, u"US"), - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, u"IL"), - x509.NameAttribute(x509.NameOID.LOCALITY_NAME, u"Chicago"), - x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, u"Testing"), - x509.NameAttribute(x509.NameOID.COMMON_NAME, u"Testing Root CA"), - ]) + assert crl.issuer == x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, u"US"), + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, u"IL"), + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, u"Chicago"), + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, u"Testing"), + x509.NameAttribute( + x509.NameOID.COMMON_NAME, u"Testing Root CA" + ), + ] + ) def test_export_der(self): """ @@ -3183,13 +3291,17 @@ class TestCRL(object): crl = x509.load_der_x509_crl(dumped_crl, backend) revoked = crl.get_revoked_certificate_by_serial_number(0x03AB) assert revoked is not None - assert crl.issuer == x509.Name([ - x509.NameAttribute(x509.NameOID.COUNTRY_NAME, u"US"), - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, u"IL"), - x509.NameAttribute(x509.NameOID.LOCALITY_NAME, u"Chicago"), - x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, u"Testing"), - x509.NameAttribute(x509.NameOID.COMMON_NAME, u"Testing Root CA"), - ]) + assert crl.issuer == x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, u"US"), + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, u"IL"), + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, u"Chicago"), + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, u"Testing"), + x509.NameAttribute( + x509.NameOID.COMMON_NAME, u"Testing Root CA" + ), + ] + ) # Flaky because we compare the output of running commands which sometimes # varies by 1 second @@ -3206,8 +3318,14 @@ class TestCRL(object): self.cert, self.pkey, FILETYPE_ASN1, digest=b"md5" ) text = _runopenssl( - dumped_crl, b"crl", b"-noout", b"-text", b"-inform", b"DER", - b"-nameopt", b"" + dumped_crl, + b"crl", + b"-noout", + b"-text", + b"-inform", + b"DER", + b"-nameopt", + b"", ) # text format @@ -3224,7 +3342,7 @@ class TestCRL(object): crl = self._get_crl() dumped_crl = crl.export(self.cert, self.pkey, digest=b"sha1") text = _runopenssl(dumped_crl, b"crl", b"-noout", b"-text") - text.index(b'Signature Algorithm: sha1') + text.index(b"Signature Algorithm: sha1") def test_export_md5_digest(self): """ @@ -3237,7 +3355,7 @@ class TestCRL(object): assert 0 == len(catcher) dumped_crl = crl.export(self.cert, self.pkey, digest=b"md5") text = _runopenssl(dumped_crl, b"crl", b"-noout", b"-text") - text.index(b'Signature Algorithm: md5') + text.index(b"Signature Algorithm: md5") def test_export_default_digest(self): """ @@ -3303,7 +3421,8 @@ class TestCRL(object): crl = CRL() with pytest.raises(ValueError): crl.export( - self.cert, self.pkey, FILETYPE_PEM, 10, b"strange-digest") + self.cert, self.pkey, FILETYPE_PEM, 10, b"strange-digest" + ) def test_get_revoked(self): """ @@ -3315,18 +3434,18 @@ class TestCRL(object): revoked = Revoked() now = datetime.now().strftime("%Y%m%d%H%M%SZ").encode("ascii") revoked.set_rev_date(now) - revoked.set_serial(b'3ab') + revoked.set_serial(b"3ab") crl.add_revoked(revoked) - revoked.set_serial(b'100') - revoked.set_reason(b'sUpErSeDEd') + revoked.set_serial(b"100") + revoked.set_reason(b"sUpErSeDEd") crl.add_revoked(revoked) revs = crl.get_revoked() assert len(revs) == 2 assert type(revs[0]) == Revoked assert type(revs[1]) == Revoked - assert revs[0].get_serial() == b'03AB' - assert revs[1].get_serial() == b'0100' + assert revs[0].get_serial() == b"03AB" + assert revs[1].get_serial() == b"0100" assert revs[0].get_rev_date() == now assert revs[1].get_rev_date() == now @@ -3338,19 +3457,19 @@ class TestCRL(object): crl = load_crl(FILETYPE_PEM, crlData) revs = crl.get_revoked() assert len(revs) == 2 - assert revs[0].get_serial() == b'03AB' + assert revs[0].get_serial() == b"03AB" assert revs[0].get_reason() is None - assert revs[1].get_serial() == b'0100' - assert revs[1].get_reason() == b'Superseded' + assert revs[1].get_serial() == b"0100" + assert revs[1].get_reason() == b"Superseded" der = _runopenssl(crlData, b"crl", b"-outform", b"DER") crl = load_crl(FILETYPE_ASN1, der) revs = crl.get_revoked() assert len(revs) == 2 - assert revs[0].get_serial() == b'03AB' + assert revs[0].get_serial() == b"03AB" assert revs[0].get_reason() is None - assert revs[1].get_serial() == b'0100' - assert revs[1].get_reason() == b'Superseded' + assert revs[1].get_serial() == b"0100" + assert revs[1].get_reason() == b"Superseded" def test_load_crl_bad_filetype(self): """ @@ -3375,7 +3494,7 @@ class TestCRL(object): """ crl = load_crl(FILETYPE_PEM, crlData) assert isinstance(crl.get_issuer(), X509Name) - assert crl.get_issuer().CN == 'Testing Root CA' + assert crl.get_issuer().CN == "Testing Root CA" def test_dump_crl(self): """ @@ -3398,15 +3517,15 @@ class TestCRL(object): # FIXME: This string splicing is an unfortunate implementation # detail that has been reported in # https://github.com/pyca/pyopenssl/issues/258 - serial = hex(cert.get_serial_number())[2:].encode('utf-8') + serial = hex(cert.get_serial_number())[2:].encode("utf-8") revoked.set_serial(serial) - revoked.set_reason(b'unspecified') - revoked.set_rev_date(b'20140601000000Z') + revoked.set_reason(b"unspecified") + revoked.set_rev_date(b"20140601000000Z") crl.add_revoked(revoked) crl.set_version(1) - crl.set_lastUpdate(b'20140601000000Z') - crl.set_nextUpdate(b'20180601000000Z') - crl.sign(issuer_cert, issuer_key, digest=b'sha512') + crl.set_lastUpdate(b"20140601000000Z") + crl.set_nextUpdate(b"20180601000000Z") + crl.sign(issuer_cert, issuer_key, digest=b"sha512") return crl def test_verify_with_revoked(self): @@ -3418,17 +3537,20 @@ class TestCRL(object): store.add_cert(self.root_cert) store.add_cert(self.intermediate_cert) root_crl = self._make_test_crl( - self.root_cert, self.root_key, certs=[self.intermediate_cert]) + self.root_cert, self.root_key, certs=[self.intermediate_cert] + ) intermediate_crl = self._make_test_crl( - self.intermediate_cert, self.intermediate_key, certs=[]) + self.intermediate_cert, self.intermediate_key, certs=[] + ) store.add_crl(root_crl) store.add_crl(intermediate_crl) store.set_flags( - X509StoreFlags.CRL_CHECK | X509StoreFlags.CRL_CHECK_ALL) + X509StoreFlags.CRL_CHECK | X509StoreFlags.CRL_CHECK_ALL + ) store_ctx = X509StoreContext(store, self.intermediate_server_cert) with pytest.raises(X509StoreContextError) as err: store_ctx.verify_certificate() - assert err.value.args[0][2] == 'certificate revoked' + assert err.value.args[0][2] == "certificate revoked" def test_verify_with_missing_crl(self): """ @@ -3439,15 +3561,17 @@ class TestCRL(object): store.add_cert(self.root_cert) store.add_cert(self.intermediate_cert) root_crl = self._make_test_crl( - self.root_cert, self.root_key, certs=[self.intermediate_cert]) + self.root_cert, self.root_key, certs=[self.intermediate_cert] + ) store.add_crl(root_crl) store.set_flags( - X509StoreFlags.CRL_CHECK | X509StoreFlags.CRL_CHECK_ALL) + X509StoreFlags.CRL_CHECK | X509StoreFlags.CRL_CHECK_ALL + ) store_ctx = X509StoreContext(store, self.intermediate_server_cert) with pytest.raises(X509StoreContextError) as err: store_ctx.verify_certificate() - assert err.value.args[0][2] == 'unable to get certificate CRL' - assert err.value.certificate.get_subject().CN == 'intermediate-service' + assert err.value.args[0][2] == "unable to get certificate CRL" + assert err.value.certificate.get_subject().CN == "intermediate-service" def test_convert_from_cryptography(self): crypto_crl = x509.load_pem_x509_crl(crlData, backend) @@ -3468,10 +3592,12 @@ class TestX509StoreContext(object): """ Tests for `OpenSSL.crypto.X509StoreContext`. """ + root_cert = load_certificate(FILETYPE_PEM, root_cert_pem) intermediate_cert = load_certificate(FILETYPE_PEM, intermediate_cert_pem) intermediate_server_cert = load_certificate( - FILETYPE_PEM, intermediate_server_cert_pem) + FILETYPE_PEM, intermediate_server_cert_pem + ) def test_valid(self): """ @@ -3516,8 +3642,8 @@ class TestX509StoreContext(object): with pytest.raises(X509StoreContextError) as exc: store_ctx.verify_certificate() - assert exc.value.args[0][2] == 'self signed certificate' - assert exc.value.certificate.get_subject().CN == 'Testing Root CA' + assert exc.value.args[0][2] == "self signed certificate" + assert exc.value.certificate.get_subject().CN == "Testing Root CA" def test_invalid_chain_no_root(self): """ @@ -3531,8 +3657,8 @@ class TestX509StoreContext(object): with pytest.raises(X509StoreContextError) as exc: store_ctx.verify_certificate() - assert exc.value.args[0][2] == 'unable to get issuer certificate' - assert exc.value.certificate.get_subject().CN == 'intermediate' + assert exc.value.args[0][2] == "unable to get issuer certificate" + assert exc.value.certificate.get_subject().CN == "intermediate" def test_invalid_chain_no_intermediate(self): """ @@ -3546,8 +3672,8 @@ class TestX509StoreContext(object): with pytest.raises(X509StoreContextError) as exc: store_ctx.verify_certificate() - assert exc.value.args[0][2] == 'unable to get local issuer certificate' - assert exc.value.certificate.get_subject().CN == 'intermediate-service' + assert exc.value.args[0][2] == "unable to get local issuer certificate" + assert exc.value.certificate.get_subject().CN == "intermediate-service" def test_modification_pre_verify(self): """ @@ -3564,8 +3690,8 @@ class TestX509StoreContext(object): with pytest.raises(X509StoreContextError) as exc: store_ctx.verify_certificate() - assert exc.value.args[0][2] == 'unable to get issuer certificate' - assert exc.value.certificate.get_subject().CN == 'intermediate' + assert exc.value.args[0][2] == "unable to get issuer certificate" + assert exc.value.certificate.get_subject().CN == "intermediate" store_ctx.set_store(store_good) assert store_ctx.verify_certificate() is None @@ -3581,7 +3707,7 @@ class TestX509StoreContext(object): expire_time = self.intermediate_server_cert.get_notAfter() expire_datetime = datetime.strptime( - expire_time.decode('utf-8'), '%Y%m%d%H%M%SZ' + expire_time.decode("utf-8"), "%Y%m%d%H%M%SZ" ) store.set_time(expire_datetime) @@ -3589,7 +3715,7 @@ class TestX509StoreContext(object): with pytest.raises(X509StoreContextError) as exc: store_ctx.verify_certificate() - assert exc.value.args[0][2] == 'certificate has expired' + assert exc.value.args[0][2] == "certificate has expired" class TestSignVerify(object): @@ -3606,7 +3732,8 @@ class TestSignVerify(object): b"thirteen. Winston Smith, his chin nuzzled into his breast in an " b"effort to escape the vile wind, slipped quickly through the " b"glass doors of Victory Mansions, though not quickly enough to " - b"prevent a swirl of gritty dust from entering along with him.") + b"prevent a swirl of gritty dust from entering along with him." + ) # sign the content with this private key priv_key = load_privatekey(FILETYPE_PEM, root_key_pem) @@ -3615,7 +3742,7 @@ class TestSignVerify(object): # certificate unrelated to priv_key, used to trigger an error bad_cert = load_certificate(FILETYPE_PEM, server_cert_pem) - for digest in ['md5', 'sha1']: + for digest in ["md5", "sha1"]: sig = sign(priv_key, content, digest) # Verify the signature of content, will throw an exception if @@ -3654,22 +3781,20 @@ class TestSignVerify(object): priv_key = load_privatekey(FILETYPE_PEM, root_key_pem) cert = load_certificate(FILETYPE_PEM, root_cert_pem) - for digest in ['md5', 'sha1']: + for digest in ["md5", "sha1"]: with pytest.warns(DeprecationWarning) as w: simplefilter("always") sig = sign(priv_key, content, digest) - assert ( - "{0} for data is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + assert "{0} for data is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) == str(w[-1].message) with pytest.warns(DeprecationWarning) as w: simplefilter("always") verify(cert, sig, content, digest) - assert ( - "{0} for data is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + assert "{0} for data is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) == str(w[-1].message) def test_sign_verify_ecdsa(self): """ @@ -3708,7 +3833,8 @@ class TestSignVerify(object): b"thirteen. Winston Smith, his chin nuzzled into his breast in an " b"effort to escape the vile wind, slipped quickly through the " b"glass doors of Victory Mansions, though not quickly enough to " - b"prevent a swirl of gritty dust from entering along with him.") + b"prevent a swirl of gritty dust from entering along with him." + ) priv_key = load_privatekey(FILETYPE_PEM, large_key_pem) sign(priv_key, content, "sha1") @@ -3780,6 +3906,7 @@ class TestEllipticCurveEquality(EqualityTestsMixin): """ Tests `_EllipticCurve`'s implementation of ``==`` and ``!=``. """ + curve_factory = EllipticCurveFactory() if curve_factory.curve_name is None: @@ -3804,6 +3931,7 @@ class TestEllipticCurveHash(object): Tests for `_EllipticCurve`'s implementation of hashing (thus use as an item in a `dict` or `set`). """ + curve_factory = EllipticCurveFactory() if curve_factory.curve_name is None: @@ -3824,7 +3952,7 @@ class TestEllipticCurveHash(object): does not contain that curve. """ curve = get_elliptic_curve(self.curve_factory.curve_name) - curves = set([ - get_elliptic_curve(self.curve_factory.another_curve_name) - ]) + curves = set( + [get_elliptic_curve(self.curve_factory.another_curve_name)] + ) assert curve not in curves diff --git a/tests/test_rand.py b/tests/test_rand.py index e04a24c..763d711 100644 --- a/tests/test_rand.py +++ b/tests/test_rand.py @@ -11,11 +11,7 @@ from OpenSSL import rand class TestRand(object): - - @pytest.mark.parametrize('args', [ - (b"foo", None), - (None, 3), - ]) + @pytest.mark.parametrize("args", [(b"foo", None), (None, 3)]) def test_add_wrong_args(self, args): """ `OpenSSL.rand.add` raises `TypeError` if called with arguments not of @@ -28,7 +24,7 @@ class TestRand(object): """ `OpenSSL.rand.add` adds entropy to the PRNG. """ - rand.add(b'hamburger', 3) + rand.add(b"hamburger", 3) def test_status(self): """ diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 2cee928..ba5b638 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -11,7 +11,13 @@ import uuid from gc import collect, get_referrers from errno import ( - EAFNOSUPPORT, ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN) + EAFNOSUPPORT, + ECONNREFUSED, + EINPROGRESS, + EWOULDBLOCK, + EPIPE, + ESHUTDOWN, +) from sys import platform, getfilesystemencoding from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket from os import makedirs @@ -45,49 +51,93 @@ from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN from OpenSSL.SSL import ( - SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, - TLSv1_1_METHOD, TLSv1_2_METHOD) + SSLv2_METHOD, + SSLv3_METHOD, + SSLv23_METHOD, + TLSv1_METHOD, + TLSv1_1_METHOD, + TLSv1_2_METHOD, +) from OpenSSL.SSL import OP_SINGLE_DH_USE, OP_NO_SSLv2, OP_NO_SSLv3 from OpenSSL.SSL import ( - VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE) + VERIFY_PEER, + VERIFY_FAIL_IF_NO_PEER_CERT, + VERIFY_CLIENT_ONCE, + VERIFY_NONE, +) from OpenSSL import SSL from OpenSSL.SSL import ( - SESS_CACHE_OFF, SESS_CACHE_CLIENT, SESS_CACHE_SERVER, SESS_CACHE_BOTH, - SESS_CACHE_NO_AUTO_CLEAR, SESS_CACHE_NO_INTERNAL_LOOKUP, - SESS_CACHE_NO_INTERNAL_STORE, SESS_CACHE_NO_INTERNAL) + SESS_CACHE_OFF, + SESS_CACHE_CLIENT, + SESS_CACHE_SERVER, + SESS_CACHE_BOTH, + SESS_CACHE_NO_AUTO_CLEAR, + SESS_CACHE_NO_INTERNAL_LOOKUP, + SESS_CACHE_NO_INTERNAL_STORE, + SESS_CACHE_NO_INTERNAL, +) from OpenSSL.SSL import ( - Error, SysCallError, WantReadError, WantWriteError, ZeroReturnError) -from OpenSSL.SSL import ( - Context, Session, Connection, SSLeay_version) + Error, + SysCallError, + WantReadError, + WantWriteError, + ZeroReturnError, +) +from OpenSSL.SSL import Context, Session, Connection, SSLeay_version from OpenSSL.SSL import _make_requires from OpenSSL._util import ffi as _ffi, lib as _lib from OpenSSL.SSL import ( - OP_NO_QUERY_MTU, OP_COOKIE_EXCHANGE, OP_NO_TICKET, OP_NO_COMPRESSION, - MODE_RELEASE_BUFFERS, NO_OVERLAPPING_PROTOCOLS) + OP_NO_QUERY_MTU, + OP_COOKIE_EXCHANGE, + OP_NO_TICKET, + OP_NO_COMPRESSION, + MODE_RELEASE_BUFFERS, + NO_OVERLAPPING_PROTOCOLS, +) from OpenSSL.SSL import ( - SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, - SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, - SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, - SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, - SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE) + SSL_ST_CONNECT, + SSL_ST_ACCEPT, + SSL_ST_MASK, + SSL_CB_LOOP, + SSL_CB_EXIT, + SSL_CB_READ, + SSL_CB_WRITE, + SSL_CB_ALERT, + SSL_CB_READ_ALERT, + SSL_CB_WRITE_ALERT, + SSL_CB_ACCEPT_LOOP, + SSL_CB_ACCEPT_EXIT, + SSL_CB_CONNECT_LOOP, + SSL_CB_CONNECT_EXIT, + SSL_CB_HANDSHAKE_START, + SSL_CB_HANDSHAKE_DONE, +) try: from OpenSSL.SSL import ( - SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE + SSL_ST_INIT, + SSL_ST_BEFORE, + SSL_ST_OK, + SSL_ST_RENEGOTIATE, ) except ImportError: SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None from .util import WARNING_TYPE_EXPECTED, NON_ASCII, is_consistent_type from .test_crypto import ( - cleartextCertificatePEM, cleartextPrivateKeyPEM, - client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, - root_cert_pem) + cleartextCertificatePEM, + cleartextPrivateKeyPEM, + client_cert_pem, + client_key_pem, + server_cert_pem, + server_key_pem, + root_cert_pem, +) # openssl dhparam 1024 -out dh-1024.pem (note that 1024 is a small number of @@ -148,7 +198,7 @@ def socket_pair(): """ # Connect a pair of sockets port = socket_any_family() - port.bind(('', 0)) + port.bind(("", 0)) port.listen(1) client = socket(port.family) client.setblocking(False) @@ -191,8 +241,8 @@ def _create_certificate_chain(): 2. A new intermediate certificate signed by cacert (icert) 3. A new server certificate signed by icert (scert) """ - caext = X509Extension(b'basicConstraints', False, b'CA:true') - not_after_date = (datetime.date.today() + datetime.timedelta(days=365)) + caext = X509Extension(b"basicConstraints", False, b"CA:true") + not_after_date = datetime.date.today() + datetime.timedelta(days=365) not_after = not_after_date.strftime("%Y%m%d%H%M%SZ").encode("ascii") # Step 1 @@ -233,8 +283,9 @@ def _create_certificate_chain(): scert.set_pubkey(skey) scert.set_notBefore(b"20000101000000Z") scert.set_notAfter(not_after) - scert.add_extensions([ - X509Extension(b'basicConstraints', True, b'CA:false')]) + scert.add_extensions( + [X509Extension(b"basicConstraints", True, b"CA:false")] + ) scert.set_serial_number(0) scert.sign(ikey, "sha1") @@ -293,8 +344,10 @@ def interact_in_memory(client_conn, server_conn): # Copy stuff from each side's send buffer to the other side's # receive buffer. - for (read, write) in [(client_conn, server_conn), - (server_conn, client_conn)]: + for (read, write) in [ + (client_conn, server_conn), + (server_conn, client_conn), + ]: # Give the side a chance to generate some more bytes, or succeed. try: @@ -344,6 +397,7 @@ class TestVersion(object): Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and `OpenSSL.SSL.OPENSSL_VERSION_NUMBER`. """ + def test_OPENSSL_VERSION_NUMBER(self): """ `OPENSSL_VERSION_NUMBER` is an integer with status in the low byte and @@ -357,8 +411,13 @@ class TestVersion(object): number of version strings based on that indicator. """ versions = {} - for t in [SSLEAY_VERSION, SSLEAY_CFLAGS, SSLEAY_BUILT_ON, - SSLEAY_PLATFORM, SSLEAY_DIR]: + for t in [ + SSLEAY_VERSION, + SSLEAY_CFLAGS, + SSLEAY_BUILT_ON, + SSLEAY_PLATFORM, + SSLEAY_DIR, + ]: version = SSLeay_version(t) versions[version] = t assert isinstance(version, bytes) @@ -371,19 +430,17 @@ def ca_file(tmpdir): Create a valid PEM file with CA certificates and return the path. """ key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend() + public_exponent=65537, key_size=2048, backend=default_backend() ) public_key = key.public_key() builder = x509.CertificateBuilder() - builder = builder.subject_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), - ])) - builder = builder.issuer_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), - ])) + builder = builder.subject_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")]) + ) + builder = builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")]) + ) one_day = datetime.timedelta(1, 0, 0) builder = builder.not_valid_before(datetime.datetime.today() - one_day) builder = builder.not_valid_after(datetime.datetime.today() + one_day) @@ -394,15 +451,12 @@ def ca_file(tmpdir): ) certificate = builder.sign( - private_key=key, algorithm=hashes.SHA256(), - backend=default_backend() + private_key=key, algorithm=hashes.SHA256(), backend=default_backend() ) ca_file = tmpdir.join("test.pem") ca_file.write_binary( - certificate.public_bytes( - encoding=serialization.Encoding.PEM, - ) + certificate.public_bytes(encoding=serialization.Encoding.PEM,) ) return str(ca_file).encode("ascii") @@ -420,10 +474,11 @@ class TestContext(object): """ Unit tests for `OpenSSL.SSL.Context`. """ - @pytest.mark.parametrize("cipher_string", [ - b"hello world:AES128-SHA", - u"hello world:AES128-SHA", - ]) + + @pytest.mark.parametrize( + "cipher_string", + [b"hello world:AES128-SHA", u"hello world:AES128-SHA"], + ) def test_set_cipher_list(self, context, cipher_string): """ `Context.set_cipher_list` accepts both byte and unicode strings @@ -453,14 +508,8 @@ class TestContext(object): with pytest.raises(Error) as excinfo: context.set_cipher_list(b"imaginary-cipher") assert excinfo.value.args == ( - [ - ( - 'SSL routines', - 'SSL_CTX_set_cipher_list', - 'no cipher match', - ), - ], - ) + [("SSL routines", "SSL_CTX_set_cipher_list", "no cipher match",)], + ) def test_load_client_ca(self, context, ca_file): """ @@ -484,9 +533,7 @@ class TestContext(object): """ Passing the path as unicode raises a warning but works. """ - pytest.deprecated_call( - context.load_client_ca, ca_file.decode("ascii") - ) + pytest.deprecated_call(context.load_client_ca, ca_file.decode("ascii")) def test_set_session_id(self, context): """ @@ -502,9 +549,11 @@ class TestContext(object): context.set_session_id(b"abc" * 1000) assert [ - ("SSL routines", - "SSL_CTX_set_session_id_context", - "ssl session id context too long") + ( + "SSL routines", + "SSL_CTX_set_session_id_context", + "ssl session id context too long", + ) ] == e.value.args[0] def test_set_session_id_unicode(self, context): @@ -542,7 +591,7 @@ class TestContext(object): """ `Context` can be used to create instances of that type. """ - assert is_consistent_type(Context, 'Context', TLSv1_METHOD) + assert is_consistent_type(Context, "Context", TLSv1_METHOD) def test_use_privatekey(self): """ @@ -573,14 +622,12 @@ class TestContext(object): key.generate_key(TYPE_RSA, 512) with open(pemfile, "wt") as pem: - pem.write( - dump_privatekey(FILETYPE_PEM, key).decode("ascii") - ) + pem.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii")) ctx = Context(TLSv1_METHOD) ctx.use_privatekey_file(pemfile, filetype) - @pytest.mark.parametrize('filetype', [object(), "", None, 1.0]) + @pytest.mark.parametrize("filetype", [object(), "", None, 1.0]) def test_wrong_privatekey_file_wrong_args(self, tmpfile, filetype): """ `Context.use_privatekey_file` raises `TypeError` when called with @@ -596,8 +643,7 @@ class TestContext(object): instance giving the file name to ``Context.use_privatekey_file``. """ self._use_privatekey_file_test( - tmpfile + NON_ASCII.encode(getfilesystemencoding()), - FILETYPE_PEM, + tmpfile + NON_ASCII.encode(getfilesystemencoding()), FILETYPE_PEM, ) def test_use_privatekey_file_unicode(self, tmpfile): @@ -606,8 +652,7 @@ class TestContext(object): instance giving the file name to ``Context.use_privatekey_file``. """ self._use_privatekey_file_test( - tmpfile.decode(getfilesystemencoding()) + NON_ASCII, - FILETYPE_PEM, + tmpfile.decode(getfilesystemencoding()) + NON_ASCII, FILETYPE_PEM, ) def test_use_certificate_wrong_args(self): @@ -814,8 +859,8 @@ class TestContext(object): key = PKey() key.generate_key(TYPE_RSA, 512) pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase) - with open(tmpfile, 'w') as fObj: - fObj.write(pem.decode('ascii')) + with open(tmpfile, "w") as fObj: + fObj.write(pem.decode("ascii")) return tmpfile def test_set_passwd_cb_wrong_args(self): @@ -839,6 +884,7 @@ class TestContext(object): def passphraseCallback(maxlen, verify, extra): calledWith.append((maxlen, verify, extra)) return passphrase + context = Context(TLSv1_METHOD) context.set_passwd_cb(passphraseCallback) context.use_privatekey_file(pemFile) @@ -926,12 +972,15 @@ class TestContext(object): def info(conn, where, ret): called.append((conn, where, ret)) + context = Context(TLSv1_METHOD) context.set_info_callback(info) context.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) context.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) serverSSL = Connection(context, server) serverSSL.set_accept_state() @@ -944,10 +993,13 @@ class TestContext(object): # assert it is called with the right Connection instance. It would # also be good to assert *something* about `where` and `ret`. notConnections = [ - conn for (conn, where, ret) in called - if not isinstance(conn, Connection)] - assert [] == notConnections, ( - "Some info callback arguments were not Connection instances.") + conn + for (conn, where, ret) in called + if not isinstance(conn, Connection) + ] + assert ( + [] == notConnections + ), "Some info callback arguments were not Connection instances." def _load_verify_locations_test(self, *args): """ @@ -963,16 +1015,19 @@ class TestContext(object): # connection will fail. clientContext.set_verify( VERIFY_PEER, - lambda conn, cert, errno, depth, preverify_ok: preverify_ok) + lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + ) clientSSL = Connection(clientContext, client) clientSSL.set_connect_state() serverContext = Context(TLSv1_METHOD) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) serverSSL = Connection(serverContext, server) serverSSL.set_accept_state() @@ -984,7 +1039,7 @@ class TestContext(object): handshake(clientSSL, serverSSL) cert = clientSSL.get_peer_certificate() - assert cert.get_subject().CN == 'Testing Root CA' + assert cert.get_subject().CN == "Testing Root CA" def _load_verify_cafile(self, cafile): """ @@ -993,8 +1048,8 @@ class TestContext(object): certificate is used as a trust root for the purposes of verifying connections created using that `Context`. """ - with open(cafile, 'w') as fObj: - fObj.write(cleartextCertificatePEM.decode('ascii')) + with open(cafile, "w") as fObj: + fObj.write(cleartextCertificatePEM.decode("ascii")) self._load_verify_locations_test(cafile) @@ -1035,10 +1090,10 @@ class TestContext(object): # Hash values computed manually with c_rehash to avoid depending on # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other # from OpenSSL 1.0.0. - for name in [b'c7adac82.0', b'c3705638.0']: + for name in [b"c7adac82.0", b"c3705638.0"]: cafile = join_bytes_or_unicode(capath, name) - with open(cafile, 'w') as fObj: - fObj.write(cleartextCertificatePEM.decode('ascii')) + with open(cafile, "w") as fObj: + fObj.write(cleartextCertificatePEM.decode("ascii")) self._load_verify_locations_test(None, capath) @@ -1074,7 +1129,7 @@ class TestContext(object): @pytest.mark.skipif( not platform.startswith("linux"), reason="Loading fallback paths is a linux-specific behavior to " - "accommodate pyca/cryptography manylinux1 wheels" + "accommodate pyca/cryptography manylinux1 wheels", ) def test_fallback_default_verify_paths(self, monkeypatch): """ @@ -1092,12 +1147,12 @@ class TestContext(object): monkeypatch.setattr( SSL, "_CRYPTOGRAPHY_MANYLINUX1_CA_FILE", - _ffi.string(_lib.X509_get_default_cert_file()) + _ffi.string(_lib.X509_get_default_cert_file()), ) monkeypatch.setattr( SSL, "_CRYPTOGRAPHY_MANYLINUX1_CA_DIR", - _ffi.string(_lib.X509_get_default_cert_dir()) + _ffi.string(_lib.X509_get_default_cert_dir()), ) context.set_default_verify_paths() store = context.get_cert_store() @@ -1127,9 +1182,9 @@ class TestContext(object): monkeypatch.setattr( _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 ) - dir_env_var = _ffi.string( - _lib.X509_get_default_cert_dir_env() - ).decode("ascii") + dir_env_var = _ffi.string(_lib.X509_get_default_cert_dir_env()).decode( + "ascii" + ) file_env_var = _ffi.string( _lib.X509_get_default_cert_file_env() ).decode("ascii") @@ -1138,16 +1193,14 @@ class TestContext(object): context.set_default_verify_paths() monkeypatch.setattr( - context, - "_fallback_default_verify_paths", - raiser(SystemError) + context, "_fallback_default_verify_paths", raiser(SystemError) ) context.set_default_verify_paths() @pytest.mark.skipif( platform == "win32", reason="set_default_verify_paths appears not to work on Windows. " - "See LP#404343 and LP#404344." + "See LP#404343 and LP#404344.", ) def test_set_default_verify_paths(self): """ @@ -1165,7 +1218,8 @@ class TestContext(object): context.set_default_verify_paths() context.set_verify( VERIFY_PEER, - lambda conn, cert, errno, depth, preverify_ok: preverify_ok) + lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + ) client = socket_any_family() client.connect(("encrypted.google.com", 443)) @@ -1183,9 +1237,7 @@ class TestContext(object): """ context = Context(TLSv1_METHOD) context._fallback_default_verify_paths([], []) - context._fallback_default_verify_paths( - ["/not/a/file"], ["/not/a/dir"] - ) + context._fallback_default_verify_paths(["/not/a/file"], ["/not/a/dir"]) def test_add_extra_chain_cert_invalid_cert(self): """ @@ -1225,9 +1277,11 @@ class TestContext(object): """ serverContext = Context(TLSv1_METHOD) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) serverConnection = Connection(serverContext, None) class VerifyCallback(object): @@ -1254,9 +1308,11 @@ class TestContext(object): """ serverContext = Context(TLSv1_METHOD) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) serverConnection = Connection(serverContext, None) def verify_cb_get_subject(conn, cert, errnum, depth, ok): @@ -1278,14 +1334,17 @@ class TestContext(object): """ serverContext = Context(TLSv1_2_METHOD) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM) + ) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, cleartextCertificatePEM) + ) clientContext = Context(TLSv1_2_METHOD) def verify_callback(*args): raise Exception("silly verify failure") + clientContext.set_verify(VERIFY_PEER, verify_callback) with pytest.raises(Exception) as exc: @@ -1310,17 +1369,17 @@ class TestContext(object): # Dump the CA certificate to a file because that's the only way to load # it as a trusted CA in the client context. - for cert, name in [(cacert, 'ca.pem'), - (icert, 'i.pem'), - (scert, 's.pem')]: - with tmpdir.join(name).open('w') as f: - f.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii')) - - for key, name in [(cakey, 'ca.key'), - (ikey, 'i.key'), - (skey, 's.key')]: - with tmpdir.join(name).open('w') as f: - f.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii')) + for cert, name in [ + (cacert, "ca.pem"), + (icert, "i.pem"), + (scert, "s.pem"), + ]: + with tmpdir.join(name).open("w") as f: + f.write(dump_certificate(FILETYPE_PEM, cert).decode("ascii")) + + for key, name in [(cakey, "ca.key"), (ikey, "i.key"), (skey, "s.key")]: + with tmpdir.join(name).open("w") as f: + f.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii")) # Create the server context serverContext = Context(TLSv1_METHOD) @@ -1332,7 +1391,8 @@ class TestContext(object): # Create the client clientContext = Context(TLSv1_METHOD) clientContext.set_verify( - VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) + VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb + ) clientContext.load_verify_locations(str(tmpdir.join("ca.pem"))) # Try it out. @@ -1356,14 +1416,14 @@ class TestContext(object): caFile = join_bytes_or_unicode(certdir, "ca.pem") # Write out the chain file. - with open(chainFile, 'wb') as fObj: + with open(chainFile, "wb") as fObj: # Most specific to least general. fObj.write(dump_certificate(FILETYPE_PEM, scert)) fObj.write(dump_certificate(FILETYPE_PEM, icert)) fObj.write(dump_certificate(FILETYPE_PEM, cacert)) - with open(caFile, 'w') as fObj: - fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii')) + with open(caFile, "w") as fObj: + fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode("ascii")) serverContext = Context(TLSv1_METHOD) serverContext.use_certificate_chain_file(chainFile) @@ -1371,7 +1431,8 @@ class TestContext(object): clientContext = Context(TLSv1_METHOD) clientContext.set_verify( - VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) + VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb + ) clientContext.load_verify_locations(caFile) self._handshake_test(serverContext, clientContext) @@ -1423,10 +1484,11 @@ class TestContext(object): context = Context(TLSv1_METHOD) assert context.get_verify_mode() == 0 context.set_verify( - VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None) + VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None + ) assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) - @pytest.mark.parametrize('mode', [None, 1.0, object(), 'mode']) + @pytest.mark.parametrize("mode", [None, 1.0, object(), "mode"]) def test_set_verify_wrong_mode_arg(self, mode): """ `Context.set_verify` raises `TypeError` if the first argument is @@ -1436,7 +1498,7 @@ class TestContext(object): with pytest.raises(TypeError): context.set_verify(mode=mode, callback=lambda *args: None) - @pytest.mark.parametrize('callback', [None, 1.0, 'mode', ('foo', 'bar')]) + @pytest.mark.parametrize("callback", [None, 1.0, "mode", ("foo", "bar")]) def test_set_verify_wrong_callable_arg(self, callback): """ `Context.set_verify` raises `TypeError` if the second argument @@ -1547,7 +1609,7 @@ class TestContext(object): """ context = Context(TLSv1_METHOD) with pytest.raises(TypeError): - context.set_tlsext_use_srtp(text_type('SRTP_AES128_CM_SHA1_80')) + context.set_tlsext_use_srtp(text_type("SRTP_AES128_CM_SHA1_80")) def test_set_tlsext_use_srtp_invalid_profile(self): """ @@ -1557,7 +1619,7 @@ class TestContext(object): """ context = Context(TLSv1_METHOD) with pytest.raises(Error): - context.set_tlsext_use_srtp(b'SRTP_BOGUS') + context.set_tlsext_use_srtp(b"SRTP_BOGUS") def test_set_tlsext_use_srtp_valid(self): """ @@ -1566,7 +1628,7 @@ class TestContext(object): It does not return anything. """ context = Context(TLSv1_METHOD) - assert context.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80') is None + assert context.set_tlsext_use_srtp(b"SRTP_AES128_CM_SHA1_80") is None class TestServerNameCallback(object): @@ -1574,11 +1636,13 @@ class TestServerNameCallback(object): Tests for `Context.set_tlsext_servername_callback` and its interaction with `Connection`. """ + def test_old_callback_forgotten(self): """ If `Context.set_tlsext_servername_callback` is used to specify a new callback, the one it replaces is dereferenced. """ + def callback(connection): # pragma: no cover pass @@ -1616,6 +1680,7 @@ class TestServerNameCallback(object): def servername(conn): args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) context.set_tlsext_servername_callback(servername) @@ -1627,7 +1692,8 @@ class TestServerNameCallback(object): # Necessary to actually accept the connection context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(context, None) @@ -1651,13 +1717,15 @@ class TestServerNameCallback(object): def servername(conn): args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) context.set_tlsext_servername_callback(servername) # Necessary to actually accept the connection context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(context, None) @@ -1679,6 +1747,7 @@ class TestNextProtoNegotiation(object): """ Test for Next Protocol Negotiation in PyOpenSSL. """ + def test_npn_success(self): """ Tests that clients and servers that agree on the negotiated next @@ -1690,11 +1759,11 @@ class TestNextProtoNegotiation(object): def advertise(conn): advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] + return [b"http/1.1", b"spdy/2"] def select(conn, options): select_args.append((conn, options)) - return b'spdy/2' + return b"spdy/2" server_context = Context(TLSv1_METHOD) server_context.set_npn_advertise_callback(advertise) @@ -1704,9 +1773,11 @@ class TestNextProtoNegotiation(object): # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1718,10 +1789,10 @@ class TestNextProtoNegotiation(object): interact_in_memory(server, client) assert advertise_args == [(server,)] - assert select_args == [(client, [b'http/1.1', b'spdy/2'])] + assert select_args == [(client, [b"http/1.1", b"spdy/2"])] - assert server.get_next_proto_negotiated() == b'spdy/2' - assert client.get_next_proto_negotiated() == b'spdy/2' + assert server.get_next_proto_negotiated() == b"spdy/2" + assert client.get_next_proto_negotiated() == b"spdy/2" def test_npn_client_fail(self): """ @@ -1733,11 +1804,11 @@ class TestNextProtoNegotiation(object): def advertise(conn): advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] + return [b"http/1.1", b"spdy/2"] def select(conn, options): select_args.append((conn, options)) - return b'' + return b"" server_context = Context(TLSv1_METHOD) server_context.set_npn_advertise_callback(advertise) @@ -1747,9 +1818,11 @@ class TestNextProtoNegotiation(object): # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1763,7 +1836,7 @@ class TestNextProtoNegotiation(object): interact_in_memory(server, client) assert advertise_args == [(server,)] - assert select_args == [(client, [b'http/1.1', b'spdy/2'])] + assert select_args == [(client, [b"http/1.1", b"spdy/2"])] def test_npn_select_error(self): """ @@ -1774,7 +1847,7 @@ class TestNextProtoNegotiation(object): def advertise(conn): advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] + return [b"http/1.1", b"spdy/2"] def select(conn, options): raise TypeError @@ -1787,9 +1860,11 @@ class TestNextProtoNegotiation(object): # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1801,7 +1876,9 @@ class TestNextProtoNegotiation(object): # If the callback throws an exception it should be raised here. with pytest.raises(TypeError): interact_in_memory(server, client) - assert advertise_args == [(server,), ] + assert advertise_args == [ + (server,), + ] def test_npn_advertise_error(self): """ @@ -1818,7 +1895,7 @@ class TestNextProtoNegotiation(object): Assert later that no args are actually appended. """ select_args.append((conn, options)) - return b'' + return b"" server_context = Context(TLSv1_METHOD) server_context.set_npn_advertise_callback(advertise) @@ -1828,9 +1905,11 @@ class TestNextProtoNegotiation(object): # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1849,6 +1928,7 @@ class TestApplicationLayerProtoNegotiation(object): """ Tests for ALPN in PyOpenSSL. """ + def test_alpn_success(self): """ Clients and servers that agree on the negotiated ALPN protocol can @@ -1859,19 +1939,21 @@ class TestApplicationLayerProtoNegotiation(object): def select(conn, options): select_args.append((conn, options)) - return b'spdy/2' + return b"spdy/2" client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(TLSv1_METHOD) server_context.set_alpn_select_callback(select) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1882,10 +1964,10 @@ class TestApplicationLayerProtoNegotiation(object): interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] - assert server.get_alpn_proto_negotiated() == b'spdy/2' - assert client.get_alpn_proto_negotiated() == b'spdy/2' + assert server.get_alpn_proto_negotiated() == b"spdy/2" + assert client.get_alpn_proto_negotiated() == b"spdy/2" def test_alpn_set_on_connection(self): """ @@ -1896,7 +1978,7 @@ class TestApplicationLayerProtoNegotiation(object): def select(conn, options): select_args.append((conn, options)) - return b'spdy/2' + return b"spdy/2" # Setup the client context but don't set any ALPN protocols. client_context = Context(TLSv1_METHOD) @@ -1906,9 +1988,11 @@ class TestApplicationLayerProtoNegotiation(object): # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1916,15 +2000,15 @@ class TestApplicationLayerProtoNegotiation(object): # Set the ALPN protocols on the client connection. client = Connection(client_context, None) - client.set_alpn_protos([b'http/1.1', b'spdy/2']) + client.set_alpn_protos([b"http/1.1", b"spdy/2"]) client.set_connect_state() interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] - assert server.get_alpn_proto_negotiated() == b'spdy/2' - assert client.get_alpn_proto_negotiated() == b'spdy/2' + assert server.get_alpn_proto_negotiated() == b"spdy/2" + assert client.get_alpn_proto_negotiated() == b"spdy/2" def test_alpn_server_fail(self): """ @@ -1935,19 +2019,21 @@ class TestApplicationLayerProtoNegotiation(object): def select(conn, options): select_args.append((conn, options)) - return b'' + return b"" client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(TLSv1_METHOD) server_context.set_alpn_select_callback(select) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1960,7 +2046,7 @@ class TestApplicationLayerProtoNegotiation(object): with pytest.raises(Error): interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] def test_alpn_no_server_overlap(self): """ @@ -1975,16 +2061,18 @@ class TestApplicationLayerProtoNegotiation(object): return NO_OVERLAPPING_PROTOCOLS client_context = Context(SSLv23_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(SSLv23_METHOD) server_context.set_alpn_select_callback(refusal) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1996,9 +2084,9 @@ class TestApplicationLayerProtoNegotiation(object): # Do the dance. interact_in_memory(server, client) - assert refusal_args == [(server, [b'http/1.1', b'spdy/2'])] + assert refusal_args == [(server, [b"http/1.1", b"spdy/2"])] - assert client.get_alpn_proto_negotiated() == b'' + assert client.get_alpn_proto_negotiated() == b"" def test_alpn_select_cb_returns_invalid_value(self): """ @@ -2013,16 +2101,18 @@ class TestApplicationLayerProtoNegotiation(object): return u"can't return unicode" client_context = Context(SSLv23_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(SSLv23_METHOD) server_context.set_alpn_select_callback(invalid_cb) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -2035,9 +2125,9 @@ class TestApplicationLayerProtoNegotiation(object): with pytest.raises(TypeError): interact_in_memory(server, client) - assert invalid_cb_args == [(server, [b'http/1.1', b'spdy/2'])] + assert invalid_cb_args == [(server, [b"http/1.1", b"spdy/2"])] - assert client.get_alpn_proto_negotiated() == b'' + assert client.get_alpn_proto_negotiated() == b"" def test_alpn_no_server(self): """ @@ -2045,15 +2135,17 @@ class TestApplicationLayerProtoNegotiation(object): because the server doesn't offer ALPN, no protocol is negotiated. """ client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(TLSv1_METHOD) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -2065,7 +2157,7 @@ class TestApplicationLayerProtoNegotiation(object): # Do the dance. interact_in_memory(server, client) - assert client.get_alpn_proto_negotiated() == b'' + assert client.get_alpn_proto_negotiated() == b"" def test_alpn_callback_exception(self): """ @@ -2078,16 +2170,18 @@ class TestApplicationLayerProtoNegotiation(object): raise TypeError() client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(TLSv1_METHOD) server_context.set_alpn_select_callback(select) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -2098,13 +2192,14 @@ class TestApplicationLayerProtoNegotiation(object): with pytest.raises(TypeError): interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] class TestSession(object): """ Unit tests for :py:obj:`OpenSSL.SSL.Session`. """ + def test_construction(self): """ :py:class:`Session` can be constructed with no arguments, creating @@ -2118,6 +2213,7 @@ class TestConnection(object): """ Unit tests for `OpenSSL.SSL.Connection`. """ + # XXX get_peer_certificate -> None # XXX sock_shutdown # XXX master_key -> TypeError @@ -2137,9 +2233,9 @@ class TestConnection(object): `Connection` can be used to create instances of that type. """ ctx = Context(TLSv1_METHOD) - assert is_consistent_type(Connection, 'Connection', ctx, None) + assert is_consistent_type(Connection, "Connection", ctx, None) - @pytest.mark.parametrize('bad_context', [object(), 'context', None, 1]) + @pytest.mark.parametrize("bad_context", [object(), "context", None, 1]) def test_wrong_args(self, bad_context): """ `Connection.__init__` raises `TypeError` if called with a non-`Context` @@ -2148,7 +2244,7 @@ class TestConnection(object): with pytest.raises(TypeError): Connection(bad_context) - @pytest.mark.parametrize('bad_bio', [object(), None, 1, [1, 2, 3]]) + @pytest.mark.parametrize("bad_bio", [object(), None, 1, [1, 2, 3]]) def test_bio_write_wrong_args(self, bad_bio): """ `Connection.bio_write` raises `TypeError` if called with a non-bytes @@ -2166,10 +2262,10 @@ class TestConnection(object): """ context = Context(TLSv1_METHOD) connection = Connection(context, None) - connection.bio_write(b'xy') - connection.bio_write(bytearray(b'za')) + connection.bio_write(b"xy") + connection.bio_write(bytearray(b"za")) with pytest.warns(DeprecationWarning): - connection.bio_write(u'deprecated') + connection.bio_write(u"deprecated") def test_get_context(self): """ @@ -2241,10 +2337,10 @@ class TestConnection(object): passed. """ server, client = loopback() - server.send(b'xy') - assert client.recv(2, MSG_PEEK) == b'xy' - assert client.recv(2, MSG_PEEK) == b'xy' - assert client.recv(2) == b'xy' + server.send(b"xy") + assert client.recv(2, MSG_PEEK) == b"xy" + assert client.recv(2, MSG_PEEK) == b"xy" + assert client.recv(2) == b"xy" def test_connect_wrong_args(self): """ @@ -2276,7 +2372,7 @@ class TestConnection(object): `Connection.connect` establishes a connection to the specified address. """ port = socket_any_family() - port.bind(('', 0)) + port.bind(("", 0)) port.listen(3) clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family)) @@ -2285,7 +2381,7 @@ class TestConnection(object): @pytest.mark.skipif( platform == "darwin", - reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4" + reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4", ) def test_connect_ex(self): """ @@ -2293,7 +2389,7 @@ class TestConnection(object): errno instead of raising an exception. """ port = socket_any_family() - port.bind(('', 0)) + port.bind(("", 0)) port.listen(3) clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family)) @@ -2313,7 +2409,7 @@ class TestConnection(object): ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) port = socket_any_family() portSSL = Connection(ctx, port) - portSSL.bind(('', 0)) + portSSL.bind(("", 0)) portSSL.listen(3) clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family)) @@ -2375,9 +2471,11 @@ class TestConnection(object): server_ctx = Context(TLSv1_METHOD) client_ctx = Context(TLSv1_METHOD) server_ctx.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_ctx.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) server = Connection(server_ctx, None) client = Connection(client_ctx, None) handshake_in_memory(client, server) @@ -2407,10 +2505,12 @@ class TestConnection(object): client = loopback_client_factory(client) assert server.get_state_string() in [ - b"before/accept initialization", b"before SSL initialization" + b"before/accept initialization", + b"before SSL initialization", ] assert client.get_state_string() in [ - b"before/connect initialization", b"before SSL initialization" + b"before/connect initialization", + b"before SSL initialization", ] def test_app_data(self): @@ -2565,17 +2665,17 @@ class TestConnection(object): server.set_accept_state() return server - originalServer, originalClient = loopback( - server_factory=makeServer) + originalServer, originalClient = loopback(server_factory=makeServer) originalSession = originalClient.get_session() def makeClient(socket): client = loopback_client_factory(socket) client.set_session(originalSession) return client + resumedServer, resumedClient = loopback( - server_factory=makeServer, - client_factory=makeClient) + server_factory=makeServer, client_factory=makeClient + ) # This is a proxy: in general, we have no access to any unique # identifier for the session (new enough versions of OpenSSL expose @@ -2621,7 +2721,8 @@ class TestConnection(object): return client originalServer, originalClient = loopback( - server_factory=makeServer, client_factory=makeOriginalClient) + server_factory=makeServer, client_factory=makeOriginalClient + ) originalSession = originalClient.get_session() def makeClient(socket): @@ -2657,7 +2758,8 @@ class TestConnection(object): raise else: pytest.fail( - "Failed to fill socket buffer, cannot test BIO want write") + "Failed to fill socket buffer, cannot test BIO want write" + ) ctx = Context(TLSv1_METHOD) conn = Connection(ctx, client_socket) @@ -2736,8 +2838,10 @@ class TestConnection(object): name of the currently used cipher. """ server, client = loopback() - server_cipher_name, client_cipher_name = \ - server.get_cipher_name(), client.get_cipher_name() + server_cipher_name, client_cipher_name = ( + server.get_cipher_name(), + client.get_cipher_name(), + ) assert isinstance(server_cipher_name, text_type) assert isinstance(client_cipher_name, text_type) @@ -2759,8 +2863,10 @@ class TestConnection(object): the protocol name of the currently used cipher. """ server, client = loopback() - server_cipher_version, client_cipher_version = \ - server.get_cipher_version(), client.get_cipher_version() + server_cipher_version, client_cipher_version = ( + server.get_cipher_version(), + client.get_cipher_version(), + ) assert isinstance(server_cipher_version, text_type) assert isinstance(client_cipher_version, text_type) @@ -2782,8 +2888,10 @@ class TestConnection(object): of the currently used cipher. """ server, client = loopback() - server_cipher_bits, client_cipher_bits = \ - server.get_cipher_bits(), client.get_cipher_bits() + server_cipher_bits, client_cipher_bits = ( + server.get_cipher_bits(), + client.get_cipher_bits(), + ) assert isinstance(server_cipher_bits, int) assert isinstance(client_cipher_bits, int) @@ -2828,7 +2936,7 @@ class TestConnection(object): with pytest.raises(WantReadError): conn.bio_read(1024) - @pytest.mark.parametrize('bufsize', [1.0, None, object(), 'bufsize']) + @pytest.mark.parametrize("bufsize", [1.0, None, object(), "bufsize"]) def test_bio_read_wrong_args(self, bufsize): """ `Connection.bio_read` raises `TypeError` if passed a non-integer @@ -2859,6 +2967,7 @@ class TestConnectionGetCipherList(object): """ Tests for `Connection.get_cipher_list`. """ + def test_result(self): """ `Connection.get_cipher_list` returns a list of `bytes` giving the @@ -2875,14 +2984,16 @@ class VeryLarge(bytes): """ Mock object so that we don't have to allocate 2**31 bytes """ + def __len__(self): - return 2**31 + return 2 ** 31 class TestConnectionSend(object): """ Tests for `Connection.send`. """ + def test_wrong_args(self): """ When called with arguments other than string argument for its first @@ -2900,9 +3011,9 @@ class TestConnectionSend(object): and returns the number of bytes sent. """ server, client = loopback() - count = server.send(b'xy') + count = server.send(b"xy") assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" def test_text(self): """ @@ -2913,12 +3024,11 @@ class TestConnectionSend(object): with pytest.warns(DeprecationWarning) as w: simplefilter("always") count = server.send(b"xy".decode("ascii")) - assert ( - "{0} for buf is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + assert "{0} for buf is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) == str(w[-1].message) assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" def test_short_memoryview(self): """ @@ -2927,9 +3037,9 @@ class TestConnectionSend(object): of bytes sent. """ server, client = loopback() - count = server.send(memoryview(b'xy')) + count = server.send(memoryview(b"xy")) assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" def test_short_bytearray(self): """ @@ -2937,9 +3047,9 @@ class TestConnectionSend(object): it and returns the number of bytes sent. """ server, client = loopback() - count = server.send(bytearray(b'xy')) + count = server.send(bytearray(b"xy")) assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" @skip_if_py3 def test_short_buffer(self): @@ -2949,13 +3059,13 @@ class TestConnectionSend(object): of bytes sent. """ server, client = loopback() - count = server.send(buffer(b'xy')) + count = server.send(buffer(b"xy")) # noqa: F821 assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" @pytest.mark.skipif( - sys.maxsize < 2**31, - reason="sys.maxsize < 2**31 - test requires 64 bit" + sys.maxsize < 2 ** 31, + reason="sys.maxsize < 2**31 - test requires 64 bit", ) def test_buf_too_large(self): """ @@ -2981,6 +3091,7 @@ class TestConnectionRecvInto(object): """ Tests for `Connection.recv_into`. """ + def _no_length_test(self, factory): """ Assert that when the given buffer is passed to `Connection.recv_into`, @@ -2990,10 +3101,10 @@ class TestConnectionRecvInto(object): output_buffer = factory(5) server, client = loopback() - server.send(b'xy') + server.send(b"xy") assert client.recv_into(output_buffer) == 2 - assert output_buffer == bytearray(b'xy\x00\x00\x00') + assert output_buffer == bytearray(b"xy\x00\x00\x00") def test_bytearray_no_length(self): """ @@ -3011,10 +3122,10 @@ class TestConnectionRecvInto(object): output_buffer = factory(10) server, client = loopback() - server.send(b'abcdefghij') + server.send(b"abcdefghij") assert client.recv_into(output_buffer, 5) == 5 - assert output_buffer == bytearray(b'abcde\x00\x00\x00\x00\x00') + assert output_buffer == bytearray(b"abcde\x00\x00\x00\x00\x00") def test_bytearray_respects_length(self): """ @@ -3033,12 +3144,12 @@ class TestConnectionRecvInto(object): output_buffer = factory(5) server, client = loopback() - server.send(b'abcdefghij') + server.send(b"abcdefghij") assert client.recv_into(output_buffer) == 5 - assert output_buffer == bytearray(b'abcde') + assert output_buffer == bytearray(b"abcde") rest = client.recv(5) - assert b'fghij' == rest + assert b"fghij" == rest def test_bytearray_doesnt_overfill(self): """ @@ -3059,12 +3170,12 @@ class TestConnectionRecvInto(object): def test_peek(self): server, client = loopback() - server.send(b'xy') + server.send(b"xy") for _ in range(2): output_buffer = bytearray(5) assert client.recv_into(output_buffer, flags=MSG_PEEK) == 2 - assert output_buffer == bytearray(b'xy\x00\x00\x00') + assert output_buffer == bytearray(b"xy\x00\x00\x00") def test_memoryview_no_length(self): """ @@ -3103,6 +3214,7 @@ class TestConnectionSendall(object): """ Tests for `Connection.sendall`. """ + def test_wrong_args(self): """ When called with arguments other than a string argument for its first @@ -3120,8 +3232,8 @@ class TestConnectionSendall(object): passed to it. """ server, client = loopback() - server.sendall(b'x') - assert client.recv(1) == b'x' + server.sendall(b"x") + assert client.recv(1) == b"x" def test_text(self): """ @@ -3132,10 +3244,9 @@ class TestConnectionSendall(object): with pytest.warns(DeprecationWarning) as w: simplefilter("always") server.sendall(b"x".decode("ascii")) - assert ( - "{0} for buf is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + assert "{0} for buf is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) == str(w[-1].message) assert client.recv(1) == b"x" def test_short_memoryview(self): @@ -3144,8 +3255,8 @@ class TestConnectionSendall(object): `Connection.sendall` transmits all of them. """ server, client = loopback() - server.sendall(memoryview(b'x')) - assert client.recv(1) == b'x' + server.sendall(memoryview(b"x")) + assert client.recv(1) == b"x" @skip_if_py3 def test_short_buffers(self): @@ -3154,9 +3265,9 @@ class TestConnectionSendall(object): `Connection.sendall` transmits all of them. """ server, client = loopback() - count = server.sendall(buffer(b'xy')) + count = server.sendall(buffer(b"xy")) # noqa: F821 assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" def test_long(self): """ @@ -3167,7 +3278,7 @@ class TestConnectionSendall(object): # Should be enough, underlying SSL_write should only do 16k at a time. # On Windows, after 32k of bytes the write will block (forever # - because no one is yet reading). - message = b'x' * (1024 * 32 - 1) + b'y' + message = b"x" * (1024 * 32 - 1) + b"y" server.sendall(message) accum = [] received = 0 @@ -3175,7 +3286,7 @@ class TestConnectionSendall(object): data = client.recv(1024) accum.append(data) received += len(data) - assert message == b''.join(accum) + assert message == b"".join(accum) def test_closed(self): """ @@ -3196,6 +3307,7 @@ class TestConnectionRenegotiate(object): """ Tests for SSL renegotiation APIs. """ + def test_total_renegotiations(self): """ `Connection.total_renegotiations` returns `0` before any renegotiations @@ -3239,12 +3351,13 @@ class TestError(object): """ Unit tests for `OpenSSL.SSL.Error`. """ + def test_type(self): """ `Error` is an exception type. """ assert issubclass(Error, Exception) - assert Error.__name__ == 'Error' + assert Error.__name__ == "Error" class TestConstants(object): @@ -3255,9 +3368,10 @@ class TestConstants(object): OpenSSL APIs. The only assertions it seems can be made about them is their values. """ + @pytest.mark.skipif( OP_NO_QUERY_MTU is None, - reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old" + reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old", ) def test_op_no_query_mtu(self): """ @@ -3269,7 +3383,7 @@ class TestConstants(object): @pytest.mark.skipif( OP_COOKIE_EXCHANGE is None, reason="OP_COOKIE_EXCHANGE unavailable - " - "OpenSSL version may be too old" + "OpenSSL version may be too old", ) def test_op_cookie_exchange(self): """ @@ -3280,7 +3394,7 @@ class TestConstants(object): @pytest.mark.skipif( OP_NO_TICKET is None, - reason="OP_NO_TICKET unavailable - OpenSSL version may be too old" + reason="OP_NO_TICKET unavailable - OpenSSL version may be too old", ) def test_op_no_ticket(self): """ @@ -3291,7 +3405,9 @@ class TestConstants(object): @pytest.mark.skipif( OP_NO_COMPRESSION is None, - reason="OP_NO_COMPRESSION unavailable - OpenSSL version may be too old" + reason=( + "OP_NO_COMPRESSION unavailable - OpenSSL version may be too old" + ), ) def test_op_no_compression(self): """ @@ -3365,6 +3481,7 @@ class TestMemoryBIO(object): """ Tests for `OpenSSL.SSL.Connection` using a memory BIO. """ + def _server(self, sock): """ Create a new server-side SSL `Connection` object wrapped around `sock`. @@ -3375,13 +3492,15 @@ class TestMemoryBIO(object): server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) server_ctx.set_verify( VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, - verify_cb + verify_cb, ) server_store = server_ctx.get_cert_store() server_ctx.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_ctx.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) server_ctx.check_privatekey() server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) # Here the Connection is actually created. If None is passed as the @@ -3400,13 +3519,15 @@ class TestMemoryBIO(object): client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) client_ctx.set_verify( VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, - verify_cb + verify_cb, ) client_store = client_ctx.get_cert_store() client_ctx.use_privatekey( - load_privatekey(FILETYPE_PEM, client_key_pem)) + load_privatekey(FILETYPE_PEM, client_key_pem) + ) client_ctx.use_certificate( - load_certificate(FILETYPE_PEM, client_cert_pem)) + load_certificate(FILETYPE_PEM, client_cert_pem) + ) client_ctx.check_privatekey() client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) client_conn = Connection(client_ctx, sock) @@ -3443,39 +3564,41 @@ class TestMemoryBIO(object): assert client_conn.client_random() != client_conn.server_random() # Export key material for other uses. - cekm = client_conn.export_keying_material(b'LABEL', 32) - sekm = server_conn.export_keying_material(b'LABEL', 32) + cekm = client_conn.export_keying_material(b"LABEL", 32) + sekm = server_conn.export_keying_material(b"LABEL", 32) assert cekm is not None assert sekm is not None assert cekm == sekm assert len(sekm) == 32 # Export key material for other uses with additional context. - cekmc = client_conn.export_keying_material(b'LABEL', 32, b'CONTEXT') - sekmc = server_conn.export_keying_material(b'LABEL', 32, b'CONTEXT') + cekmc = client_conn.export_keying_material(b"LABEL", 32, b"CONTEXT") + sekmc = server_conn.export_keying_material(b"LABEL", 32, b"CONTEXT") assert cekmc is not None assert sekmc is not None assert cekmc == sekmc assert cekmc != cekm assert sekmc != sekm # Export with alternate label - cekmt = client_conn.export_keying_material(b'test', 32, b'CONTEXT') - sekmt = server_conn.export_keying_material(b'test', 32, b'CONTEXT') + cekmt = client_conn.export_keying_material(b"test", 32, b"CONTEXT") + sekmt = server_conn.export_keying_material(b"test", 32, b"CONTEXT") assert cekmc != cekmt assert sekmc != sekmt # Here are the bytes we'll try to send. - important_message = b'One if by land, two if by sea.' + important_message = b"One if by land, two if by sea." server_conn.write(important_message) - assert ( - interact_in_memory(client_conn, server_conn) == - (client_conn, important_message)) + assert interact_in_memory(client_conn, server_conn) == ( + client_conn, + important_message, + ) client_conn.write(important_message[::-1]) - assert ( - interact_in_memory(client_conn, server_conn) == - (server_conn, important_message[::-1])) + assert interact_in_memory(client_conn, server_conn) == ( + server_conn, + important_message[::-1], + ) def test_socket_connect(self): """ @@ -3608,9 +3731,11 @@ class TestMemoryBIO(object): client sides, `Connection.get_client_ca_list` returns an empty list after the connection is set up. """ + def no_ca(ctx): ctx.set_client_ca_list([]) return [] + self._check_client_ca_list(no_ca) def test_set_one_ca_list(self): @@ -3627,6 +3752,7 @@ class TestMemoryBIO(object): def single_ca(ctx): ctx.set_client_ca_list([cadesc]) return [cadesc] + self._check_client_ca_list(single_ca) def test_set_multiple_ca_list(self): @@ -3647,6 +3773,7 @@ class TestMemoryBIO(object): L = [sedesc, cldesc] ctx.set_client_ca_list(L) return L + self._check_client_ca_list(multiple_ca) def test_reset_ca_list(self): @@ -3667,6 +3794,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([sedesc, cldesc]) ctx.set_client_ca_list([cadesc]) return [cadesc] + self._check_client_ca_list(changed_ca) def test_mutated_ca_list(self): @@ -3686,6 +3814,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([cadesc]) L.append(sedesc) return [cadesc] + self._check_client_ca_list(mutated_ca) def test_add_client_ca_wrong_args(self): @@ -3708,6 +3837,7 @@ class TestMemoryBIO(object): def single_ca(ctx): ctx.add_client_ca(cacert) return [cadesc] + self._check_client_ca_list(single_ca) def test_multiple_add_client_ca(self): @@ -3725,6 +3855,7 @@ class TestMemoryBIO(object): ctx.add_client_ca(cacert) ctx.add_client_ca(secert) return [cadesc, sedesc] + self._check_client_ca_list(multiple_ca) def test_set_and_add_client_ca(self): @@ -3745,6 +3876,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([cadesc, sedesc]) ctx.add_client_ca(clcert) return [cadesc, sedesc, cldesc] + self._check_client_ca_list(mixed_set_add_ca) def test_set_after_add_client_ca(self): @@ -3765,6 +3897,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([cadesc]) ctx.add_client_ca(secert) return [cadesc, sedesc] + self._check_client_ca_list(set_replaces_add_ca) @@ -3772,6 +3905,7 @@ class TestInfoConstants(object): """ Tests for assorted constants exposed for use in info callbacks. """ + def test_integers(self): """ All of the info constants are integers. @@ -3781,17 +3915,31 @@ class TestInfoConstants(object): info callback matches up with the constant exposed by OpenSSL.SSL. """ for const in [ - SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, - SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, - SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, - SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, - SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE + SSL_ST_CONNECT, + SSL_ST_ACCEPT, + SSL_ST_MASK, + SSL_CB_LOOP, + SSL_CB_EXIT, + SSL_CB_READ, + SSL_CB_WRITE, + SSL_CB_ALERT, + SSL_CB_READ_ALERT, + SSL_CB_WRITE_ALERT, + SSL_CB_ACCEPT_LOOP, + SSL_CB_ACCEPT_EXIT, + SSL_CB_CONNECT_LOOP, + SSL_CB_CONNECT_EXIT, + SSL_CB_HANDSHAKE_START, + SSL_CB_HANDSHAKE_DONE, ]: assert isinstance(const, int) # These constants don't exist on OpenSSL 1.1.0 for const in [ - SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE + SSL_ST_INIT, + SSL_ST_BEFORE, + SSL_ST_OK, + SSL_ST_RENEGOTIATE, ]: assert const is None or isinstance(const, int) @@ -3801,6 +3949,7 @@ class TestRequires(object): Tests for the decorator factory used to conditionally raise NotImplementedError when older OpenSSLs are used. """ + def test_available(self): """ When the OpenSSL functionality is available the decorated functions @@ -3838,6 +3987,7 @@ class TestOCSP(object): """ Tests for PyOpenSSL's OCSP stapling support. """ + sample_ocsp_data = b"this is totally ocsp data" def _client_connection(self, callback, data, request_ocsp=True): @@ -3882,6 +4032,7 @@ class TestOCSP(object): the client does not send the OCSP request, neither callback gets called. """ + def ocsp_callback(*args, **kwargs): # pragma: nocover pytest.fail("Should not be called") @@ -3907,7 +4058,7 @@ class TestOCSP(object): handshake_in_memory(client, server) assert len(called) == 1 - assert called[0] == b'' + assert called[0] == b"" def test_client_receives_servers_data(self): """ @@ -3990,7 +4141,7 @@ class TestOCSP(object): client_calls = [] def server_callback(*args): - return b'' + return b"" def client_callback(conn, ocsp_data, ignored): client_calls.append(ocsp_data) @@ -4001,12 +4152,13 @@ class TestOCSP(object): handshake_in_memory(client, server) assert len(client_calls) == 1 - assert client_calls[0] == b'' + assert client_calls[0] == b"" def test_client_returns_false_terminates_handshake(self): """ If the client returns False from its callback, the handshake fails. """ + def server_callback(*args): return self.sample_ocsp_data @@ -4023,6 +4175,7 @@ class TestOCSP(object): """ The callbacks thrown in the client callback bubble up to the caller. """ + class SentinelException(Exception): pass @@ -4042,6 +4195,7 @@ class TestOCSP(object): """ The callbacks thrown in the server callback bubble up to the caller. """ + class SentinelException(Exception): pass @@ -4061,8 +4215,9 @@ class TestOCSP(object): """ The server callback must return a bytestring, or a TypeError is thrown. """ + def server_callback(*args): - return self.sample_ocsp_data.decode('ascii') + return self.sample_ocsp_data.decode("ascii") def client_callback(*args): # pragma: nocover pytest.fail("Should not be called") diff --git a/tests/test_util.py b/tests/test_util.py index 91847e0..6224448 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -7,6 +7,7 @@ class TestErrors(object): """ Tests for handling of certain OpenSSL error cases. """ + def test_exception_from_error_queue_nonexistent_reason(self): """ :func:`exception_from_error_queue` raises ``ValueError`` when it diff --git a/tests/util.py b/tests/util.py index 65b905a..75d2c8d 100644 --- a/tests/util.py +++ b/tests/util.py @@ -59,7 +59,7 @@ class EqualityTestsMixin(object): An object compares equal to itself using the C{==} operator. """ o = self.anInstance() - assert (o == o) + assert o == o def test_identicalNe(self): """ @@ -75,7 +75,7 @@ class EqualityTestsMixin(object): """ a = self.anInstance() b = self.anInstance() - assert (a == b) + assert a == b def test_sameNe(self): """ @@ -102,7 +102,7 @@ class EqualityTestsMixin(object): """ a = self.anInstance() b = self.anotherInstance() - assert (a != b) + assert a != b def test_anotherTypeEq(self): """ @@ -120,13 +120,14 @@ class EqualityTestsMixin(object): """ a = self.anInstance() b = object() - assert (a != b) + assert a != b def test_delegatedEq(self): """ The result of comparison using C{==} is delegated to the right-hand operand if it is of an unrelated type. """ + class Delegate(object): def __eq__(self, other): # Do something crazy and obvious. @@ -141,6 +142,7 @@ class EqualityTestsMixin(object): The result of comparison using C{!=} is delegated to the right-hand operand if it is of an unrelated type. """ + class Delegate(object): def __ne__(self, other): # Do something crazy and obvious. @@ -52,11 +52,14 @@ commands = rm -rf ./urllib3 [testenv:flake8] +basepython = python3 deps = - flake8 + black + flake8 skip_install = true commands = - flake8 src tests setup.py + black --check . + flake8 src tests setup.py [testenv:pypi-readme] deps = @@ -85,3 +88,7 @@ skip_install = true commands = coverage combine coverage report + +[flake8] +ignore = E203,W503,W504 +select = E,W,F,I |