diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 177 |
1 files changed, 142 insertions, 35 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index 90c21ceff7..cd8d6b4c9e 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -60,10 +60,25 @@ import re import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, SSLError +from _ssl import _SSLContext +from _ssl import ( + SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, + SSLSyscallError, SSLEOFError, + ) from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED -from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1 -from _ssl import RAND_status, RAND_egd, RAND_add +from _ssl import ( + OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1, + OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE + ) +try: + from _ssl import OP_NO_COMPRESSION +except ImportError: + pass +try: + from _ssl import OP_SINGLE_ECDH_USE +except ImportError: + pass +from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes from _ssl import ( SSL_ERROR_ZERO_RETURN, SSL_ERROR_WANT_READ, @@ -75,8 +90,9 @@ from _ssl import ( SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, ) -from _ssl import HAS_SNI -from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN +from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23, + PROTOCOL_TLSv1) from _ssl import _OPENSSL_API_VERSION _PROTOCOL_NAMES = { @@ -94,11 +110,17 @@ else: from socket import getnameinfo as _getnameinfo from socket import error as socket_error -from socket import socket, AF_INET, SOCK_STREAM +from socket import socket, AF_INET, SOCK_STREAM, create_connection +from socket import SOL_SOCKET, SO_TYPE import base64 # for DER-to-PEM translation import traceback import errno +if _ssl.HAS_TLS_UNIQUE: + CHANNEL_BINDING_TYPES = ['tls-unique'] +else: + CHANNEL_BINDING_TYPES = [] + # Disable weak or insecure ciphers by default # (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL') _DEFAULT_CIPHERS = 'DEFAULT:!aNULL:!eNULL:!LOW:!EXPORT:!SSLv2' @@ -108,31 +130,59 @@ class CertificateError(ValueError): pass -def _dnsname_to_pat(dn, max_wildcards=1): +def _dnsname_match(dn, hostname, max_wildcards=1): + """Matching according to RFC 6125, section 6.4.3 + + http://tools.ietf.org/html/rfc6125#section-6.4.3 + """ pats = [] - for frag in dn.split(r'.'): - if frag.count('*') > max_wildcards: - # Issue #17980: avoid denials of service by refusing more - # than one wildcard per fragment. A survery of established - # policy among SSL implementations showed it to be a - # reasonable choice. - raise CertificateError( - "too many wildcards in certificate DNS name: " + repr(dn)) - if frag == '*': - # When '*' is a fragment by itself, it matches a non-empty dotless - # fragment. - pats.append('[^.]+') - else: - # Otherwise, '*' matches any dotless fragment. - frag = re.escape(frag) - pats.append(frag.replace(r'\*', '[^.]*')) - return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + if not dn: + return False + + leftmost, *remainder = dn.split(r'.') + + wildcards = leftmost.count('*') + if wildcards > max_wildcards: + # Issue #17980: avoid denials of service by refusing more + # than one wildcard per fragment. A survery of established + # policy among SSL implementations showed it to be a + # reasonable choice. + raise CertificateError( + "too many wildcards in certificate DNS name: " + repr(dn)) + + # speed up common case w/o wildcards + if not wildcards: + return dn.lower() == hostname.lower() + + # RFC 6125, section 6.4.3, subitem 1. + # The client SHOULD NOT attempt to match a presented identifier in which + # the wildcard character comprises a label other than the left-most label. + if leftmost == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + elif leftmost.startswith('xn--') or hostname.startswith('xn--'): + # RFC 6125, section 6.4.3, subitem 3. + # The client SHOULD NOT attempt to match a presented identifier + # where the wildcard character is embedded within an A-label or + # U-label of an internationalized domain name. + pats.append(re.escape(leftmost)) + else: + # Otherwise, '*' matches any dotless string, e.g. www* + pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) + + # add the remaining fragments, ignore any wildcards + for frag in remainder: + pats.append(re.escape(frag)) + + pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + return pat.match(hostname) def match_hostname(cert, hostname): """Verify that *cert* (in decoded format as returned by - SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules - are mostly followed, but IP addresses are not accepted for *hostname*. + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 + rules are followed, but IP addresses are not accepted for *hostname*. CertificateError is raised on failure. On success, the function returns nothing. @@ -143,7 +193,7 @@ def match_hostname(cert, hostname): san = cert.get('subjectAltName', ()) for key, value in san: if key == 'DNS': - if _dnsname_to_pat(value).match(hostname): + if _dnsname_match(value, hostname): return dnsnames.append(value) if not dnsnames: @@ -154,7 +204,7 @@ def match_hostname(cert, hostname): # XXX according to RFC 2818, the most specific Common Name # must be used. if key == 'commonName': - if _dnsname_to_pat(value).match(hostname): + if _dnsname_match(value, hostname): return dnsnames.append(value) if len(dnsnames) > 1: @@ -195,6 +245,17 @@ class SSLContext(_SSLContext): server_hostname=server_hostname, _context=self) + def set_npn_protocols(self, npn_protocols): + protos = bytearray() + for protocol in npn_protocols: + b = bytes(protocol, 'ascii') + if len(b) == 0 or len(b) > 255: + raise SSLError('NPN protocols must be 1 to 255 in length') + protos.append(len(b)) + protos.extend(b) + + self._set_npn_protocols(protos) + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps @@ -206,7 +267,7 @@ class SSLSocket(socket): ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, - suppress_ragged_eofs=True, ciphers=None, + suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, server_hostname=None, _context=None): @@ -226,6 +287,8 @@ class SSLSocket(socket): self.context.load_verify_locations(ca_certs) if certfile: self.context.load_cert_chain(certfile, keyfile) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.keyfile = keyfile @@ -234,6 +297,10 @@ class SSLSocket(socket): self.ssl_version = ssl_version self.ca_certs = ca_certs self.ciphers = ciphers + # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get + # mixed in. + if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: + raise NotImplementedError("only stream sockets are supported") if server_side and server_hostname: raise ValueError("server_hostname can only be specified " "in client mode") @@ -326,6 +393,13 @@ class SSLSocket(socket): self._checkClosed() return self._sslobj.peer_certificate(binary_form) + def selected_npn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_NPN: + return None + else: + return self._sslobj.selected_npn_protocol() + def cipher(self): self._checkClosed() if not self._sslobj: @@ -333,6 +407,13 @@ class SSLSocket(socket): else: return self._sslobj.cipher() + def compression(self): + self._checkClosed() + if not self._sslobj: + return None + else: + return self._sslobj.compression() + def send(self, data, flags=0): self._checkClosed() if self._sslobj: @@ -365,6 +446,12 @@ class SSLSocket(socket): else: return socket.sendto(self, data, flags_or_addr, addr) + def sendmsg(self, *args, **kwargs): + # Ensure programs don't send data unencrypted if they try to + # use this method. + raise NotImplementedError("sendmsg not allowed on instances of %s" % + self.__class__) + def sendall(self, data, flags=0): self._checkClosed() if self._sslobj: @@ -423,6 +510,14 @@ class SSLSocket(socket): else: return socket.recvfrom_into(self, buffer, nbytes, flags) + def recvmsg(self, *args, **kwargs): + raise NotImplementedError("recvmsg not allowed on instances of %s" % + self.__class__) + + def recvmsg_into(self, *args, **kwargs): + raise NotImplementedError("recvmsg_into not allowed on instances of " + "%s" % self.__class__) + def pending(self): self._checkClosed() if self._sslobj: @@ -504,16 +599,28 @@ class SSLSocket(socket): server_side=True) return newsock, addr - def __del__(self): - # sys.stderr.write("__del__ on %s\n" % repr(self)) - self._real_close() + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake). + """ + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError("Unsupported channel binding type") + if cb_type != "tls-unique": + raise NotImplementedError( + "{0} channel binding type not implemented" + .format(cb_type)) + if self._sslobj is None: + return None + return self._sslobj.tls_unique_cb() def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): + suppress_ragged_eofs=True, + ciphers=None): return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, @@ -568,9 +675,9 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): cert_reqs = CERT_REQUIRED else: cert_reqs = CERT_NONE - s = wrap_socket(socket(), ssl_version=ssl_version, + s = create_connection(addr) + s = wrap_socket(s, ssl_version=ssl_version, cert_reqs=cert_reqs, ca_certs=ca_certs) - s.connect(addr) dercert = s.getpeercert(True) s.close() return DER_cert_to_PEM_cert(dercert) |