summaryrefslogtreecommitdiff
path: root/src/M2Crypto/SSL/Checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/M2Crypto/SSL/Checker.py')
-rw-r--r--src/M2Crypto/SSL/Checker.py296
1 files changed, 296 insertions, 0 deletions
diff --git a/src/M2Crypto/SSL/Checker.py b/src/M2Crypto/SSL/Checker.py
new file mode 100644
index 0000000..46d397a
--- /dev/null
+++ b/src/M2Crypto/SSL/Checker.py
@@ -0,0 +1,296 @@
+"""
+SSL peer certificate checking routines
+
+Copyright (c) 2004-2007 Open Source Applications Foundation.
+All rights reserved.
+
+Copyright 2008 Heikki Toivonen. All rights reserved.
+"""
+
+__all__ = ['SSLVerificationError', 'NoCertificate', 'WrongCertificate',
+ 'WrongHost', 'Checker']
+
+import re
+import socket
+
+from M2Crypto import X509, m2, six # noqa
+from typing import AnyStr, Optional # noqa
+
+
+class SSLVerificationError(Exception):
+ pass
+
+
+class NoCertificate(SSLVerificationError):
+ pass
+
+
+class WrongCertificate(SSLVerificationError):
+ pass
+
+
+class WrongHost(SSLVerificationError):
+ def __init__(self, expectedHost, actualHost, fieldName='commonName'):
+ # type: (str, AnyStr, str) -> None
+ """
+ This exception will be raised if the certificate returned by the
+ peer was issued for a different host than we tried to connect to.
+ This could be due to a server misconfiguration or an active attack.
+
+ :param expectedHost: The name of the host we expected to find in the
+ certificate.
+ :param actualHost: The name of the host we actually found in the
+ certificate.
+ :param fieldName: The field name where we noticed the error. This
+ should be either 'commonName' or 'subjectAltName'.
+ """
+ if fieldName not in ('commonName', 'subjectAltName'):
+ raise ValueError(
+ 'Unknown fieldName, should be either commonName ' +
+ 'or subjectAltName')
+
+ SSLVerificationError.__init__(self)
+ self.expectedHost = expectedHost
+ self.actualHost = actualHost
+ self.fieldName = fieldName
+
+ def __str__(self):
+ # type: () -> str
+ s = 'Peer certificate %s does not match host, expected %s, got %s' \
+ % (self.fieldName, self.expectedHost, self.actualHost)
+ return six.ensure_text(s)
+
+
+class Checker(object):
+
+ numericIpMatch = re.compile('^[0-9]+(\.[0-9]+)*$')
+
+ def __init__(self, host=None, peerCertHash=None, peerCertDigest='sha1'):
+ # type: (Optional[str], Optional[bytes], str) -> None
+ self.host = host
+ if peerCertHash is not None:
+ peerCertHash = six.ensure_binary(peerCertHash)
+ self.fingerprint = peerCertHash
+ self.digest = peerCertDigest # type: str
+
+ def __call__(self, peerCert, host=None):
+ # type: (X509.X509, Optional[str]) -> bool
+ if peerCert is None:
+ raise NoCertificate('peer did not return certificate')
+
+ if host is not None:
+ self.host = host # type: str
+
+ if self.fingerprint:
+ if self.digest not in ('sha1', 'md5'):
+ raise ValueError('unsupported digest "%s"' % self.digest)
+
+ if self.digest == 'sha1':
+ expected_len = 40
+ elif self.digest == 'md5':
+ expected_len = 32
+ else:
+ raise ValueError('Unexpected digest {0}'.format(self.digest))
+
+ if len(self.fingerprint) != expected_len:
+ raise WrongCertificate(
+ ('peer certificate fingerprint length does not match\n' +
+ 'fingerprint: {0}\nexpected = {1}\n' +
+ 'observed = {2}').format(self.fingerprint,
+ expected_len,
+ len(self.fingerprint)))
+
+ expected_fingerprint = six.ensure_text(self.fingerprint)
+ observed_fingerprint = peerCert.get_fingerprint(md=self.digest)
+ if observed_fingerprint != expected_fingerprint:
+ raise WrongCertificate(
+ ('peer certificate fingerprint does not match\n' +
+ 'expected = {0},\n' +
+ 'observed = {1}').format(expected_fingerprint,
+ observed_fingerprint))
+
+ if self.host:
+ hostValidationPassed = False
+ self.useSubjectAltNameOnly = False
+
+ # subjectAltName=DNS:somehost[, ...]*
+ try:
+ subjectAltName = peerCert.get_ext('subjectAltName').get_value()
+ if self._splitSubjectAltName(self.host, subjectAltName):
+ hostValidationPassed = True
+ elif self.useSubjectAltNameOnly:
+ raise WrongHost(expectedHost=self.host,
+ actualHost=subjectAltName,
+ fieldName='subjectAltName')
+ except LookupError:
+ pass
+
+ # commonName=somehost[, ...]*
+ if not hostValidationPassed:
+ hasCommonName = False
+ commonNames = ''
+ for entry in peerCert.get_subject().get_entries_by_nid(
+ m2.NID_commonName):
+ hasCommonName = True
+ commonName = entry.get_data().as_text()
+ if not commonNames:
+ commonNames = commonName
+ else:
+ commonNames += ',' + commonName
+ if self._match(self.host, commonName):
+ hostValidationPassed = True
+ break
+
+ if not hasCommonName:
+ raise WrongCertificate('no commonName in peer certificate')
+
+ if not hostValidationPassed:
+ raise WrongHost(expectedHost=self.host,
+ actualHost=commonNames,
+ fieldName='commonName')
+
+ return True
+
+ def _splitSubjectAltName(self, host, subjectAltName):
+ # type: (AnyStr, AnyStr) -> bool
+ """
+ >>> check = Checker()
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:my.example.com')
+ True
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:*.example.com')
+ True
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:m*.example.com')
+ True
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:m*ample.com')
+ False
+ >>> check.useSubjectAltNameOnly
+ True
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:m*ample.com, othername:<unsupported>')
+ False
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:m*ample.com, DNS:my.example.org')
+ False
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:m*ample.com, DNS:my.example.com')
+ True
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='DNS:my.example.com, DNS:my.example.org')
+ True
+ >>> check.useSubjectAltNameOnly
+ True
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='')
+ False
+ >>> check._splitSubjectAltName(host='my.example.com',
+ ... subjectAltName='othername:<unsupported>')
+ False
+ >>> check.useSubjectAltNameOnly
+ False
+ """
+ self.useSubjectAltNameOnly = False
+ for certHost in subjectAltName.split(','):
+ certHost = certHost.lower().strip()
+ if certHost[:4] == 'dns:':
+ self.useSubjectAltNameOnly = True
+ if self._match(host, certHost[4:]):
+ return True
+ elif certHost[:11] == 'ip address:':
+ self.useSubjectAltNameOnly = True
+ if self._matchIPAddress(host, certHost[11:]):
+ return True
+ return False
+
+ def _match(self, host, certHost):
+ # type: (str, str) -> bool
+ """
+ >>> check = Checker()
+ >>> check._match(host='my.example.com', certHost='my.example.com')
+ True
+ >>> check._match(host='my.example.com', certHost='*.example.com')
+ True
+ >>> check._match(host='my.example.com', certHost='m*.example.com')
+ True
+ >>> check._match(host='my.example.com', certHost='m*.EXAMPLE.com')
+ True
+ >>> check._match(host='my.example.com', certHost='m*ample.com')
+ False
+ >>> check._match(host='my.example.com', certHost='*.*.com')
+ False
+ >>> check._match(host='1.2.3.4', certHost='1.2.3.4')
+ True
+ >>> check._match(host='1.2.3.4', certHost='*.2.3.4')
+ False
+ >>> check._match(host='1234', certHost='1234')
+ True
+ """
+ # XXX See RFC 2818 and 3280 for matching rules, this is may not
+ # XXX yet be complete.
+
+ host = host.lower()
+ certHost = certHost.lower()
+
+ if host == certHost:
+ return True
+
+ if certHost.count('*') > 1:
+ # Not sure about this, but being conservative
+ return False
+
+ if self.numericIpMatch.match(host) or \
+ self.numericIpMatch.match(certHost.replace('*', '')):
+ # Not sure if * allowed in numeric IP, but think not.
+ return False
+
+ if certHost.find('\\') > -1:
+ # Not sure about this, maybe some encoding might have these.
+ # But being conservative for now, because regex below relies
+ # on this.
+ return False
+
+ # Massage certHost so that it can be used in regex
+ certHost = certHost.replace('.', '\.')
+ certHost = certHost.replace('*', '[^\.]*')
+ if re.compile('^%s$' % certHost).match(host):
+ return True
+
+ return False
+
+ def _matchIPAddress(self, host, certHost):
+ # type: (AnyStr, AnyStr) -> bool
+ """
+ >>> check = Checker()
+ >>> check._matchIPAddress(host='my.example.com',
+ ... certHost='my.example.com')
+ False
+ >>> check._matchIPAddress(host='1.2.3.4', certHost='1.2.3.4')
+ True
+ >>> check._matchIPAddress(host='1.2.3.4', certHost='*.2.3.4')
+ False
+ >>> check._matchIPAddress(host='1.2.3.4', certHost='1.2.3.40')
+ False
+ >>> check._matchIPAddress(host='::1', certHost='::1')
+ True
+ >>> check._matchIPAddress(host='::1', certHost='0:0:0:0:0:0:0:1')
+ True
+ >>> check._matchIPAddress(host='::1', certHost='::2')
+ False
+ """
+ try:
+ canonical = socket.getaddrinfo(host, 0, 0, socket.SOCK_STREAM, 0,
+ socket.AI_NUMERICHOST)
+ certCanonical = socket.getaddrinfo(certHost, 0, 0,
+ socket.SOCK_STREAM, 0,
+ socket.AI_NUMERICHOST)
+ except:
+ return False
+ return canonical == certCanonical
+
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()