diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 179 |
1 files changed, 122 insertions, 57 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index e1913904f3..8ad4a339a9 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -52,6 +52,8 @@ PROTOCOL_SSLv2 PROTOCOL_SSLv3 PROTOCOL_SSLv23 PROTOCOL_TLS +PROTOCOL_TLS_CLIENT +PROTOCOL_TLS_SERVER PROTOCOL_TLSv1 PROTOCOL_TLSv1_1 PROTOCOL_TLSv1_2 @@ -94,17 +96,16 @@ import re import sys import os from collections import namedtuple -from enum import Enum as _Enum, IntEnum as _IntEnum +from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, MemoryBIO +from _ssl import _SSLContext, MemoryBIO, SSLSession from _ssl import ( SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLSyscallError, SSLEOFError, ) -from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes try: @@ -113,32 +114,47 @@ except ImportError: # LibreSSL does not provide RAND_egd pass -def _import_symbols(prefix): - for n in dir(_ssl): - if n.startswith(prefix): - globals()[n] = getattr(_ssl, n) - -_import_symbols('OP_') -_import_symbols('ALERT_DESCRIPTION_') -_import_symbols('SSL_ERROR_') -_import_symbols('VERIFY_') from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN - from _ssl import _OPENSSL_API_VERSION + +_IntEnum._convert( + '_SSLMethod', __name__, + lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23', + source=_ssl) + +_IntFlag._convert( + 'Options', __name__, + lambda name: name.startswith('OP_'), + source=_ssl) + +_IntEnum._convert( + 'AlertDescription', __name__, + lambda name: name.startswith('ALERT_DESCRIPTION_'), + source=_ssl) + +_IntEnum._convert( + 'SSLErrorNumber', __name__, + lambda name: name.startswith('SSL_ERROR_'), + source=_ssl) + +_IntFlag._convert( + 'VerifyFlags', __name__, + lambda name: name.startswith('VERIFY_'), + source=_ssl) + _IntEnum._convert( - '_SSLMethod', __name__, - lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23', - source=_ssl) + 'VerifyMode', __name__, + lambda name: name.startswith('CERT_'), + source=_ssl) + PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} -try: - _SSLv2_IF_EXISTS = PROTOCOL_SSLv2 -except NameError: - _SSLv2_IF_EXISTS = None +_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None) + if sys.platform == "win32": from _ssl import enum_certificates, enum_crls @@ -377,18 +393,18 @@ class SSLContext(_SSLContext): def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, suppress_ragged_eofs=True, - server_hostname=None): + server_hostname=None, session=None): return SSLSocket(sock=sock, server_side=server_side, do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs, server_hostname=server_hostname, - _context=self) + _context=self, _session=session) def wrap_bio(self, incoming, outgoing, server_side=False, - server_hostname=None): + server_hostname=None, session=None): sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side, server_hostname=server_hostname) - return SSLObject(sslobj) + return SSLObject(sslobj, session=session) def set_npn_protocols(self, npn_protocols): protos = bytearray() @@ -434,6 +450,34 @@ class SSLContext(_SSLContext): self._load_windows_store_certs(storename, purpose) self.set_default_verify_paths() + @property + def options(self): + return Options(super().options) + + @options.setter + def options(self, value): + super(SSLContext, SSLContext).options.__set__(self, value) + + @property + def verify_flags(self): + return VerifyFlags(super().verify_flags) + + @verify_flags.setter + def verify_flags(self, value): + super(SSLContext, SSLContext).verify_flags.__set__(self, value) + + @property + def verify_mode(self): + value = super().verify_mode + try: + return VerifyMode(value) + except ValueError: + return value + + @verify_mode.setter + def verify_mode(self, value): + super(SSLContext, SSLContext).verify_mode.__set__(self, value) + def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, capath=None, cadata=None): @@ -446,32 +490,16 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, if not isinstance(purpose, _ASN1Object): raise TypeError(purpose) + # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, + # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE + # by default. context = SSLContext(PROTOCOL_TLS) - # SSLv2 considered harmful. - context.options |= OP_NO_SSLv2 - - # SSLv3 has problematic security and is only required for really old - # clients such as IE6 on Windows XP - context.options |= OP_NO_SSLv3 - - # disable compression to prevent CRIME attacks (OpenSSL 1.0+) - context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0) - if purpose == Purpose.SERVER_AUTH: # verify certs and host name in client mode context.verify_mode = CERT_REQUIRED context.check_hostname = True elif purpose == Purpose.CLIENT_AUTH: - # Prefer the server's ciphers by default so that we get stronger - # encryption - context.options |= getattr(_ssl, "OP_CIPHER_SERVER_PREFERENCE", 0) - - # Use single use keys in order to improve forward secrecy - context.options |= getattr(_ssl, "OP_SINGLE_DH_USE", 0) - context.options |= getattr(_ssl, "OP_SINGLE_ECDH_USE", 0) - - # disallow ciphers with known vulnerabilities context.set_ciphers(_RESTRICTED_SERVER_CIPHERS) if cafile or capath or cadata: @@ -497,12 +525,10 @@ def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, if not isinstance(purpose, _ASN1Object): raise TypeError(purpose) + # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, + # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE + # by default. context = SSLContext(protocol) - # SSLv2 considered harmful. - context.options |= OP_NO_SSLv2 - # SSLv3 has problematic security and is only required for really old - # clients such as IE6 on Windows XP - context.options |= OP_NO_SSLv3 if cert_reqs is not None: context.verify_mode = cert_reqs @@ -548,10 +574,12 @@ class SSLObject: * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. """ - def __init__(self, sslobj, owner=None): + def __init__(self, sslobj, owner=None, session=None): self._sslobj = sslobj # Note: _sslobj takes a weak reference to owner self._sslobj.owner = owner or self + if session is not None: + self._sslobj.session = session @property def context(self): @@ -563,6 +591,20 @@ class SSLObject: self._sslobj.context = ctx @property + def session(self): + """The SSLSession for client socket.""" + return self._sslobj.session + + @session.setter + def session(self, session): + self._sslobj.session = session + + @property + def session_reused(self): + """Was the client session reused during handshake""" + return self._sslobj.session_reused + + @property def server_side(self): """Whether this is a server-side socket.""" return self._sslobj.server_side @@ -679,7 +721,7 @@ class SSLSocket(socket): family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, server_hostname=None, - _context=None): + _context=None, _session=None): if _context: self._context = _context @@ -711,11 +753,16 @@ class SSLSocket(socket): # mixed in. if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: raise NotImplementedError("only stream sockets are supported") - if server_side and server_hostname: - raise ValueError("server_hostname can only be specified " - "in client mode") + if server_side: + if server_hostname: + raise ValueError("server_hostname can only be specified " + "in client mode") + if _session is not None: + raise ValueError("session can only be specified in " + "client mode") if self._context.check_hostname and not server_hostname: raise ValueError("check_hostname requires server_hostname") + self._session = _session self.server_side = server_side self.server_hostname = server_hostname self.do_handshake_on_connect = do_handshake_on_connect @@ -751,7 +798,8 @@ class SSLSocket(socket): try: sslobj = self._context._wrap_socket(self, server_side, server_hostname) - self._sslobj = SSLObject(sslobj, owner=self) + self._sslobj = SSLObject(sslobj, owner=self, + session=self._session) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: @@ -772,6 +820,24 @@ class SSLSocket(socket): self._context = ctx self._sslobj.context = ctx + @property + def session(self): + """The SSLSession for client socket.""" + if self._sslobj is not None: + return self._sslobj.session + + @session.setter + def session(self, session): + self._session = session + if self._sslobj is not None: + self._sslobj.session = session + + @property + def session_reused(self): + """Was the client session reused during handshake""" + if self._sslobj is not None: + return self._sslobj.session_reused + def dup(self): raise NotImplemented("Can't dup() %s instances" % self.__class__.__name__) @@ -898,7 +964,6 @@ class SSLSocket(socket): while (count < amount): v = self.send(data[count:]) count += v - return amount else: return socket.sendall(self, data, flags) @@ -1005,7 +1070,8 @@ class SSLSocket(socket): if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") sslobj = self.context._wrap_socket(self, False, self.server_hostname) - self._sslobj = SSLObject(sslobj, owner=self) + self._sslobj = SSLObject(sslobj, owner=self, + session=self._session) try: if connect_ex: rc = socket.connect_ex(self, addr) @@ -1068,7 +1134,6 @@ def wrap_socket(sock, keyfile=None, certfile=None, do_handshake_on_connect=True, suppress_ragged_eofs=True, ciphers=None): - return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, |