diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 261 |
1 files changed, 200 insertions, 61 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index 3d4997caf0..9264699886 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -92,12 +92,12 @@ import re import sys import os from collections import namedtuple -from enum import Enum as _Enum +from enum import Enum as _Enum, IntEnum as _IntEnum import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext +from _ssl import _SSLContext, MemoryBIO from _ssl import ( SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLSyscallError, SSLEOFError, @@ -106,7 +106,12 @@ from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import (VERIFY_DEFAULT, VERIFY_CRL_CHECK_LEAF, VERIFY_CRL_CHECK_CHAIN, VERIFY_X509_STRICT) from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj -from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes +from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes +try: + from _ssl import RAND_egd +except ImportError: + # LibreSSL does not provide RAND_egd + pass def _import_symbols(prefix): for n in dir(_ssl): @@ -119,30 +124,19 @@ _import_symbols('SSL_ERROR_') from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN -from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 from _ssl import _OPENSSL_API_VERSION +_SSLMethod = _IntEnum('_SSLMethod', + {name: value for name, value in vars(_ssl).items() + if name.startswith('PROTOCOL_')}) +globals().update(_SSLMethod.__members__) + +_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} -_PROTOCOL_NAMES = { - PROTOCOL_TLSv1: "TLSv1", - PROTOCOL_SSLv23: "SSLv23", - PROTOCOL_SSLv3: "SSLv3", -} try: - from _ssl import PROTOCOL_SSLv2 _SSLv2_IF_EXISTS = PROTOCOL_SSLv2 -except ImportError: +except NameError: _SSLv2_IF_EXISTS = None -else: - _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" - -try: - from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2 -except ImportError: - pass -else: - _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1" - _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2" if sys.platform == "win32": from _ssl import enum_certificates, enum_crls @@ -363,6 +357,12 @@ class SSLContext(_SSLContext): server_hostname=server_hostname, _context=self) + def wrap_bio(self, incoming, outgoing, server_side=False, + server_hostname=None): + sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side, + server_hostname=server_hostname) + return SSLObject(sslobj) + def set_npn_protocols(self, npn_protocols): protos = bytearray() for protocol in npn_protocols: @@ -490,6 +490,128 @@ _create_default_https_context = create_default_context _create_stdlib_context = _create_unverified_context +class SSLObject: + """This class implements an interface on top of a low-level SSL object as + implemented by OpenSSL. This object captures the state of an SSL connection + but does not provide any network IO itself. IO needs to be performed + through separate "BIO" objects which are OpenSSL's IO abstraction layer. + + This class does not have a public constructor. Instances are returned by + ``SSLContext.wrap_bio``. This class is typically used by framework authors + that want to implement asynchronous IO for SSL through memory buffers. + + When compared to ``SSLSocket``, this object lacks the following features: + + * Any form of network IO incluging methods such as ``recv`` and ``send``. + * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. + """ + + def __init__(self, sslobj, owner=None): + self._sslobj = sslobj + # Note: _sslobj takes a weak reference to owner + self._sslobj.owner = owner or self + + @property + def context(self): + """The SSLContext that is currently in use.""" + return self._sslobj.context + + @context.setter + def context(self, ctx): + self._sslobj.context = ctx + + @property + def server_side(self): + """Whether this is a server-side socket.""" + return self._sslobj.server_side + + @property + def server_hostname(self): + """The currently set server hostname (for SNI), or ``None`` if no + server hostame is set.""" + return self._sslobj.server_hostname + + def read(self, len=0, buffer=None): + """Read up to 'len' bytes from the SSL object and return them. + + If 'buffer' is provided, read into this buffer and return the number of + bytes read. + """ + if buffer is not None: + v = self._sslobj.read(len, buffer) + else: + v = self._sslobj.read(len or 1024) + return v + + def write(self, data): + """Write 'data' to the SSL object and return the number of bytes + written. + + The 'data' argument must support the buffer interface. + """ + return self._sslobj.write(data) + + def getpeercert(self, binary_form=False): + """Returns a formatted version of the data in the certificate provided + by the other end of the SSL channel. + + Return None if no certificate was provided, {} if a certificate was + provided, but not validated. + """ + return self._sslobj.peer_certificate(binary_form) + + def selected_npn_protocol(self): + """Return the currently selected NPN protocol as a string, or ``None`` + if a next protocol was not negotiated or if NPN is not supported by one + of the peers.""" + if _ssl.HAS_NPN: + return self._sslobj.selected_npn_protocol() + + def cipher(self): + """Return the currently selected cipher as a 3-tuple ``(name, + ssl_version, secret_bits)``.""" + return self._sslobj.cipher() + + def compression(self): + """Return the current compression algorithm in use, or ``None`` if + compression was not negotiated or not supported by one of the peers.""" + return self._sslobj.compression() + + def pending(self): + """Return the number of bytes that can be read immediately.""" + return self._sslobj.pending() + + def do_handshake(self): + """Start the SSL/TLS handshake.""" + self._sslobj.do_handshake() + if self.context.check_hostname: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) + + def unwrap(self): + """Start the SSL shutdown handshake.""" + return self._sslobj.shutdown() + + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake).""" + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError("Unsupported channel binding type") + if cb_type != "tls-unique": + raise NotImplementedError( + "{0} channel binding type not implemented" + .format(cb_type)) + return self._sslobj.tls_unique_cb() + + def version(self): + """Return a string identifying the protocol version used by the + current SSL channel. """ + return self._sslobj.version() + + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and @@ -572,8 +694,9 @@ class SSLSocket(socket): if connected: # create the SSL object try: - self._sslobj = self._context._wrap_socket(self, server_side, - server_hostname) + sslobj = self._context._wrap_socket(self, server_side, + server_hostname) + self._sslobj = SSLObject(sslobj, owner=self) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: @@ -618,11 +741,7 @@ class SSLSocket(socket): if not self._sslobj: raise ValueError("Read on closed or unwrapped SSL socket.") try: - if buffer is not None: - v = self._sslobj.read(len, buffer) - else: - v = self._sslobj.read(len or 1024) - return v + return self._sslobj.read(len, buffer) except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: if buffer is not None: @@ -649,7 +768,7 @@ class SSLSocket(socket): self._checkClosed() self._check_connected() - return self._sslobj.peer_certificate(binary_form) + return self._sslobj.getpeercert(binary_form) def selected_npn_protocol(self): self._checkClosed() @@ -679,17 +798,7 @@ class SSLSocket(socket): raise ValueError( "non-zero flags not allowed in calls to send() on %s" % self.__class__) - try: - v = self._sslobj.write(data) - except SSLError as x: - if x.args[0] == SSL_ERROR_WANT_READ: - return 0 - elif x.args[0] == SSL_ERROR_WANT_WRITE: - return 0 - else: - raise - else: - return v + return self._sslobj.write(data) else: return socket.send(self, data, flags) @@ -725,6 +834,16 @@ class SSLSocket(socket): else: return socket.sendall(self, data, flags) + def sendfile(self, file, offset=0, count=None): + """Send a file, possibly by using os.sendfile() if this is a + clear-text socket. Return the total number of bytes sent. + """ + if self._sslobj is None: + # os.sendfile() works with plain sockets only + return super().sendfile(file, offset, count) + else: + return self._sendfile_use_send(file, offset, count) + def recv(self, buflen=1024, flags=0): self._checkClosed() if self._sslobj: @@ -789,7 +908,7 @@ class SSLSocket(socket): def unwrap(self): if self._sslobj: - s = self._sslobj.shutdown() + s = self._sslobj.unwrap() self._sslobj = None return s else: @@ -810,12 +929,6 @@ class SSLSocket(socket): finally: self.settimeout(timeout) - if self.context.check_hostname: - if not self.server_hostname: - raise ValueError("check_hostname needs server_hostname " - "argument") - match_hostname(self.getpeercert(), self.server_hostname) - def _real_connect(self, addr, connect_ex): if self.server_side: raise ValueError("can't connect in server-side mode") @@ -823,7 +936,8 @@ class SSLSocket(socket): # connected at the time of the call. We connect it, then wrap it. if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") - self._sslobj = self.context._wrap_socket(self, False, self.server_hostname) + sslobj = self.context._wrap_socket(self, False, self.server_hostname) + self._sslobj = SSLObject(sslobj, owner=self) try: if connect_ex: rc = socket.connect_ex(self, addr) @@ -866,15 +980,18 @@ class SSLSocket(socket): if the requested `cb_type` is not supported. Return bytes of the data or None if the data is not available (e.g. before the handshake). """ - if cb_type not in CHANNEL_BINDING_TYPES: - raise ValueError("Unsupported channel binding type") - if cb_type != "tls-unique": - raise NotImplementedError( - "{0} channel binding type not implemented" - .format(cb_type)) if self._sslobj is None: return None - return self._sslobj.tls_unique_cb() + return self._sslobj.get_channel_binding(cb_type) + + def version(self): + """ + Return a string identifying the protocol version used by the + current SSL channel, or None if there is no established channel. + """ + if self._sslobj is None: + return None + return self._sslobj.version() def wrap_socket(sock, keyfile=None, certfile=None, @@ -894,12 +1011,34 @@ def wrap_socket(sock, keyfile=None, certfile=None, # some utility functions def cert_time_to_seconds(cert_time): - """Takes a date-time string in standard ASN1_print form - ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return - a Python time value in seconds past the epoch.""" + """Return the time in seconds since the Epoch, given the timestring + representing the "notBefore" or "notAfter" date from a certificate + in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale). + + "notBefore" or "notAfter" dates must use UTC (RFC 5280). + + Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec + UTC should be specified as GMT (see ASN1_TIME_print()) + """ + from time import strptime + from calendar import timegm - import time - return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) + months = ( + "Jan","Feb","Mar","Apr","May","Jun", + "Jul","Aug","Sep","Oct","Nov","Dec" + ) + time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT + try: + month_number = months.index(cert_time[:3].title()) + 1 + except ValueError: + raise ValueError('time data %r does not match ' + 'format "%%b%s"' % (cert_time, time_format)) + else: + # found valid month + tt = strptime(cert_time[3:], time_format) + # return an integer, the previous mktime()-based implementation + # returned a float (fractional seconds are always zero here). + return timegm((tt[0], month_number) + tt[2:6]) PEM_HEADER = "-----BEGIN CERTIFICATE-----" PEM_FOOTER = "-----END CERTIFICATE-----" @@ -926,7 +1065,7 @@ def PEM_cert_to_DER_cert(pem_cert_string): d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] return base64.decodebytes(d.encode('ASCII', 'strict')) -def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): +def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): """Retrieve the certificate from the server at the specified address, and return it as a PEM-encoded string. If 'ca_certs' is specified, validate the server cert against it. |