summaryrefslogtreecommitdiff
path: root/OpenSSL/test
diff options
context:
space:
mode:
authorJean-Paul Calderone <exarkun@divmod.com>2011-06-12 17:20:31 -0400
committerJean-Paul Calderone <exarkun@divmod.com>2011-06-12 17:20:31 -0400
commit0da7acc9079aeba61a764cb84b243e4742dae92d (patch)
tree95228a0bcac56023c8a82de8a8deff1af5fdaee5 /OpenSSL/test
parent9828f666f9eed90290093e962181d6065bb2017c (diff)
parentc27733339ba8f678479db744e75a994639cba61e (diff)
downloadpyopenssl-0da7acc9079aeba61a764cb84b243e4742dae92d.tar.gz
Catch up to trunk
Diffstat (limited to 'OpenSSL/test')
-rw-r--r--OpenSSL/test/__init__.py3
-rw-r--r--OpenSSL/test/test_crypto.py211
-rw-r--r--OpenSSL/test/test_rand.py3
-rw-r--r--OpenSSL/test/test_ssl.py509
-rw-r--r--OpenSSL/test/util.py28
5 files changed, 664 insertions, 90 deletions
diff --git a/OpenSSL/test/__init__.py b/OpenSSL/test/__init__.py
index ab9c4cb..ccb4e9a 100644
--- a/OpenSSL/test/__init__.py
+++ b/OpenSSL/test/__init__.py
@@ -1,4 +1,5 @@
-# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
+# Copyright (C) Jean-Paul Calderone
+# See LICENSE for details.
"""
Package containing unit tests for L{OpenSSL}.
diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py
index 07c172f..71da5c3 100644
--- a/OpenSSL/test/test_crypto.py
+++ b/OpenSSL/test/test_crypto.py
@@ -1,4 +1,5 @@
-# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
+# Copyright (c) Jean-Paul Calderone
+# See LICENSE file for details.
"""
Unit tests for L{OpenSSL.crypto}.
@@ -25,6 +26,13 @@ from OpenSSL.crypto import NetscapeSPKI, NetscapeSPKIType
from OpenSSL.crypto import sign, verify
from OpenSSL.test.util import TestCase, bytes, b
+def normalize_certificate_pem(pem):
+ return dump_certificate(FILETYPE_PEM, load_certificate(FILETYPE_PEM, pem))
+
+
+def normalize_privatekey_pem(pem):
+ return dump_privatekey(FILETYPE_PEM, load_privatekey(FILETYPE_PEM, pem))
+
root_cert_pem = b("""-----BEGIN CERTIFICATE-----
MIIC7TCCAlagAwIBAgIIPQzE4MbeufQwDQYJKoZIhvcNAQEFBQAwWDELMAkGA1UE
@@ -79,7 +87,7 @@ uzujnS8YXWvM7DM1Ilozk4MzPug8jzFp5uhKCQ==
-----END CERTIFICATE-----
""")
-server_key_pem = b("""-----BEGIN RSA PRIVATE KEY-----
+server_key_pem = normalize_privatekey_pem(b("""-----BEGIN RSA PRIVATE KEY-----
MIICWwIBAAKBgQC+pvhuud1dLaQQvzipdtlcTotgr5SuE2LvSx0gz/bg1U3u1eQ+
U5eqsxaEUceaX5p5Kk+QflvW8qdjVNxQuYS5uc0gK2+OZnlIYxCf4n5GYGzVIx3Q
SBj/TAEFB2WuVinZBiCbxgL7PFM1Kpa+EwVkCAduPpSflJJPwkYGrK2MHQIDAQAB
@@ -94,7 +102,7 @@ FwwOhpahld+vqhYk+pfuWWUpQciE+Bu7ZQJASjfT4sQv4qbbKK/scePicnDdx9th
NaeNCFfH3aeTrX0LyQJAMBWjWmeKM2G2sCExheeQK0ROnaBC8itCECD4Jsve4nqf
r50+LF74iLXFwqysVCebPKMOpDWp/qQ1BbJQIPs7/A==
-----END RSA PRIVATE KEY-----
-""")
+"""))
client_cert_pem = b("""-----BEGIN CERTIFICATE-----
MIICJjCCAY+gAwIBAgIJAKxpFI5lODkjMA0GCSqGSIb3DQEBBQUAMFgxCzAJBgNV
@@ -112,7 +120,7 @@ PSTJCjJOn3xo2NTKRgV1gaoTf2EhL+RG8TQ=
-----END CERTIFICATE-----
""")
-client_key_pem = b("""-----BEGIN RSA PRIVATE KEY-----
+client_key_pem = normalize_privatekey_pem(b("""-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDAZh/SRtNm5ntMT4qb6YzEpTroMlq2rn+GrRHRiZ+xkCw/CGNh
btPir7/QxaUj26BSmQrHw1bGKEbPsWiW7bdXSespl+xKiku4G/KvnnmWdeJHqsiX
eUZtqurMELcPQAw9xPHEuhqqUJvvEoMTsnCEqGM+7DtboCRajYyHfluARQIDAQAB
@@ -127,7 +135,7 @@ si6xwT7GzMDkk/ko684AV3KPc/h6G0yGtFIrMg7J3uExpR/VdH2KgwMkZXisSMvw
JJEQjOMCVsEJlRk54WWjAkEAzoZNH6UhDdBK5F38rVt/y4SEHgbSfJHIAmPS32Kq
f6GGcfNpip0Uk7q7udTKuX7Q/buZi/C4YW7u3VKAquv9NA==
-----END RSA PRIVATE KEY-----
-""")
+"""))
cleartextCertificatePEM = b("""-----BEGIN CERTIFICATE-----
MIIC7TCCAlagAwIBAgIIPQzE4MbeufQwDQYJKoZIhvcNAQEFBQAwWDELMAkGA1UE
@@ -149,7 +157,8 @@ w/njVbKMXrvc83qmTdGl3TAM0fxQIpqgcglFLveEBgzn
-----END CERTIFICATE-----
""")
-cleartextPrivateKeyPEM = b("""-----BEGIN RSA PRIVATE KEY-----
+cleartextPrivateKeyPEM = normalize_privatekey_pem(b("""\
+-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQD5mkLpi7q6ROdu7khB3S9aanA0Zls7vvfGOmB80/yeylhGpsjA
jWen0VtSQke/NlEPGtO38tsV7CsuFnSmschvAnGrcJl76b0UOOHUgDTIoRxC6QDU
3claegwsrBA+sJEBbqx5RdXbIRGicPG/8qQ4Zm1SKOgotcbwiaor2yxZ2wIDAQAB
@@ -164,7 +173,7 @@ ttXigLnCqR486JDPTi9ZscoZkZ+w7y6e/hH8t6d5Vjt48JVyfjPIaJY+km58LcN3
6AWSeGAdtRFHVzR7oHjVAkB4hutvxiOeiIVQNBhM6RSI9aBPMI21DoX2JRoxvNW2
cbvAhow217X9V0dVerEOKxnNYspXRrh36h7k4mQA+sDq
-----END RSA PRIVATE KEY-----
-""")
+"""))
cleartextCertificateRequestPEM = b("""-----BEGIN CERTIFICATE REQUEST-----
MIIBnjCCAQcCAQAwXjELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAklMMRAwDgYDVQQH
@@ -352,6 +361,26 @@ class X509ExtTests(TestCase):
self.assertEqual(ext.get_short_name(), b('nsComment'))
+ def test_get_data(self):
+ """
+ L{X509Extension.get_data} returns a string giving the data of the
+ extension.
+ """
+ ext = X509Extension(b('basicConstraints'), True, b('CA:true'))
+ # Expect to get back the DER encoded form of CA:true.
+ self.assertEqual(ext.get_data(), b('0\x03\x01\x01\xff'))
+
+
+ def test_get_data_wrong_args(self):
+ """
+ L{X509Extension.get_data} raises L{TypeError} if passed any arguments.
+ """
+ ext = X509Extension(b('basicConstraints'), True, b('CA:true'))
+ self.assertRaises(TypeError, ext.get_data, None)
+ self.assertRaises(TypeError, ext.get_data, "foo")
+ self.assertRaises(TypeError, ext.get_data, 7)
+
+
def test_unused_subject(self):
"""
The C{subject} parameter to L{X509Extension} may be provided for an
@@ -597,6 +626,33 @@ class X509NameTests(TestCase):
name, type(name), X509NameType))
+ def test_onlyStringAttributes(self):
+ """
+ Attempting to set a non-L{str} attribute name on an L{X509NameType}
+ instance causes L{TypeError} to be raised.
+ """
+ name = self._x509name()
+ # Beyond these cases, you may also think that unicode should be
+ # rejected. Sorry, you're wrong. unicode is automatically converted to
+ # str outside of the control of X509Name, so there's no way to reject
+ # it.
+ self.assertRaises(TypeError, setattr, name, None, "hello")
+ self.assertRaises(TypeError, setattr, name, 30, "hello")
+ class evil(str):
+ pass
+ self.assertRaises(TypeError, setattr, name, evil(), "hello")
+
+
+ def test_setInvalidAttribute(self):
+ """
+ Attempting to set any attribute name on an L{X509NameType} instance for
+ which no corresponding NID is defined causes L{AttributeError} to be
+ raised.
+ """
+ name = self._x509name()
+ self.assertRaises(AttributeError, setattr, name, "no such thing", None)
+
+
def test_attributes(self):
"""
L{X509NameType} instances have attributes for each standard (?)
@@ -946,6 +1002,26 @@ class X509Tests(TestCase, _PKeyInteractionTestsMixin):
"""
pemData = cleartextCertificatePEM + cleartextPrivateKeyPEM
+ extpem = """
+-----BEGIN CERTIFICATE-----
+MIIC3jCCAkegAwIBAgIJAJHFjlcCgnQzMA0GCSqGSIb3DQEBBQUAMEcxCzAJBgNV
+BAYTAlNFMRUwEwYDVQQIEwxXZXN0ZXJib3R0b20xEjAQBgNVBAoTCUNhdGFsb2dp
+eDENMAsGA1UEAxMEUm9vdDAeFw0wODA0MjIxNDQ1MzhaFw0wOTA0MjIxNDQ1Mzha
+MFQxCzAJBgNVBAYTAlNFMQswCQYDVQQIEwJXQjEUMBIGA1UEChMLT3Blbk1ldGFk
+aXIxIjAgBgNVBAMTGW5vZGUxLm9tMi5vcGVubWV0YWRpci5vcmcwgZ8wDQYJKoZI
+hvcNAQEBBQADgY0AMIGJAoGBAPIcQMrwbk2nESF/0JKibj9i1x95XYAOwP+LarwT
+Op4EQbdlI9SY+uqYqlERhF19w7CS+S6oyqx0DRZSk4Y9dZ9j9/xgm2u/f136YS1u
+zgYFPvfUs6PqYLPSM8Bw+SjJ+7+2+TN+Tkiof9WP1cMjodQwOmdsiRbR0/J7+b1B
+hec1AgMBAAGjgcQwgcEwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNT
+TCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFIdHsBcMVVMbAO7j6NCj
+03HgLnHaMB8GA1UdIwQYMBaAFL2h9Bf9Mre4vTdOiHTGAt7BRY/8MEYGA1UdEQQ/
+MD2CDSouZXhhbXBsZS5vcmeCESoub20yLmV4bWFwbGUuY29thwSC7wgKgRNvbTJA
+b3Blbm1ldGFkaXIub3JnMA0GCSqGSIb3DQEBBQUAA4GBALd7WdXkp2KvZ7/PuWZA
+MPlIxyjS+Ly11+BNE0xGQRp9Wz+2lABtpgNqssvU156+HkKd02rGheb2tj7MX9hG
+uZzbwDAZzJPjzDQDD7d3cWsrVcfIdqVU7epHqIadnOF+X0ghJ39pAm6VVadnSXCt
+WpOdIpB8KksUTCzV591Nr1wd
+-----END CERTIFICATE-----
+ """
def signable(self):
"""
Create and return a new L{X509}.
@@ -1198,6 +1274,77 @@ class X509Tests(TestCase, _PKeyInteractionTestsMixin):
b("A8:EB:07:F8:53:25:0A:F2:56:05:C5:A5:C4:C4:C7:15"))
+ def _extcert(self, pkey, extensions):
+ cert = X509()
+ cert.set_pubkey(pkey)
+ cert.get_subject().commonName = "Unit Tests"
+ cert.get_issuer().commonName = "Unit Tests"
+ when = b(datetime.now().strftime("%Y%m%d%H%M%SZ"))
+ cert.set_notBefore(when)
+ cert.set_notAfter(when)
+
+ cert.add_extensions(extensions)
+ return load_certificate(
+ FILETYPE_PEM, dump_certificate(FILETYPE_PEM, cert))
+
+
+ def test_extension_count(self):
+ """
+ L{X509.get_extension_count} returns the number of extensions that are
+ present in the certificate.
+ """
+ pkey = load_privatekey(FILETYPE_PEM, client_key_pem)
+ ca = X509Extension(b('basicConstraints'), True, b('CA:FALSE'))
+ key = X509Extension(b('keyUsage'), True, b('digitalSignature'))
+ subjectAltName = X509Extension(
+ b('subjectAltName'), True, b('DNS:example.com'))
+
+ # Try a certificate with no extensions at all.
+ c = self._extcert(pkey, [])
+ self.assertEqual(c.get_extension_count(), 0)
+
+ # And a certificate with one
+ c = self._extcert(pkey, [ca])
+ self.assertEqual(c.get_extension_count(), 1)
+
+ # And a certificate with several
+ c = self._extcert(pkey, [ca, key, subjectAltName])
+ self.assertEqual(c.get_extension_count(), 3)
+
+
+ def test_get_extension(self):
+ """
+ L{X509.get_extension} takes an integer and returns an L{X509Extension}
+ corresponding to the extension at that index.
+ """
+ pkey = load_privatekey(FILETYPE_PEM, client_key_pem)
+ ca = X509Extension(b('basicConstraints'), True, b('CA:FALSE'))
+ key = X509Extension(b('keyUsage'), True, b('digitalSignature'))
+ subjectAltName = X509Extension(
+ b('subjectAltName'), False, b('DNS:example.com'))
+
+ cert = self._extcert(pkey, [ca, key, subjectAltName])
+
+ ext = cert.get_extension(0)
+ self.assertTrue(isinstance(ext, X509Extension))
+ self.assertTrue(ext.get_critical())
+ self.assertEqual(ext.get_short_name(), b('basicConstraints'))
+
+ ext = cert.get_extension(1)
+ self.assertTrue(isinstance(ext, X509Extension))
+ self.assertTrue(ext.get_critical())
+ self.assertEqual(ext.get_short_name(), b('keyUsage'))
+
+ ext = cert.get_extension(2)
+ self.assertTrue(isinstance(ext, X509Extension))
+ self.assertFalse(ext.get_critical())
+ self.assertEqual(ext.get_short_name(), b('subjectAltName'))
+
+ self.assertRaises(IndexError, cert.get_extension, -1)
+ self.assertRaises(IndexError, cert.get_extension, 4)
+ self.assertRaises(TypeError, cert.get_extension, "hello")
+
+
def test_invalid_digest_algorithm(self):
"""
L{X509.digest} raises L{ValueError} if called with an unrecognized hash
@@ -1326,7 +1473,53 @@ class X509Tests(TestCase, _PKeyInteractionTestsMixin):
name.
"""
cert = load_certificate(FILETYPE_PEM, self.pemData)
- self.assertEquals(cert.subject_name_hash(), 3350047874)
+ self.assertIn(
+ cert.subject_name_hash(),
+ [3350047874, # OpenSSL 0.9.8, MD5
+ 3278919224, # OpenSSL 1.0.0, SHA1
+ ])
+
+
+ def test_get_signature_algorithm(self):
+ """
+ L{X509Type.get_signature_algorithm} returns a string which means
+ the algorithm used to sign the certificate.
+ """
+ cert = load_certificate(FILETYPE_PEM, self.pemData)
+ self.assertEqual(
+ b("sha1WithRSAEncryption"), cert.get_signature_algorithm())
+
+
+ def test_get_undefined_signature_algorithm(self):
+ """
+ L{X509Type.get_signature_algorithm} raises L{ValueError} if the
+ signature algorithm is undefined or unknown.
+ """
+ # This certificate has been modified to indicate a bogus OID in the
+ # signature algorithm field so that OpenSSL does not recognize it.
+ certPEM = """\
+-----BEGIN CERTIFICATE-----
+MIIC/zCCAmigAwIBAgIBATAGBgJ8BQUAMHsxCzAJBgNVBAYTAlNHMREwDwYDVQQK
+EwhNMkNyeXB0bzEUMBIGA1UECxMLTTJDcnlwdG8gQ0ExJDAiBgNVBAMTG00yQ3J5
+cHRvIENlcnRpZmljYXRlIE1hc3RlcjEdMBsGCSqGSIb3DQEJARYObmdwc0Bwb3N0
+MS5jb20wHhcNMDAwOTEwMDk1MTMwWhcNMDIwOTEwMDk1MTMwWjBTMQswCQYDVQQG
+EwJTRzERMA8GA1UEChMITTJDcnlwdG8xEjAQBgNVBAMTCWxvY2FsaG9zdDEdMBsG
+CSqGSIb3DQEJARYObmdwc0Bwb3N0MS5jb20wXDANBgkqhkiG9w0BAQEFAANLADBI
+AkEArL57d26W9fNXvOhNlZzlPOACmvwOZ5AdNgLzJ1/MfsQQJ7hHVeHmTAjM664V
++fXvwUGJLziCeBo1ysWLRnl8CQIDAQABo4IBBDCCAQAwCQYDVR0TBAIwADAsBglg
+hkgBhvhCAQ0EHxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0O
+BBYEFM+EgpK+eyZiwFU1aOPSbczbPSpVMIGlBgNVHSMEgZ0wgZqAFPuHI2nrnDqT
+FeXFvylRT/7tKDgBoX+kfTB7MQswCQYDVQQGEwJTRzERMA8GA1UEChMITTJDcnlw
+dG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQDExtNMkNyeXB0byBDZXJ0
+aWZpY2F0ZSBNYXN0ZXIxHTAbBgkqhkiG9w0BCQEWDm5ncHNAcG9zdDEuY29tggEA
+MA0GCSqGSIb3DQEBBAUAA4GBADv8KpPo+gfJxN2ERK1Y1l17sz/ZhzoGgm5XCdbx
+jEY7xKfpQngV599k1xhl11IMqizDwu0855agrckg2MCTmOI9DZzDD77tAYb+Dk0O
+PEVk0Mk/V0aIsDE9bolfCi/i/QWZ3N8s5nTWMNyBBBmoSliWCm4jkkRZRD0ejgTN
+tgI5
+-----END CERTIFICATE-----
+"""
+ cert = load_certificate(FILETYPE_PEM, certPEM)
+ self.assertRaises(ValueError, cert.get_signature_algorithm)
@@ -1547,7 +1740,7 @@ class PKCS12Tests(TestCase):
dumped_p12 = p12.export(passphrase=passwd, iter=2, maciter=3)
reloaded_p12 = load_pkcs12(dumped_p12, passwd)
self.assertEqual(
- p12.get_friendlyname(),reloaded_p12.get_friendlyname())
+ p12.get_friendlyname(), reloaded_p12.get_friendlyname())
# We would use the openssl program to confirm the friendly
# name, but it is not possible. The pkcs12 command
# does not store the friendly name in the cert's
diff --git a/OpenSSL/test/test_rand.py b/OpenSSL/test/test_rand.py
index a785168..00fc6d1 100644
--- a/OpenSSL/test/test_rand.py
+++ b/OpenSSL/test/test_rand.py
@@ -1,4 +1,5 @@
-# Copyright (C) Frederick Dean 2009, All rights reserved
+# Copyright (c) Frederick Dean
+# See LICENSE for details.
"""
Unit tests for L{OpenSSL.rand}.
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 6c8579b..2ab67fd 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -1,32 +1,41 @@
-# Copyright (C) Jean-Paul Calderone 2008-2010, All rights reserved
+# Copyright (C) Jean-Paul Calderone
+# See LICENSE for details.
"""
Unit tests for L{OpenSSL.SSL}.
"""
+from gc import collect
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK
-from sys import platform
+from sys import platform, version_info
from socket import error, socket
from os import makedirs
from os.path import join
from unittest import main
+from weakref import ref
-from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, FILETYPE_ASN1
+from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
from OpenSSL.crypto import PKey, X509, X509Extension
from OpenSSL.crypto import dump_privatekey, load_privatekey
from OpenSSL.crypto import dump_certificate, load_certificate
+from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS
+from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON
from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
from OpenSSL.SSL import OP_NO_SSLv2, OP_NO_SSLv3, OP_SINGLE_DH_USE
-from OpenSSL.SSL import VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE
-from OpenSSL.SSL import Error, SysCallError, WantReadError, ZeroReturnError
+from OpenSSL.SSL import (
+ VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE)
+from OpenSSL.SSL import (
+ Error, SysCallError, WantReadError, ZeroReturnError, SSLeay_version)
from OpenSSL.SSL import Context, ContextType, Connection, ConnectionType
from OpenSSL.test.util import TestCase, bytes, b
-from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM
-from OpenSSL.test.test_crypto import client_cert_pem, client_key_pem
-from OpenSSL.test.test_crypto import server_cert_pem, server_key_pem, root_cert_pem
+from OpenSSL.test.test_crypto import (
+ cleartextCertificatePEM, cleartextPrivateKeyPEM)
+from OpenSSL.test.test_crypto import (
+ client_cert_pem, client_key_pem, server_cert_pem, server_key_pem,
+ root_cert_pem)
try:
from OpenSSL.SSL import OP_NO_QUERY_MTU
@@ -41,6 +50,13 @@ try:
except ImportError:
OP_NO_TICKET = None
+from OpenSSL.SSL import (
+ SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, SSL_ST_INIT, SSL_ST_BEFORE,
+ SSL_ST_OK, SSL_ST_RENEGOTIATE,
+ SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT,
+ SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP,
+ SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT,
+ SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE)
# openssl dhparam 128 -out dh-128.pem (note that 128 is a small number of bits
# to use)
@@ -54,6 +70,7 @@ MBYCEQCobsg29c9WZP/54oAPcwiDAgEC
def verify_cb(conn, cert, errnum, depth, ok):
return ok
+
def socket_pair():
"""
Establish and return a pair of network sockets connected to each other.
@@ -96,6 +113,60 @@ def handshake(client, server):
conns.remove(conn)
+def _create_certificate_chain():
+ """
+ Construct and return a chain of certificates.
+
+ 1. A new self-signed certificate authority certificate (cacert)
+ 2. A new intermediate certificate signed by cacert (icert)
+ 3. A new server certificate signed by icert (scert)
+ """
+ caext = X509Extension(b('basicConstraints'), False, b('CA:true'))
+
+ # Step 1
+ cakey = PKey()
+ cakey.generate_key(TYPE_RSA, 512)
+ cacert = X509()
+ cacert.get_subject().commonName = "Authority Certificate"
+ cacert.set_issuer(cacert.get_subject())
+ cacert.set_pubkey(cakey)
+ cacert.set_notBefore(b("20000101000000Z"))
+ cacert.set_notAfter(b("20200101000000Z"))
+ cacert.add_extensions([caext])
+ cacert.set_serial_number(0)
+ cacert.sign(cakey, "sha1")
+
+ # Step 2
+ ikey = PKey()
+ ikey.generate_key(TYPE_RSA, 512)
+ icert = X509()
+ icert.get_subject().commonName = "Intermediate Certificate"
+ icert.set_issuer(cacert.get_subject())
+ icert.set_pubkey(ikey)
+ icert.set_notBefore(b("20000101000000Z"))
+ icert.set_notAfter(b("20200101000000Z"))
+ icert.add_extensions([caext])
+ icert.set_serial_number(0)
+ icert.sign(cakey, "sha1")
+
+ # Step 3
+ skey = PKey()
+ skey.generate_key(TYPE_RSA, 512)
+ scert = X509()
+ scert.get_subject().commonName = "Server Certificate"
+ scert.set_issuer(icert.get_subject())
+ scert.set_pubkey(skey)
+ scert.set_notBefore(b("20000101000000Z"))
+ scert.set_notAfter(b("20200101000000Z"))
+ scert.add_extensions([
+ X509Extension(b('basicConstraints'), True, b('CA:false'))])
+ scert.set_serial_number(0)
+ scert.sign(ikey, "sha1")
+
+ return [(cakey, cacert), (ikey, icert), (skey, scert)]
+
+
+
class _LoopbackMixin:
"""
Helper mixin which defines methods for creating a connected socket pair and
@@ -141,7 +212,7 @@ class _LoopbackMixin:
# Give the side a chance to generate some more bytes, or
# succeed.
try:
- bytes = read.recv(2 ** 16)
+ data = read.recv(2 ** 16)
except WantReadError:
# It didn't succeed, so we'll hope it generated some
# output.
@@ -149,7 +220,7 @@ class _LoopbackMixin:
else:
# It did succeed, so we'll stop now and let the caller deal
# with it.
- return (read, bytes)
+ return (read, data)
while True:
# Keep copying as long as there's more stuff there.
@@ -167,6 +238,36 @@ class _LoopbackMixin:
+class VersionTests(TestCase):
+ """
+ Tests for version information exposed by
+ L{OpenSSL.SSL.SSLeay_version} and
+ L{OpenSSL.SSL.OPENSSL_VERSION_NUMBER}.
+ """
+ def test_OPENSSL_VERSION_NUMBER(self):
+ """
+ L{OPENSSL_VERSION_NUMBER} is an integer with status in the low
+ byte and the patch, fix, minor, and major versions in the
+ nibbles above that.
+ """
+ self.assertTrue(isinstance(OPENSSL_VERSION_NUMBER, int))
+
+
+ def test_SSLeay_version(self):
+ """
+ L{SSLeay_version} takes a version type indicator and returns
+ one of a number of version strings based on that indicator.
+ """
+ versions = {}
+ for t in [SSLEAY_VERSION, SSLEAY_CFLAGS, SSLEAY_BUILT_ON,
+ SSLEAY_PLATFORM, SSLEAY_DIR]:
+ version = SSLeay_version(t)
+ versions[version] = t
+ self.assertTrue(isinstance(version, bytes))
+ self.assertEqual(len(versions), 5)
+
+
+
class ContextTests(TestCase, _LoopbackMixin):
"""
Unit tests for L{OpenSSL.SSL.Context}.
@@ -176,8 +277,16 @@ class ContextTests(TestCase, _LoopbackMixin):
L{Context} can be instantiated with one of L{SSLv2_METHOD},
L{SSLv3_METHOD}, L{SSLv23_METHOD}, or L{TLSv1_METHOD}.
"""
- for meth in [SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD]:
+ for meth in [SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD]:
Context(meth)
+
+ try:
+ Context(SSLv2_METHOD)
+ except ValueError:
+ # Some versions of OpenSSL have SSLv2, some don't.
+ # Difficult to say in advance.
+ pass
+
self.assertRaises(TypeError, Context, "")
self.assertRaises(ValueError, Context, 10)
@@ -512,12 +621,14 @@ class ContextTests(TestCase, _LoopbackMixin):
"""
capath = self.mktemp()
makedirs(capath)
- # Hash value computed manually with c_rehash to avoid depending on
- # c_rehash in the test suite.
- cafile = join(capath, 'c7adac82.0')
- fObj = open(cafile, 'w')
- fObj.write(cleartextCertificatePEM.decode('ascii'))
- fObj.close()
+ # Hash values computed manually with c_rehash to avoid depending on
+ # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other
+ # from OpenSSL 1.0.0.
+ for name in ['c7adac82.0', 'c3705638.0']:
+ cafile = join(capath, name)
+ fObj = open(cafile, 'w')
+ fObj.write(cleartextCertificatePEM.decode('ascii'))
+ fObj.close()
self._load_verify_locations_test(None, capath)
@@ -590,59 +701,6 @@ class ContextTests(TestCase, _LoopbackMixin):
self.assertRaises(TypeError, context.add_extra_chain_cert, object(), object())
- def _create_certificate_chain(self):
- """
- Construct and return a chain of certificates.
-
- 1. A new self-signed certificate authority certificate (cacert)
- 2. A new intermediate certificate signed by cacert (icert)
- 3. A new server certificate signed by icert (scert)
- """
- caext = X509Extension(b('basicConstraints'), False, b('CA:true'))
-
- # Step 1
- cakey = PKey()
- cakey.generate_key(TYPE_RSA, 512)
- cacert = X509()
- cacert.get_subject().commonName = "Authority Certificate"
- cacert.set_issuer(cacert.get_subject())
- cacert.set_pubkey(cakey)
- cacert.set_notBefore(b("20000101000000Z"))
- cacert.set_notAfter(b("20200101000000Z"))
- cacert.add_extensions([caext])
- cacert.set_serial_number(0)
- cacert.sign(cakey, "sha1")
-
- # Step 2
- ikey = PKey()
- ikey.generate_key(TYPE_RSA, 512)
- icert = X509()
- icert.get_subject().commonName = "Intermediate Certificate"
- icert.set_issuer(cacert.get_subject())
- icert.set_pubkey(ikey)
- icert.set_notBefore(b("20000101000000Z"))
- icert.set_notAfter(b("20200101000000Z"))
- icert.add_extensions([caext])
- icert.set_serial_number(0)
- icert.sign(cakey, "sha1")
-
- # Step 3
- skey = PKey()
- skey.generate_key(TYPE_RSA, 512)
- scert = X509()
- scert.get_subject().commonName = "Server Certificate"
- scert.set_issuer(icert.get_subject())
- scert.set_pubkey(skey)
- scert.set_notBefore(b("20000101000000Z"))
- scert.set_notAfter(b("20200101000000Z"))
- scert.add_extensions([
- X509Extension(b('basicConstraints'), True, b('CA:false'))])
- scert.set_serial_number(0)
- scert.sign(ikey, "sha1")
-
- return [(cakey, cacert), (ikey, icert), (skey, scert)]
-
-
def _handshake_test(self, serverContext, clientContext):
"""
Verify that a client and server created with the given contexts can
@@ -678,7 +736,7 @@ class ContextTests(TestCase, _LoopbackMixin):
to it with a client which trusts cacert and requires verification to
succeed.
"""
- chain = self._create_certificate_chain()
+ chain = _create_certificate_chain()
[(cakey, cacert), (ikey, icert), (skey, scert)] = chain
# Dump the CA certificate to a file because that's the only way to load
@@ -719,7 +777,7 @@ class ContextTests(TestCase, _LoopbackMixin):
to it with a client which trusts cacert and requires verification to
succeed.
"""
- chain = self._create_certificate_chain()
+ chain = _create_certificate_chain()
[(cakey, cacert), (ikey, icert), (skey, scert)] = chain
# Write out the chain file.
@@ -817,6 +875,106 @@ class ContextTests(TestCase, _LoopbackMixin):
+class ServerNameCallbackTests(TestCase, _LoopbackMixin):
+ """
+ Tests for L{Context.set_tlsext_servername_callback} and its interaction with
+ L{Connection}.
+ """
+ def test_wrong_args(self):
+ """
+ L{Context.set_tlsext_servername_callback} raises L{TypeError} if called
+ with other than one argument.
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.set_tlsext_servername_callback)
+ self.assertRaises(
+ TypeError, context.set_tlsext_servername_callback, 1, 2)
+
+ def test_old_callback_forgotten(self):
+ """
+ If L{Context.set_tlsext_servername_callback} is used to specify a new
+ callback, the one it replaces is dereferenced.
+ """
+ def callback(connection):
+ pass
+
+ def replacement(connection):
+ pass
+
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(callback)
+
+ tracker = ref(callback)
+ del callback
+
+ context.set_tlsext_servername_callback(replacement)
+ collect()
+ self.assertIdentical(None, tracker())
+
+
+ def test_no_servername(self):
+ """
+ When a client specifies no server name, the callback passed to
+ L{Context.set_tlsext_servername_callback} is invoked and the result of
+ L{Connection.get_servername} is C{None}.
+ """
+ args = []
+ def servername(conn):
+ args.append((conn, conn.get_servername()))
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(servername)
+
+ # Lose our reference to it. The Context is responsible for keeping it
+ # alive now.
+ del servername
+ collect()
+
+ # Necessary to actually accept the connection
+ context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(context, None)
+ server.set_accept_state()
+
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, None)], args)
+
+
+ def test_servername(self):
+ """
+ When a client specifies a server name in its hello message, the callback
+ passed to L{Contexts.set_tlsext_servername_callback} is invoked and the
+ result of L{Connection.get_servername} is that server name.
+ """
+ args = []
+ def servername(conn):
+ args.append((conn, conn.get_servername()))
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(servername)
+
+ # Necessary to actually accept the connection
+ context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(context, None)
+ server.set_accept_state()
+
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+ client.set_tlsext_host_name(b("foo1.example.com"))
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, b("foo1.example.com"))], args)
+
+
+
class ConnectionTests(TestCase, _LoopbackMixin):
"""
Unit tests for L{OpenSSL.SSL.Connection}.
@@ -868,6 +1026,71 @@ class ConnectionTests(TestCase, _LoopbackMixin):
self.assertRaises(TypeError, connection.get_context, None)
+ def test_set_context_wrong_args(self):
+ """
+ L{Connection.set_context} raises L{TypeError} if called with a
+ non-L{Context} instance argument or with any number of arguments other
+ than 1.
+ """
+ ctx = Context(TLSv1_METHOD)
+ connection = Connection(ctx, None)
+ self.assertRaises(TypeError, connection.set_context)
+ self.assertRaises(TypeError, connection.set_context, object())
+ self.assertRaises(TypeError, connection.set_context, "hello")
+ self.assertRaises(TypeError, connection.set_context, 1)
+ self.assertRaises(TypeError, connection.set_context, 1, 2)
+ self.assertRaises(
+ TypeError, connection.set_context, Context(TLSv1_METHOD), 2)
+ self.assertIdentical(ctx, connection.get_context())
+
+
+ def test_set_context(self):
+ """
+ L{Connection.set_context} specifies a new L{Context} instance to be used
+ for the connection.
+ """
+ original = Context(SSLv23_METHOD)
+ replacement = Context(TLSv1_METHOD)
+ connection = Connection(original, None)
+ connection.set_context(replacement)
+ self.assertIdentical(replacement, connection.get_context())
+ # Lose our references to the contexts, just in case the Connection isn't
+ # properly managing its own contributions to their reference counts.
+ del original, replacement
+ collect()
+
+
+ def test_set_tlsext_host_name_wrong_args(self):
+ """
+ If L{Connection.set_tlsext_host_name} is called with a non-byte string
+ argument or a byte string with an embedded NUL or other than one
+ argument, L{TypeError} is raised.
+ """
+ conn = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, conn.set_tlsext_host_name)
+ self.assertRaises(TypeError, conn.set_tlsext_host_name, object())
+ self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456)
+ self.assertRaises(
+ TypeError, conn.set_tlsext_host_name, b("with\0null"))
+
+ if version_info >= (3,):
+ # On Python 3.x, don't accidentally implicitly convert from text.
+ self.assertRaises(
+ TypeError,
+ conn.set_tlsext_host_name, b("example.com").decode("ascii"))
+
+
+ def test_get_servername_wrong_args(self):
+ """
+ L{Connection.get_servername} raises L{TypeError} if called with any
+ arguments.
+ """
+ connection = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, connection.get_servername, object())
+ self.assertRaises(TypeError, connection.get_servername, 1)
+ self.assertRaises(TypeError, connection.get_servername, "hello")
+
+
def test_pending(self):
"""
L{Connection.pending} returns the number of bytes available for
@@ -1047,6 +1270,68 @@ class ConnectionTests(TestCase, _LoopbackMixin):
self.assertRaises(NotImplementedError, conn.makefile)
+ def test_get_peer_cert_chain_wrong_args(self):
+ """
+ L{Connection.get_peer_cert_chain} raises L{TypeError} if called with any
+ arguments.
+ """
+ conn = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, conn.get_peer_cert_chain, 1)
+ self.assertRaises(TypeError, conn.get_peer_cert_chain, "foo")
+ self.assertRaises(TypeError, conn.get_peer_cert_chain, object())
+ self.assertRaises(TypeError, conn.get_peer_cert_chain, [])
+
+
+ def test_get_peer_cert_chain(self):
+ """
+ L{Connection.get_peer_cert_chain} returns a list of certificates which
+ the connected server returned for the certification verification.
+ """
+ chain = _create_certificate_chain()
+ [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
+
+ serverContext = Context(TLSv1_METHOD)
+ serverContext.use_privatekey(skey)
+ serverContext.use_certificate(scert)
+ serverContext.add_extra_chain_cert(icert)
+ serverContext.add_extra_chain_cert(cacert)
+ server = Connection(serverContext, None)
+ server.set_accept_state()
+
+ # Create the client
+ clientContext = Context(TLSv1_METHOD)
+ clientContext.set_verify(VERIFY_NONE, verify_cb)
+ client = Connection(clientContext, None)
+ client.set_connect_state()
+
+ self._interactInMemory(client, server)
+
+ chain = client.get_peer_cert_chain()
+ self.assertEqual(len(chain), 3)
+ self.assertEqual(
+ "Server Certificate", chain[0].get_subject().CN)
+ self.assertEqual(
+ "Intermediate Certificate", chain[1].get_subject().CN)
+ self.assertEqual(
+ "Authority Certificate", chain[2].get_subject().CN)
+
+
+ def test_get_peer_cert_chain_none(self):
+ """
+ L{Connection.get_peer_cert_chain} returns C{None} if the peer sends no
+ certificate chain.
+ """
+ ctx = Context(TLSv1_METHOD)
+ ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+ server = Connection(ctx, None)
+ server.set_accept_state()
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+ self._interactInMemory(client, server)
+ self.assertIdentical(None, server.get_peer_cert_chain())
+
+
class ConnectionGetCipherListTests(TestCase):
"""
@@ -1074,6 +1359,49 @@ class ConnectionGetCipherListTests(TestCase):
+class ConnectionSendTests(TestCase, _LoopbackMixin):
+ """
+ Tests for L{Connection.send}
+ """
+ def test_wrong_args(self):
+ """
+ When called with arguments other than a single string,
+ L{Connection.send} raises L{TypeError}.
+ """
+ connection = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, connection.send)
+ self.assertRaises(TypeError, connection.send, object())
+ self.assertRaises(TypeError, connection.send, "foo", "bar")
+
+
+ def test_short_bytes(self):
+ """
+ When passed a short byte string, L{Connection.send} transmits all of it
+ and returns the number of bytes sent.
+ """
+ server, client = self._loopback()
+ count = server.send(b('xy'))
+ self.assertEquals(count, 2)
+ self.assertEquals(client.recv(2), b('xy'))
+
+ try:
+ memoryview
+ except NameError:
+ "cannot test sending memoryview without memoryview"
+ else:
+ def test_short_memoryview(self):
+ """
+ When passed a memoryview onto a small number of bytes,
+ L{Connection.send} transmits all of them and returns the number of
+ bytes sent.
+ """
+ server, client = self._loopback()
+ count = server.send(memoryview(b('xy')))
+ self.assertEquals(count, 2)
+ self.assertEquals(client.recv(2), b('xy'))
+
+
+
class ConnectionSendallTests(TestCase, _LoopbackMixin):
"""
Tests for L{Connection.sendall}.
@@ -1099,6 +1427,21 @@ class ConnectionSendallTests(TestCase, _LoopbackMixin):
self.assertEquals(client.recv(1), b('x'))
+ try:
+ memoryview
+ except NameError:
+ "cannot test sending memoryview without memoryview"
+ else:
+ def test_short_memoryview(self):
+ """
+ When passed a memoryview onto a small number of bytes,
+ L{Connection.sendall} transmits all of them.
+ """
+ server, client = self._loopback()
+ server.sendall(memoryview(b('x')))
+ self.assertEquals(client.recv(1), b('x'))
+
+
def test_long(self):
"""
L{Connection.sendall} transmits all of the bytes in the string passed to
@@ -1617,6 +1960,28 @@ class MemoryBIOTests(TestCase, _LoopbackMixin):
self._check_client_ca_list(set_replaces_add_ca)
+class InfoConstantTests(TestCase):
+ """
+ Tests for assorted constants exposed for use in info callbacks.
+ """
+ def test_integers(self):
+ """
+ All of the info constants are integers.
+
+ This is a very weak test. It would be nice to have one that actually
+ verifies that as certain info events happen, the value passed to the
+ info callback matches up with the constant exposed by OpenSSL.SSL.
+ """
+ for const in [
+ SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, SSL_ST_INIT,
+ SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE,
+ SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT,
+ SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP,
+ SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT,
+ SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE]:
+
+ self.assertTrue(isinstance(const, int))
+
if __name__ == '__main__':
main()
diff --git a/OpenSSL/test/util.py b/OpenSSL/test/util.py
index 61246a6..643fb91 100644
--- a/OpenSSL/test/util.py
+++ b/OpenSSL/test/util.py
@@ -1,5 +1,5 @@
-# Copyright (C) Jean-Paul Calderone 2009, All rights reserved
-# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
+# Copyright (C) Jean-Paul Calderone
+# Copyright (C) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
@@ -15,16 +15,14 @@ import sys
from OpenSSL.crypto import Error, _exception_from_error_queue
-
-try:
- bytes = bytes
-except NameError:
+if sys.version_info < (3, 0):
def b(s):
return s
bytes = str
else:
def b(s):
- return s.encode("ascii")
+ return s.encode("charmap")
+ bytes = bytes
class TestCase(TestCase):
@@ -52,6 +50,22 @@ class TestCase(TestCase):
self.fail("Left over errors in OpenSSL error queue: " + repr(e))
+ def failUnlessIn(self, containee, container, msg=None):
+ """
+ Fail the test if C{containee} is not found in C{container}.
+
+ @param containee: the value that should be in C{container}
+ @param container: a sequence type, or in the case of a mapping type,
+ will follow semantics of 'if key in dict.keys()'
+ @param msg: if msg is None, then the failure message will be
+ '%r not in %r' % (first, second)
+ """
+ if containee not in container:
+ raise self.failureException(msg or "%r not in %r"
+ % (containee, container))
+ return containee
+ assertIn = failUnlessIn
+
def failUnlessIdentical(self, first, second, msg=None):
"""
Fail the test if C{first} is not C{second}. This is an