diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 310 |
1 files changed, 253 insertions, 57 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index ec42e38d08..ab7a49b576 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -87,17 +87,18 @@ ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY """ +import ipaddress import textwrap import re import sys import os from collections import namedtuple -from enum import Enum as _Enum +from enum import Enum as _Enum, IntEnum as _IntEnum import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext +from _ssl import _SSLContext, MemoryBIO from _ssl import ( SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLSyscallError, SSLEOFError, @@ -119,30 +120,23 @@ def _import_symbols(prefix): _import_symbols('OP_') _import_symbols('ALERT_DESCRIPTION_') _import_symbols('SSL_ERROR_') -_import_symbols('PROTOCOL_') _import_symbols('VERIFY_') -from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN from _ssl import _OPENSSL_API_VERSION +_IntEnum._convert( + '_SSLMethod', __name__, + lambda name: name.startswith('PROTOCOL_'), + source=_ssl) + +_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} -_PROTOCOL_NAMES = {value: name for name, value in globals().items() if name.startswith('PROTOCOL_')} try: - from _ssl import PROTOCOL_SSLv2 _SSLv2_IF_EXISTS = PROTOCOL_SSLv2 -except ImportError: +except NameError: _SSLv2_IF_EXISTS = None -else: - _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" - -try: - from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2 -except ImportError: - pass -else: - _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1" - _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2" if sys.platform == "win32": from _ssl import enum_certificates, enum_crls @@ -246,6 +240,17 @@ def _dnsname_match(dn, hostname, max_wildcards=1): return pat.match(hostname) +def _ipaddress_match(ipname, host_ip): + """Exact matching of IP addresses. + + RFC 6125 explicitly doesn't define an algorithm for this + (section 1.7.2 - "Out of Scope"). + """ + # OpenSSL may add a trailing newline to a subjectAltName's IP address + ip = ipaddress.ip_address(ipname.rstrip()) + return ip == host_ip + + def match_hostname(cert, hostname): """Verify that *cert* (in decoded format as returned by SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 @@ -258,11 +263,20 @@ def match_hostname(cert, hostname): raise ValueError("empty or no certificate, match_hostname needs a " "SSL socket or SSL context with either " "CERT_OPTIONAL or CERT_REQUIRED") + try: + host_ip = ipaddress.ip_address(hostname) + except ValueError: + # Not an IP address (common case) + host_ip = None dnsnames = [] san = cert.get('subjectAltName', ()) for key, value in san: if key == 'DNS': - if _dnsname_match(value, hostname): + if host_ip is None and _dnsname_match(value, hostname): + return + dnsnames.append(value) + elif key == 'IP Address': + if host_ip is not None and _ipaddress_match(value, host_ip): return dnsnames.append(value) if not dnsnames: @@ -361,6 +375,12 @@ class SSLContext(_SSLContext): server_hostname=server_hostname, _context=self) + def wrap_bio(self, incoming, outgoing, server_side=False, + server_hostname=None): + sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side, + server_hostname=server_hostname) + return SSLObject(sslobj) + def set_npn_protocols(self, npn_protocols): protos = bytearray() for protocol in npn_protocols: @@ -372,6 +392,17 @@ class SSLContext(_SSLContext): self._set_npn_protocols(protos) + def set_alpn_protocols(self, alpn_protocols): + protos = bytearray() + for protocol in alpn_protocols: + b = bytes(protocol, 'ascii') + if len(b) == 0 or len(b) > 255: + raise SSLError('ALPN protocols must be 1 to 255 in length') + protos.append(len(b)) + protos.extend(b) + + self._set_alpn_protocols(protos) + def _load_windows_store_certs(self, storename, purpose): certs = bytearray() for cert, encoding, trust in enum_certificates(storename): @@ -488,6 +519,141 @@ _create_default_https_context = create_default_context _create_stdlib_context = _create_unverified_context +class SSLObject: + """This class implements an interface on top of a low-level SSL object as + implemented by OpenSSL. This object captures the state of an SSL connection + but does not provide any network IO itself. IO needs to be performed + through separate "BIO" objects which are OpenSSL's IO abstraction layer. + + This class does not have a public constructor. Instances are returned by + ``SSLContext.wrap_bio``. This class is typically used by framework authors + that want to implement asynchronous IO for SSL through memory buffers. + + When compared to ``SSLSocket``, this object lacks the following features: + + * Any form of network IO incluging methods such as ``recv`` and ``send``. + * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. + """ + + def __init__(self, sslobj, owner=None): + self._sslobj = sslobj + # Note: _sslobj takes a weak reference to owner + self._sslobj.owner = owner or self + + @property + def context(self): + """The SSLContext that is currently in use.""" + return self._sslobj.context + + @context.setter + def context(self, ctx): + self._sslobj.context = ctx + + @property + def server_side(self): + """Whether this is a server-side socket.""" + return self._sslobj.server_side + + @property + def server_hostname(self): + """The currently set server hostname (for SNI), or ``None`` if no + server hostame is set.""" + return self._sslobj.server_hostname + + def read(self, len=0, buffer=None): + """Read up to 'len' bytes from the SSL object and return them. + + If 'buffer' is provided, read into this buffer and return the number of + bytes read. + """ + if buffer is not None: + v = self._sslobj.read(len, buffer) + else: + v = self._sslobj.read(len or 1024) + return v + + def write(self, data): + """Write 'data' to the SSL object and return the number of bytes + written. + + The 'data' argument must support the buffer interface. + """ + return self._sslobj.write(data) + + def getpeercert(self, binary_form=False): + """Returns a formatted version of the data in the certificate provided + by the other end of the SSL channel. + + Return None if no certificate was provided, {} if a certificate was + provided, but not validated. + """ + return self._sslobj.peer_certificate(binary_form) + + def selected_npn_protocol(self): + """Return the currently selected NPN protocol as a string, or ``None`` + if a next protocol was not negotiated or if NPN is not supported by one + of the peers.""" + if _ssl.HAS_NPN: + return self._sslobj.selected_npn_protocol() + + def selected_alpn_protocol(self): + """Return the currently selected ALPN protocol as a string, or ``None`` + if a next protocol was not negotiated or if ALPN is not supported by one + of the peers.""" + if _ssl.HAS_ALPN: + return self._sslobj.selected_alpn_protocol() + + def cipher(self): + """Return the currently selected cipher as a 3-tuple ``(name, + ssl_version, secret_bits)``.""" + return self._sslobj.cipher() + + def shared_ciphers(self): + """Return a list of ciphers shared by the client during the handshake or + None if this is not a valid server connection. + """ + return self._sslobj.shared_ciphers() + + def compression(self): + """Return the current compression algorithm in use, or ``None`` if + compression was not negotiated or not supported by one of the peers.""" + return self._sslobj.compression() + + def pending(self): + """Return the number of bytes that can be read immediately.""" + return self._sslobj.pending() + + def do_handshake(self): + """Start the SSL/TLS handshake.""" + self._sslobj.do_handshake() + if self.context.check_hostname: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) + + def unwrap(self): + """Start the SSL shutdown handshake.""" + return self._sslobj.shutdown() + + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake).""" + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError("Unsupported channel binding type") + if cb_type != "tls-unique": + raise NotImplementedError( + "{0} channel binding type not implemented" + .format(cb_type)) + return self._sslobj.tls_unique_cb() + + def version(self): + """Return a string identifying the protocol version used by the + current SSL channel. """ + return self._sslobj.version() + + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and @@ -570,8 +736,9 @@ class SSLSocket(socket): if connected: # create the SSL object try: - self._sslobj = self._context._wrap_socket(self, server_side, - server_hostname) + sslobj = self._context._wrap_socket(self, server_side, + server_hostname) + self._sslobj = SSLObject(sslobj, owner=self) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: @@ -616,11 +783,7 @@ class SSLSocket(socket): if not self._sslobj: raise ValueError("Read on closed or unwrapped SSL socket.") try: - if buffer is not None: - v = self._sslobj.read(len, buffer) - else: - v = self._sslobj.read(len or 1024) - return v + return self._sslobj.read(len, buffer) except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: if buffer is not None: @@ -647,7 +810,7 @@ class SSLSocket(socket): self._checkClosed() self._check_connected() - return self._sslobj.peer_certificate(binary_form) + return self._sslobj.getpeercert(binary_form) def selected_npn_protocol(self): self._checkClosed() @@ -656,6 +819,13 @@ class SSLSocket(socket): else: return self._sslobj.selected_npn_protocol() + def selected_alpn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_ALPN: + return None + else: + return self._sslobj.selected_alpn_protocol() + def cipher(self): self._checkClosed() if not self._sslobj: @@ -663,6 +833,12 @@ class SSLSocket(socket): else: return self._sslobj.cipher() + def shared_ciphers(self): + self._checkClosed() + if not self._sslobj: + return None + return self._sslobj.shared_ciphers() + def compression(self): self._checkClosed() if not self._sslobj: @@ -677,17 +853,7 @@ class SSLSocket(socket): raise ValueError( "non-zero flags not allowed in calls to send() on %s" % self.__class__) - try: - v = self._sslobj.write(data) - except SSLError as x: - if x.args[0] == SSL_ERROR_WANT_READ: - return 0 - elif x.args[0] == SSL_ERROR_WANT_WRITE: - return 0 - else: - raise - else: - return v + return self._sslobj.write(data) else: return socket.send(self, data, flags) @@ -723,6 +889,16 @@ class SSLSocket(socket): else: return socket.sendall(self, data, flags) + def sendfile(self, file, offset=0, count=None): + """Send a file, possibly by using os.sendfile() if this is a + clear-text socket. Return the total number of bytes sent. + """ + if self._sslobj is None: + # os.sendfile() works with plain sockets only + return super().sendfile(file, offset, count) + else: + return self._sendfile_use_send(file, offset, count) + def recv(self, buflen=1024, flags=0): self._checkClosed() if self._sslobj: @@ -787,7 +963,7 @@ class SSLSocket(socket): def unwrap(self): if self._sslobj: - s = self._sslobj.shutdown() + s = self._sslobj.unwrap() self._sslobj = None return s else: @@ -808,12 +984,6 @@ class SSLSocket(socket): finally: self.settimeout(timeout) - if self.context.check_hostname: - if not self.server_hostname: - raise ValueError("check_hostname needs server_hostname " - "argument") - match_hostname(self.getpeercert(), self.server_hostname) - def _real_connect(self, addr, connect_ex): if self.server_side: raise ValueError("can't connect in server-side mode") @@ -821,7 +991,8 @@ class SSLSocket(socket): # connected at the time of the call. We connect it, then wrap it. if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") - self._sslobj = self.context._wrap_socket(self, False, self.server_hostname) + sslobj = self.context._wrap_socket(self, False, self.server_hostname) + self._sslobj = SSLObject(sslobj, owner=self) try: if connect_ex: rc = socket.connect_ex(self, addr) @@ -864,15 +1035,18 @@ class SSLSocket(socket): if the requested `cb_type` is not supported. Return bytes of the data or None if the data is not available (e.g. before the handshake). """ - if cb_type not in CHANNEL_BINDING_TYPES: - raise ValueError("Unsupported channel binding type") - if cb_type != "tls-unique": - raise NotImplementedError( - "{0} channel binding type not implemented" - .format(cb_type)) if self._sslobj is None: return None - return self._sslobj.tls_unique_cb() + return self._sslobj.get_channel_binding(cb_type) + + def version(self): + """ + Return a string identifying the protocol version used by the + current SSL channel, or None if there is no established channel. + """ + if self._sslobj is None: + return None + return self._sslobj.version() def wrap_socket(sock, keyfile=None, certfile=None, @@ -892,12 +1066,34 @@ def wrap_socket(sock, keyfile=None, certfile=None, # some utility functions def cert_time_to_seconds(cert_time): - """Takes a date-time string in standard ASN1_print form - ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return - a Python time value in seconds past the epoch.""" + """Return the time in seconds since the Epoch, given the timestring + representing the "notBefore" or "notAfter" date from a certificate + in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale). + + "notBefore" or "notAfter" dates must use UTC (RFC 5280). - import time - return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) + Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec + UTC should be specified as GMT (see ASN1_TIME_print()) + """ + from time import strptime + from calendar import timegm + + months = ( + "Jan","Feb","Mar","Apr","May","Jun", + "Jul","Aug","Sep","Oct","Nov","Dec" + ) + time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT + try: + month_number = months.index(cert_time[:3].title()) + 1 + except ValueError: + raise ValueError('time data %r does not match ' + 'format "%%b%s"' % (cert_time, time_format)) + else: + # found valid month + tt = strptime(cert_time[3:], time_format) + # return an integer, the previous mktime()-based implementation + # returned a float (fractional seconds are always zero here). + return timegm((tt[0], month_number) + tt[2:6]) PEM_HEADER = "-----BEGIN CERTIFICATE-----" PEM_FOOTER = "-----END CERTIFICATE-----" |