diff options
author | Jean-Paul Calderone <exarkun@divmod.com> | 2011-06-12 17:20:31 -0400 |
---|---|---|
committer | Jean-Paul Calderone <exarkun@divmod.com> | 2011-06-12 17:20:31 -0400 |
commit | 0da7acc9079aeba61a764cb84b243e4742dae92d (patch) | |
tree | 95228a0bcac56023c8a82de8a8deff1af5fdaee5 /OpenSSL/test | |
parent | 9828f666f9eed90290093e962181d6065bb2017c (diff) | |
parent | c27733339ba8f678479db744e75a994639cba61e (diff) | |
download | pyopenssl-0da7acc9079aeba61a764cb84b243e4742dae92d.tar.gz |
Catch up to trunk
Diffstat (limited to 'OpenSSL/test')
-rw-r--r-- | OpenSSL/test/__init__.py | 3 | ||||
-rw-r--r-- | OpenSSL/test/test_crypto.py | 211 | ||||
-rw-r--r-- | OpenSSL/test/test_rand.py | 3 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 509 | ||||
-rw-r--r-- | OpenSSL/test/util.py | 28 |
5 files changed, 664 insertions, 90 deletions
diff --git a/OpenSSL/test/__init__.py b/OpenSSL/test/__init__.py index ab9c4cb..ccb4e9a 100644 --- a/OpenSSL/test/__init__.py +++ b/OpenSSL/test/__init__.py @@ -1,4 +1,5 @@ -# Copyright (C) Jean-Paul Calderone 2008, All rights reserved +# Copyright (C) Jean-Paul Calderone +# See LICENSE for details. """ Package containing unit tests for L{OpenSSL}. diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py index 07c172f..71da5c3 100644 --- a/OpenSSL/test/test_crypto.py +++ b/OpenSSL/test/test_crypto.py @@ -1,4 +1,5 @@ -# Copyright (C) Jean-Paul Calderone 2008, All rights reserved +# Copyright (c) Jean-Paul Calderone +# See LICENSE file for details. """ Unit tests for L{OpenSSL.crypto}. @@ -25,6 +26,13 @@ from OpenSSL.crypto import NetscapeSPKI, NetscapeSPKIType from OpenSSL.crypto import sign, verify from OpenSSL.test.util import TestCase, bytes, b +def normalize_certificate_pem(pem): + return dump_certificate(FILETYPE_PEM, load_certificate(FILETYPE_PEM, pem)) + + +def normalize_privatekey_pem(pem): + return dump_privatekey(FILETYPE_PEM, load_privatekey(FILETYPE_PEM, pem)) + root_cert_pem = b("""-----BEGIN CERTIFICATE----- MIIC7TCCAlagAwIBAgIIPQzE4MbeufQwDQYJKoZIhvcNAQEFBQAwWDELMAkGA1UE @@ -79,7 +87,7 @@ uzujnS8YXWvM7DM1Ilozk4MzPug8jzFp5uhKCQ== -----END CERTIFICATE----- """) -server_key_pem = b("""-----BEGIN RSA PRIVATE KEY----- +server_key_pem = normalize_privatekey_pem(b("""-----BEGIN RSA PRIVATE KEY----- MIICWwIBAAKBgQC+pvhuud1dLaQQvzipdtlcTotgr5SuE2LvSx0gz/bg1U3u1eQ+ U5eqsxaEUceaX5p5Kk+QflvW8qdjVNxQuYS5uc0gK2+OZnlIYxCf4n5GYGzVIx3Q SBj/TAEFB2WuVinZBiCbxgL7PFM1Kpa+EwVkCAduPpSflJJPwkYGrK2MHQIDAQAB @@ -94,7 +102,7 @@ FwwOhpahld+vqhYk+pfuWWUpQciE+Bu7ZQJASjfT4sQv4qbbKK/scePicnDdx9th NaeNCFfH3aeTrX0LyQJAMBWjWmeKM2G2sCExheeQK0ROnaBC8itCECD4Jsve4nqf r50+LF74iLXFwqysVCebPKMOpDWp/qQ1BbJQIPs7/A== -----END RSA PRIVATE KEY----- -""") +""")) client_cert_pem = b("""-----BEGIN CERTIFICATE----- MIICJjCCAY+gAwIBAgIJAKxpFI5lODkjMA0GCSqGSIb3DQEBBQUAMFgxCzAJBgNV @@ -112,7 +120,7 @@ PSTJCjJOn3xo2NTKRgV1gaoTf2EhL+RG8TQ= -----END CERTIFICATE----- """) -client_key_pem = b("""-----BEGIN RSA PRIVATE KEY----- +client_key_pem = normalize_privatekey_pem(b("""-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDAZh/SRtNm5ntMT4qb6YzEpTroMlq2rn+GrRHRiZ+xkCw/CGNh btPir7/QxaUj26BSmQrHw1bGKEbPsWiW7bdXSespl+xKiku4G/KvnnmWdeJHqsiX eUZtqurMELcPQAw9xPHEuhqqUJvvEoMTsnCEqGM+7DtboCRajYyHfluARQIDAQAB @@ -127,7 +135,7 @@ si6xwT7GzMDkk/ko684AV3KPc/h6G0yGtFIrMg7J3uExpR/VdH2KgwMkZXisSMvw JJEQjOMCVsEJlRk54WWjAkEAzoZNH6UhDdBK5F38rVt/y4SEHgbSfJHIAmPS32Kq f6GGcfNpip0Uk7q7udTKuX7Q/buZi/C4YW7u3VKAquv9NA== -----END RSA PRIVATE KEY----- -""") +""")) cleartextCertificatePEM = b("""-----BEGIN CERTIFICATE----- MIIC7TCCAlagAwIBAgIIPQzE4MbeufQwDQYJKoZIhvcNAQEFBQAwWDELMAkGA1UE @@ -149,7 +157,8 @@ w/njVbKMXrvc83qmTdGl3TAM0fxQIpqgcglFLveEBgzn -----END CERTIFICATE----- """) -cleartextPrivateKeyPEM = b("""-----BEGIN RSA PRIVATE KEY----- +cleartextPrivateKeyPEM = normalize_privatekey_pem(b("""\ +-----BEGIN RSA PRIVATE KEY----- MIICXQIBAAKBgQD5mkLpi7q6ROdu7khB3S9aanA0Zls7vvfGOmB80/yeylhGpsjA jWen0VtSQke/NlEPGtO38tsV7CsuFnSmschvAnGrcJl76b0UOOHUgDTIoRxC6QDU 3claegwsrBA+sJEBbqx5RdXbIRGicPG/8qQ4Zm1SKOgotcbwiaor2yxZ2wIDAQAB @@ -164,7 +173,7 @@ ttXigLnCqR486JDPTi9ZscoZkZ+w7y6e/hH8t6d5Vjt48JVyfjPIaJY+km58LcN3 6AWSeGAdtRFHVzR7oHjVAkB4hutvxiOeiIVQNBhM6RSI9aBPMI21DoX2JRoxvNW2 cbvAhow217X9V0dVerEOKxnNYspXRrh36h7k4mQA+sDq -----END RSA PRIVATE KEY----- -""") +""")) cleartextCertificateRequestPEM = b("""-----BEGIN CERTIFICATE REQUEST----- MIIBnjCCAQcCAQAwXjELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAklMMRAwDgYDVQQH @@ -352,6 +361,26 @@ class X509ExtTests(TestCase): self.assertEqual(ext.get_short_name(), b('nsComment')) + def test_get_data(self): + """ + L{X509Extension.get_data} returns a string giving the data of the + extension. + """ + ext = X509Extension(b('basicConstraints'), True, b('CA:true')) + # Expect to get back the DER encoded form of CA:true. + self.assertEqual(ext.get_data(), b('0\x03\x01\x01\xff')) + + + def test_get_data_wrong_args(self): + """ + L{X509Extension.get_data} raises L{TypeError} if passed any arguments. + """ + ext = X509Extension(b('basicConstraints'), True, b('CA:true')) + self.assertRaises(TypeError, ext.get_data, None) + self.assertRaises(TypeError, ext.get_data, "foo") + self.assertRaises(TypeError, ext.get_data, 7) + + def test_unused_subject(self): """ The C{subject} parameter to L{X509Extension} may be provided for an @@ -597,6 +626,33 @@ class X509NameTests(TestCase): name, type(name), X509NameType)) + def test_onlyStringAttributes(self): + """ + Attempting to set a non-L{str} attribute name on an L{X509NameType} + instance causes L{TypeError} to be raised. + """ + name = self._x509name() + # Beyond these cases, you may also think that unicode should be + # rejected. Sorry, you're wrong. unicode is automatically converted to + # str outside of the control of X509Name, so there's no way to reject + # it. + self.assertRaises(TypeError, setattr, name, None, "hello") + self.assertRaises(TypeError, setattr, name, 30, "hello") + class evil(str): + pass + self.assertRaises(TypeError, setattr, name, evil(), "hello") + + + def test_setInvalidAttribute(self): + """ + Attempting to set any attribute name on an L{X509NameType} instance for + which no corresponding NID is defined causes L{AttributeError} to be + raised. + """ + name = self._x509name() + self.assertRaises(AttributeError, setattr, name, "no such thing", None) + + def test_attributes(self): """ L{X509NameType} instances have attributes for each standard (?) @@ -946,6 +1002,26 @@ class X509Tests(TestCase, _PKeyInteractionTestsMixin): """ pemData = cleartextCertificatePEM + cleartextPrivateKeyPEM + extpem = """ +-----BEGIN CERTIFICATE----- +MIIC3jCCAkegAwIBAgIJAJHFjlcCgnQzMA0GCSqGSIb3DQEBBQUAMEcxCzAJBgNV +BAYTAlNFMRUwEwYDVQQIEwxXZXN0ZXJib3R0b20xEjAQBgNVBAoTCUNhdGFsb2dp +eDENMAsGA1UEAxMEUm9vdDAeFw0wODA0MjIxNDQ1MzhaFw0wOTA0MjIxNDQ1Mzha +MFQxCzAJBgNVBAYTAlNFMQswCQYDVQQIEwJXQjEUMBIGA1UEChMLT3Blbk1ldGFk +aXIxIjAgBgNVBAMTGW5vZGUxLm9tMi5vcGVubWV0YWRpci5vcmcwgZ8wDQYJKoZI +hvcNAQEBBQADgY0AMIGJAoGBAPIcQMrwbk2nESF/0JKibj9i1x95XYAOwP+LarwT +Op4EQbdlI9SY+uqYqlERhF19w7CS+S6oyqx0DRZSk4Y9dZ9j9/xgm2u/f136YS1u +zgYFPvfUs6PqYLPSM8Bw+SjJ+7+2+TN+Tkiof9WP1cMjodQwOmdsiRbR0/J7+b1B +hec1AgMBAAGjgcQwgcEwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNT +TCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFIdHsBcMVVMbAO7j6NCj +03HgLnHaMB8GA1UdIwQYMBaAFL2h9Bf9Mre4vTdOiHTGAt7BRY/8MEYGA1UdEQQ/ +MD2CDSouZXhhbXBsZS5vcmeCESoub20yLmV4bWFwbGUuY29thwSC7wgKgRNvbTJA +b3Blbm1ldGFkaXIub3JnMA0GCSqGSIb3DQEBBQUAA4GBALd7WdXkp2KvZ7/PuWZA +MPlIxyjS+Ly11+BNE0xGQRp9Wz+2lABtpgNqssvU156+HkKd02rGheb2tj7MX9hG +uZzbwDAZzJPjzDQDD7d3cWsrVcfIdqVU7epHqIadnOF+X0ghJ39pAm6VVadnSXCt +WpOdIpB8KksUTCzV591Nr1wd +-----END CERTIFICATE----- + """ def signable(self): """ Create and return a new L{X509}. @@ -1198,6 +1274,77 @@ class X509Tests(TestCase, _PKeyInteractionTestsMixin): b("A8:EB:07:F8:53:25:0A:F2:56:05:C5:A5:C4:C4:C7:15")) + def _extcert(self, pkey, extensions): + cert = X509() + cert.set_pubkey(pkey) + cert.get_subject().commonName = "Unit Tests" + cert.get_issuer().commonName = "Unit Tests" + when = b(datetime.now().strftime("%Y%m%d%H%M%SZ")) + cert.set_notBefore(when) + cert.set_notAfter(when) + + cert.add_extensions(extensions) + return load_certificate( + FILETYPE_PEM, dump_certificate(FILETYPE_PEM, cert)) + + + def test_extension_count(self): + """ + L{X509.get_extension_count} returns the number of extensions that are + present in the certificate. + """ + pkey = load_privatekey(FILETYPE_PEM, client_key_pem) + ca = X509Extension(b('basicConstraints'), True, b('CA:FALSE')) + key = X509Extension(b('keyUsage'), True, b('digitalSignature')) + subjectAltName = X509Extension( + b('subjectAltName'), True, b('DNS:example.com')) + + # Try a certificate with no extensions at all. + c = self._extcert(pkey, []) + self.assertEqual(c.get_extension_count(), 0) + + # And a certificate with one + c = self._extcert(pkey, [ca]) + self.assertEqual(c.get_extension_count(), 1) + + # And a certificate with several + c = self._extcert(pkey, [ca, key, subjectAltName]) + self.assertEqual(c.get_extension_count(), 3) + + + def test_get_extension(self): + """ + L{X509.get_extension} takes an integer and returns an L{X509Extension} + corresponding to the extension at that index. + """ + pkey = load_privatekey(FILETYPE_PEM, client_key_pem) + ca = X509Extension(b('basicConstraints'), True, b('CA:FALSE')) + key = X509Extension(b('keyUsage'), True, b('digitalSignature')) + subjectAltName = X509Extension( + b('subjectAltName'), False, b('DNS:example.com')) + + cert = self._extcert(pkey, [ca, key, subjectAltName]) + + ext = cert.get_extension(0) + self.assertTrue(isinstance(ext, X509Extension)) + self.assertTrue(ext.get_critical()) + self.assertEqual(ext.get_short_name(), b('basicConstraints')) + + ext = cert.get_extension(1) + self.assertTrue(isinstance(ext, X509Extension)) + self.assertTrue(ext.get_critical()) + self.assertEqual(ext.get_short_name(), b('keyUsage')) + + ext = cert.get_extension(2) + self.assertTrue(isinstance(ext, X509Extension)) + self.assertFalse(ext.get_critical()) + self.assertEqual(ext.get_short_name(), b('subjectAltName')) + + self.assertRaises(IndexError, cert.get_extension, -1) + self.assertRaises(IndexError, cert.get_extension, 4) + self.assertRaises(TypeError, cert.get_extension, "hello") + + def test_invalid_digest_algorithm(self): """ L{X509.digest} raises L{ValueError} if called with an unrecognized hash @@ -1326,7 +1473,53 @@ class X509Tests(TestCase, _PKeyInteractionTestsMixin): name. """ cert = load_certificate(FILETYPE_PEM, self.pemData) - self.assertEquals(cert.subject_name_hash(), 3350047874) + self.assertIn( + cert.subject_name_hash(), + [3350047874, # OpenSSL 0.9.8, MD5 + 3278919224, # OpenSSL 1.0.0, SHA1 + ]) + + + def test_get_signature_algorithm(self): + """ + L{X509Type.get_signature_algorithm} returns a string which means + the algorithm used to sign the certificate. + """ + cert = load_certificate(FILETYPE_PEM, self.pemData) + self.assertEqual( + b("sha1WithRSAEncryption"), cert.get_signature_algorithm()) + + + def test_get_undefined_signature_algorithm(self): + """ + L{X509Type.get_signature_algorithm} raises L{ValueError} if the + signature algorithm is undefined or unknown. + """ + # This certificate has been modified to indicate a bogus OID in the + # signature algorithm field so that OpenSSL does not recognize it. + certPEM = """\ +-----BEGIN CERTIFICATE----- +MIIC/zCCAmigAwIBAgIBATAGBgJ8BQUAMHsxCzAJBgNVBAYTAlNHMREwDwYDVQQK +EwhNMkNyeXB0bzEUMBIGA1UECxMLTTJDcnlwdG8gQ0ExJDAiBgNVBAMTG00yQ3J5 +cHRvIENlcnRpZmljYXRlIE1hc3RlcjEdMBsGCSqGSIb3DQEJARYObmdwc0Bwb3N0 +MS5jb20wHhcNMDAwOTEwMDk1MTMwWhcNMDIwOTEwMDk1MTMwWjBTMQswCQYDVQQG +EwJTRzERMA8GA1UEChMITTJDcnlwdG8xEjAQBgNVBAMTCWxvY2FsaG9zdDEdMBsG +CSqGSIb3DQEJARYObmdwc0Bwb3N0MS5jb20wXDANBgkqhkiG9w0BAQEFAANLADBI +AkEArL57d26W9fNXvOhNlZzlPOACmvwOZ5AdNgLzJ1/MfsQQJ7hHVeHmTAjM664V ++fXvwUGJLziCeBo1ysWLRnl8CQIDAQABo4IBBDCCAQAwCQYDVR0TBAIwADAsBglg +hkgBhvhCAQ0EHxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0O +BBYEFM+EgpK+eyZiwFU1aOPSbczbPSpVMIGlBgNVHSMEgZ0wgZqAFPuHI2nrnDqT +FeXFvylRT/7tKDgBoX+kfTB7MQswCQYDVQQGEwJTRzERMA8GA1UEChMITTJDcnlw +dG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQDExtNMkNyeXB0byBDZXJ0 +aWZpY2F0ZSBNYXN0ZXIxHTAbBgkqhkiG9w0BCQEWDm5ncHNAcG9zdDEuY29tggEA +MA0GCSqGSIb3DQEBBAUAA4GBADv8KpPo+gfJxN2ERK1Y1l17sz/ZhzoGgm5XCdbx +jEY7xKfpQngV599k1xhl11IMqizDwu0855agrckg2MCTmOI9DZzDD77tAYb+Dk0O +PEVk0Mk/V0aIsDE9bolfCi/i/QWZ3N8s5nTWMNyBBBmoSliWCm4jkkRZRD0ejgTN +tgI5 +-----END CERTIFICATE----- +""" + cert = load_certificate(FILETYPE_PEM, certPEM) + self.assertRaises(ValueError, cert.get_signature_algorithm) @@ -1547,7 +1740,7 @@ class PKCS12Tests(TestCase): dumped_p12 = p12.export(passphrase=passwd, iter=2, maciter=3) reloaded_p12 = load_pkcs12(dumped_p12, passwd) self.assertEqual( - p12.get_friendlyname(),reloaded_p12.get_friendlyname()) + p12.get_friendlyname(), reloaded_p12.get_friendlyname()) # We would use the openssl program to confirm the friendly # name, but it is not possible. The pkcs12 command # does not store the friendly name in the cert's diff --git a/OpenSSL/test/test_rand.py b/OpenSSL/test/test_rand.py index a785168..00fc6d1 100644 --- a/OpenSSL/test/test_rand.py +++ b/OpenSSL/test/test_rand.py @@ -1,4 +1,5 @@ -# Copyright (C) Frederick Dean 2009, All rights reserved +# Copyright (c) Frederick Dean +# See LICENSE for details. """ Unit tests for L{OpenSSL.rand}. diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 6c8579b..2ab67fd 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1,32 +1,41 @@ -# Copyright (C) Jean-Paul Calderone 2008-2010, All rights reserved +# Copyright (C) Jean-Paul Calderone +# See LICENSE for details. """ Unit tests for L{OpenSSL.SSL}. """ +from gc import collect from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK -from sys import platform +from sys import platform, version_info from socket import error, socket from os import makedirs from os.path import join from unittest import main +from weakref import ref -from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, FILETYPE_ASN1 +from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM from OpenSSL.crypto import PKey, X509, X509Extension from OpenSSL.crypto import dump_privatekey, load_privatekey from OpenSSL.crypto import dump_certificate, load_certificate +from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS +from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD from OpenSSL.SSL import OP_NO_SSLv2, OP_NO_SSLv3, OP_SINGLE_DH_USE -from OpenSSL.SSL import VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE -from OpenSSL.SSL import Error, SysCallError, WantReadError, ZeroReturnError +from OpenSSL.SSL import ( + VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE) +from OpenSSL.SSL import ( + Error, SysCallError, WantReadError, ZeroReturnError, SSLeay_version) from OpenSSL.SSL import Context, ContextType, Connection, ConnectionType from OpenSSL.test.util import TestCase, bytes, b -from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM -from OpenSSL.test.test_crypto import client_cert_pem, client_key_pem -from OpenSSL.test.test_crypto import server_cert_pem, server_key_pem, root_cert_pem +from OpenSSL.test.test_crypto import ( + cleartextCertificatePEM, cleartextPrivateKeyPEM) +from OpenSSL.test.test_crypto import ( + client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, + root_cert_pem) try: from OpenSSL.SSL import OP_NO_QUERY_MTU @@ -41,6 +50,13 @@ try: except ImportError: OP_NO_TICKET = None +from OpenSSL.SSL import ( + SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, SSL_ST_INIT, SSL_ST_BEFORE, + SSL_ST_OK, SSL_ST_RENEGOTIATE, + SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, + SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, + SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, + SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE) # openssl dhparam 128 -out dh-128.pem (note that 128 is a small number of bits # to use) @@ -54,6 +70,7 @@ MBYCEQCobsg29c9WZP/54oAPcwiDAgEC def verify_cb(conn, cert, errnum, depth, ok): return ok + def socket_pair(): """ Establish and return a pair of network sockets connected to each other. @@ -96,6 +113,60 @@ def handshake(client, server): conns.remove(conn) +def _create_certificate_chain(): + """ + Construct and return a chain of certificates. + + 1. A new self-signed certificate authority certificate (cacert) + 2. A new intermediate certificate signed by cacert (icert) + 3. A new server certificate signed by icert (scert) + """ + caext = X509Extension(b('basicConstraints'), False, b('CA:true')) + + # Step 1 + cakey = PKey() + cakey.generate_key(TYPE_RSA, 512) + cacert = X509() + cacert.get_subject().commonName = "Authority Certificate" + cacert.set_issuer(cacert.get_subject()) + cacert.set_pubkey(cakey) + cacert.set_notBefore(b("20000101000000Z")) + cacert.set_notAfter(b("20200101000000Z")) + cacert.add_extensions([caext]) + cacert.set_serial_number(0) + cacert.sign(cakey, "sha1") + + # Step 2 + ikey = PKey() + ikey.generate_key(TYPE_RSA, 512) + icert = X509() + icert.get_subject().commonName = "Intermediate Certificate" + icert.set_issuer(cacert.get_subject()) + icert.set_pubkey(ikey) + icert.set_notBefore(b("20000101000000Z")) + icert.set_notAfter(b("20200101000000Z")) + icert.add_extensions([caext]) + icert.set_serial_number(0) + icert.sign(cakey, "sha1") + + # Step 3 + skey = PKey() + skey.generate_key(TYPE_RSA, 512) + scert = X509() + scert.get_subject().commonName = "Server Certificate" + scert.set_issuer(icert.get_subject()) + scert.set_pubkey(skey) + scert.set_notBefore(b("20000101000000Z")) + scert.set_notAfter(b("20200101000000Z")) + scert.add_extensions([ + X509Extension(b('basicConstraints'), True, b('CA:false'))]) + scert.set_serial_number(0) + scert.sign(ikey, "sha1") + + return [(cakey, cacert), (ikey, icert), (skey, scert)] + + + class _LoopbackMixin: """ Helper mixin which defines methods for creating a connected socket pair and @@ -141,7 +212,7 @@ class _LoopbackMixin: # Give the side a chance to generate some more bytes, or # succeed. try: - bytes = read.recv(2 ** 16) + data = read.recv(2 ** 16) except WantReadError: # It didn't succeed, so we'll hope it generated some # output. @@ -149,7 +220,7 @@ class _LoopbackMixin: else: # It did succeed, so we'll stop now and let the caller deal # with it. - return (read, bytes) + return (read, data) while True: # Keep copying as long as there's more stuff there. @@ -167,6 +238,36 @@ class _LoopbackMixin: +class VersionTests(TestCase): + """ + Tests for version information exposed by + L{OpenSSL.SSL.SSLeay_version} and + L{OpenSSL.SSL.OPENSSL_VERSION_NUMBER}. + """ + def test_OPENSSL_VERSION_NUMBER(self): + """ + L{OPENSSL_VERSION_NUMBER} is an integer with status in the low + byte and the patch, fix, minor, and major versions in the + nibbles above that. + """ + self.assertTrue(isinstance(OPENSSL_VERSION_NUMBER, int)) + + + def test_SSLeay_version(self): + """ + L{SSLeay_version} takes a version type indicator and returns + one of a number of version strings based on that indicator. + """ + versions = {} + for t in [SSLEAY_VERSION, SSLEAY_CFLAGS, SSLEAY_BUILT_ON, + SSLEAY_PLATFORM, SSLEAY_DIR]: + version = SSLeay_version(t) + versions[version] = t + self.assertTrue(isinstance(version, bytes)) + self.assertEqual(len(versions), 5) + + + class ContextTests(TestCase, _LoopbackMixin): """ Unit tests for L{OpenSSL.SSL.Context}. @@ -176,8 +277,16 @@ class ContextTests(TestCase, _LoopbackMixin): L{Context} can be instantiated with one of L{SSLv2_METHOD}, L{SSLv3_METHOD}, L{SSLv23_METHOD}, or L{TLSv1_METHOD}. """ - for meth in [SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD]: + for meth in [SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD]: Context(meth) + + try: + Context(SSLv2_METHOD) + except ValueError: + # Some versions of OpenSSL have SSLv2, some don't. + # Difficult to say in advance. + pass + self.assertRaises(TypeError, Context, "") self.assertRaises(ValueError, Context, 10) @@ -512,12 +621,14 @@ class ContextTests(TestCase, _LoopbackMixin): """ capath = self.mktemp() makedirs(capath) - # Hash value computed manually with c_rehash to avoid depending on - # c_rehash in the test suite. - cafile = join(capath, 'c7adac82.0') - fObj = open(cafile, 'w') - fObj.write(cleartextCertificatePEM.decode('ascii')) - fObj.close() + # Hash values computed manually with c_rehash to avoid depending on + # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other + # from OpenSSL 1.0.0. + for name in ['c7adac82.0', 'c3705638.0']: + cafile = join(capath, name) + fObj = open(cafile, 'w') + fObj.write(cleartextCertificatePEM.decode('ascii')) + fObj.close() self._load_verify_locations_test(None, capath) @@ -590,59 +701,6 @@ class ContextTests(TestCase, _LoopbackMixin): self.assertRaises(TypeError, context.add_extra_chain_cert, object(), object()) - def _create_certificate_chain(self): - """ - Construct and return a chain of certificates. - - 1. A new self-signed certificate authority certificate (cacert) - 2. A new intermediate certificate signed by cacert (icert) - 3. A new server certificate signed by icert (scert) - """ - caext = X509Extension(b('basicConstraints'), False, b('CA:true')) - - # Step 1 - cakey = PKey() - cakey.generate_key(TYPE_RSA, 512) - cacert = X509() - cacert.get_subject().commonName = "Authority Certificate" - cacert.set_issuer(cacert.get_subject()) - cacert.set_pubkey(cakey) - cacert.set_notBefore(b("20000101000000Z")) - cacert.set_notAfter(b("20200101000000Z")) - cacert.add_extensions([caext]) - cacert.set_serial_number(0) - cacert.sign(cakey, "sha1") - - # Step 2 - ikey = PKey() - ikey.generate_key(TYPE_RSA, 512) - icert = X509() - icert.get_subject().commonName = "Intermediate Certificate" - icert.set_issuer(cacert.get_subject()) - icert.set_pubkey(ikey) - icert.set_notBefore(b("20000101000000Z")) - icert.set_notAfter(b("20200101000000Z")) - icert.add_extensions([caext]) - icert.set_serial_number(0) - icert.sign(cakey, "sha1") - - # Step 3 - skey = PKey() - skey.generate_key(TYPE_RSA, 512) - scert = X509() - scert.get_subject().commonName = "Server Certificate" - scert.set_issuer(icert.get_subject()) - scert.set_pubkey(skey) - scert.set_notBefore(b("20000101000000Z")) - scert.set_notAfter(b("20200101000000Z")) - scert.add_extensions([ - X509Extension(b('basicConstraints'), True, b('CA:false'))]) - scert.set_serial_number(0) - scert.sign(ikey, "sha1") - - return [(cakey, cacert), (ikey, icert), (skey, scert)] - - def _handshake_test(self, serverContext, clientContext): """ Verify that a client and server created with the given contexts can @@ -678,7 +736,7 @@ class ContextTests(TestCase, _LoopbackMixin): to it with a client which trusts cacert and requires verification to succeed. """ - chain = self._create_certificate_chain() + chain = _create_certificate_chain() [(cakey, cacert), (ikey, icert), (skey, scert)] = chain # Dump the CA certificate to a file because that's the only way to load @@ -719,7 +777,7 @@ class ContextTests(TestCase, _LoopbackMixin): to it with a client which trusts cacert and requires verification to succeed. """ - chain = self._create_certificate_chain() + chain = _create_certificate_chain() [(cakey, cacert), (ikey, icert), (skey, scert)] = chain # Write out the chain file. @@ -817,6 +875,106 @@ class ContextTests(TestCase, _LoopbackMixin): +class ServerNameCallbackTests(TestCase, _LoopbackMixin): + """ + Tests for L{Context.set_tlsext_servername_callback} and its interaction with + L{Connection}. + """ + def test_wrong_args(self): + """ + L{Context.set_tlsext_servername_callback} raises L{TypeError} if called + with other than one argument. + """ + context = Context(TLSv1_METHOD) + self.assertRaises(TypeError, context.set_tlsext_servername_callback) + self.assertRaises( + TypeError, context.set_tlsext_servername_callback, 1, 2) + + def test_old_callback_forgotten(self): + """ + If L{Context.set_tlsext_servername_callback} is used to specify a new + callback, the one it replaces is dereferenced. + """ + def callback(connection): + pass + + def replacement(connection): + pass + + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(callback) + + tracker = ref(callback) + del callback + + context.set_tlsext_servername_callback(replacement) + collect() + self.assertIdentical(None, tracker()) + + + def test_no_servername(self): + """ + When a client specifies no server name, the callback passed to + L{Context.set_tlsext_servername_callback} is invoked and the result of + L{Connection.get_servername} is C{None}. + """ + args = [] + def servername(conn): + args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(servername) + + # Lose our reference to it. The Context is responsible for keeping it + # alive now. + del servername + collect() + + # Necessary to actually accept the connection + context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(context, None) + server.set_accept_state() + + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, None)], args) + + + def test_servername(self): + """ + When a client specifies a server name in its hello message, the callback + passed to L{Contexts.set_tlsext_servername_callback} is invoked and the + result of L{Connection.get_servername} is that server name. + """ + args = [] + def servername(conn): + args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(servername) + + # Necessary to actually accept the connection + context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(context, None) + server.set_accept_state() + + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + client.set_tlsext_host_name(b("foo1.example.com")) + + self._interactInMemory(server, client) + + self.assertEqual([(server, b("foo1.example.com"))], args) + + + class ConnectionTests(TestCase, _LoopbackMixin): """ Unit tests for L{OpenSSL.SSL.Connection}. @@ -868,6 +1026,71 @@ class ConnectionTests(TestCase, _LoopbackMixin): self.assertRaises(TypeError, connection.get_context, None) + def test_set_context_wrong_args(self): + """ + L{Connection.set_context} raises L{TypeError} if called with a + non-L{Context} instance argument or with any number of arguments other + than 1. + """ + ctx = Context(TLSv1_METHOD) + connection = Connection(ctx, None) + self.assertRaises(TypeError, connection.set_context) + self.assertRaises(TypeError, connection.set_context, object()) + self.assertRaises(TypeError, connection.set_context, "hello") + self.assertRaises(TypeError, connection.set_context, 1) + self.assertRaises(TypeError, connection.set_context, 1, 2) + self.assertRaises( + TypeError, connection.set_context, Context(TLSv1_METHOD), 2) + self.assertIdentical(ctx, connection.get_context()) + + + def test_set_context(self): + """ + L{Connection.set_context} specifies a new L{Context} instance to be used + for the connection. + """ + original = Context(SSLv23_METHOD) + replacement = Context(TLSv1_METHOD) + connection = Connection(original, None) + connection.set_context(replacement) + self.assertIdentical(replacement, connection.get_context()) + # Lose our references to the contexts, just in case the Connection isn't + # properly managing its own contributions to their reference counts. + del original, replacement + collect() + + + def test_set_tlsext_host_name_wrong_args(self): + """ + If L{Connection.set_tlsext_host_name} is called with a non-byte string + argument or a byte string with an embedded NUL or other than one + argument, L{TypeError} is raised. + """ + conn = Connection(Context(TLSv1_METHOD), None) + self.assertRaises(TypeError, conn.set_tlsext_host_name) + self.assertRaises(TypeError, conn.set_tlsext_host_name, object()) + self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456) + self.assertRaises( + TypeError, conn.set_tlsext_host_name, b("with\0null")) + + if version_info >= (3,): + # On Python 3.x, don't accidentally implicitly convert from text. + self.assertRaises( + TypeError, + conn.set_tlsext_host_name, b("example.com").decode("ascii")) + + + def test_get_servername_wrong_args(self): + """ + L{Connection.get_servername} raises L{TypeError} if called with any + arguments. + """ + connection = Connection(Context(TLSv1_METHOD), None) + self.assertRaises(TypeError, connection.get_servername, object()) + self.assertRaises(TypeError, connection.get_servername, 1) + self.assertRaises(TypeError, connection.get_servername, "hello") + + def test_pending(self): """ L{Connection.pending} returns the number of bytes available for @@ -1047,6 +1270,68 @@ class ConnectionTests(TestCase, _LoopbackMixin): self.assertRaises(NotImplementedError, conn.makefile) + def test_get_peer_cert_chain_wrong_args(self): + """ + L{Connection.get_peer_cert_chain} raises L{TypeError} if called with any + arguments. + """ + conn = Connection(Context(TLSv1_METHOD), None) + self.assertRaises(TypeError, conn.get_peer_cert_chain, 1) + self.assertRaises(TypeError, conn.get_peer_cert_chain, "foo") + self.assertRaises(TypeError, conn.get_peer_cert_chain, object()) + self.assertRaises(TypeError, conn.get_peer_cert_chain, []) + + + def test_get_peer_cert_chain(self): + """ + L{Connection.get_peer_cert_chain} returns a list of certificates which + the connected server returned for the certification verification. + """ + chain = _create_certificate_chain() + [(cakey, cacert), (ikey, icert), (skey, scert)] = chain + + serverContext = Context(TLSv1_METHOD) + serverContext.use_privatekey(skey) + serverContext.use_certificate(scert) + serverContext.add_extra_chain_cert(icert) + serverContext.add_extra_chain_cert(cacert) + server = Connection(serverContext, None) + server.set_accept_state() + + # Create the client + clientContext = Context(TLSv1_METHOD) + clientContext.set_verify(VERIFY_NONE, verify_cb) + client = Connection(clientContext, None) + client.set_connect_state() + + self._interactInMemory(client, server) + + chain = client.get_peer_cert_chain() + self.assertEqual(len(chain), 3) + self.assertEqual( + "Server Certificate", chain[0].get_subject().CN) + self.assertEqual( + "Intermediate Certificate", chain[1].get_subject().CN) + self.assertEqual( + "Authority Certificate", chain[2].get_subject().CN) + + + def test_get_peer_cert_chain_none(self): + """ + L{Connection.get_peer_cert_chain} returns C{None} if the peer sends no + certificate chain. + """ + ctx = Context(TLSv1_METHOD) + ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + server = Connection(ctx, None) + server.set_accept_state() + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + self._interactInMemory(client, server) + self.assertIdentical(None, server.get_peer_cert_chain()) + + class ConnectionGetCipherListTests(TestCase): """ @@ -1074,6 +1359,49 @@ class ConnectionGetCipherListTests(TestCase): +class ConnectionSendTests(TestCase, _LoopbackMixin): + """ + Tests for L{Connection.send} + """ + def test_wrong_args(self): + """ + When called with arguments other than a single string, + L{Connection.send} raises L{TypeError}. + """ + connection = Connection(Context(TLSv1_METHOD), None) + self.assertRaises(TypeError, connection.send) + self.assertRaises(TypeError, connection.send, object()) + self.assertRaises(TypeError, connection.send, "foo", "bar") + + + def test_short_bytes(self): + """ + When passed a short byte string, L{Connection.send} transmits all of it + and returns the number of bytes sent. + """ + server, client = self._loopback() + count = server.send(b('xy')) + self.assertEquals(count, 2) + self.assertEquals(client.recv(2), b('xy')) + + try: + memoryview + except NameError: + "cannot test sending memoryview without memoryview" + else: + def test_short_memoryview(self): + """ + When passed a memoryview onto a small number of bytes, + L{Connection.send} transmits all of them and returns the number of + bytes sent. + """ + server, client = self._loopback() + count = server.send(memoryview(b('xy'))) + self.assertEquals(count, 2) + self.assertEquals(client.recv(2), b('xy')) + + + class ConnectionSendallTests(TestCase, _LoopbackMixin): """ Tests for L{Connection.sendall}. @@ -1099,6 +1427,21 @@ class ConnectionSendallTests(TestCase, _LoopbackMixin): self.assertEquals(client.recv(1), b('x')) + try: + memoryview + except NameError: + "cannot test sending memoryview without memoryview" + else: + def test_short_memoryview(self): + """ + When passed a memoryview onto a small number of bytes, + L{Connection.sendall} transmits all of them. + """ + server, client = self._loopback() + server.sendall(memoryview(b('x'))) + self.assertEquals(client.recv(1), b('x')) + + def test_long(self): """ L{Connection.sendall} transmits all of the bytes in the string passed to @@ -1617,6 +1960,28 @@ class MemoryBIOTests(TestCase, _LoopbackMixin): self._check_client_ca_list(set_replaces_add_ca) +class InfoConstantTests(TestCase): + """ + Tests for assorted constants exposed for use in info callbacks. + """ + def test_integers(self): + """ + All of the info constants are integers. + + This is a very weak test. It would be nice to have one that actually + verifies that as certain info events happen, the value passed to the + info callback matches up with the constant exposed by OpenSSL.SSL. + """ + for const in [ + SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, SSL_ST_INIT, + SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE, + SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, + SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, + SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, + SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE]: + + self.assertTrue(isinstance(const, int)) + if __name__ == '__main__': main() diff --git a/OpenSSL/test/util.py b/OpenSSL/test/util.py index 61246a6..643fb91 100644 --- a/OpenSSL/test/util.py +++ b/OpenSSL/test/util.py @@ -1,5 +1,5 @@ -# Copyright (C) Jean-Paul Calderone 2009, All rights reserved -# Copyright (c) 2001-2009 Twisted Matrix Laboratories. +# Copyright (C) Jean-Paul Calderone +# Copyright (C) Twisted Matrix Laboratories. # See LICENSE for details. """ @@ -15,16 +15,14 @@ import sys from OpenSSL.crypto import Error, _exception_from_error_queue - -try: - bytes = bytes -except NameError: +if sys.version_info < (3, 0): def b(s): return s bytes = str else: def b(s): - return s.encode("ascii") + return s.encode("charmap") + bytes = bytes class TestCase(TestCase): @@ -52,6 +50,22 @@ class TestCase(TestCase): self.fail("Left over errors in OpenSSL error queue: " + repr(e)) + def failUnlessIn(self, containee, container, msg=None): + """ + Fail the test if C{containee} is not found in C{container}. + + @param containee: the value that should be in C{container} + @param container: a sequence type, or in the case of a mapping type, + will follow semantics of 'if key in dict.keys()' + @param msg: if msg is None, then the failure message will be + '%r not in %r' % (first, second) + """ + if containee not in container: + raise self.failureException(msg or "%r not in %r" + % (containee, container)) + return containee + assertIn = failUnlessIn + def failUnlessIdentical(self, first, second, msg=None): """ Fail the test if C{first} is not C{second}. This is an |