summaryrefslogtreecommitdiff
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py179
1 files changed, 122 insertions, 57 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index e1913904f3..8ad4a339a9 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -52,6 +52,8 @@ PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLS
+PROTOCOL_TLS_CLIENT
+PROTOCOL_TLS_SERVER
PROTOCOL_TLSv1
PROTOCOL_TLSv1_1
PROTOCOL_TLSv1_2
@@ -94,17 +96,16 @@ import re
import sys
import os
from collections import namedtuple
-from enum import Enum as _Enum, IntEnum as _IntEnum
+from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
-from _ssl import _SSLContext, MemoryBIO
+from _ssl import _SSLContext, MemoryBIO, SSLSession
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError,
)
-from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes
try:
@@ -113,32 +114,47 @@ except ImportError:
# LibreSSL does not provide RAND_egd
pass
-def _import_symbols(prefix):
- for n in dir(_ssl):
- if n.startswith(prefix):
- globals()[n] = getattr(_ssl, n)
-
-_import_symbols('OP_')
-_import_symbols('ALERT_DESCRIPTION_')
-_import_symbols('SSL_ERROR_')
-_import_symbols('VERIFY_')
from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
-
from _ssl import _OPENSSL_API_VERSION
+
+_IntEnum._convert(
+ '_SSLMethod', __name__,
+ lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
+ source=_ssl)
+
+_IntFlag._convert(
+ 'Options', __name__,
+ lambda name: name.startswith('OP_'),
+ source=_ssl)
+
+_IntEnum._convert(
+ 'AlertDescription', __name__,
+ lambda name: name.startswith('ALERT_DESCRIPTION_'),
+ source=_ssl)
+
+_IntEnum._convert(
+ 'SSLErrorNumber', __name__,
+ lambda name: name.startswith('SSL_ERROR_'),
+ source=_ssl)
+
+_IntFlag._convert(
+ 'VerifyFlags', __name__,
+ lambda name: name.startswith('VERIFY_'),
+ source=_ssl)
+
_IntEnum._convert(
- '_SSLMethod', __name__,
- lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
- source=_ssl)
+ 'VerifyMode', __name__,
+ lambda name: name.startswith('CERT_'),
+ source=_ssl)
+
PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
-try:
- _SSLv2_IF_EXISTS = PROTOCOL_SSLv2
-except NameError:
- _SSLv2_IF_EXISTS = None
+_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
+
if sys.platform == "win32":
from _ssl import enum_certificates, enum_crls
@@ -377,18 +393,18 @@ class SSLContext(_SSLContext):
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
- server_hostname=None):
+ server_hostname=None, session=None):
return SSLSocket(sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
- _context=self)
+ _context=self, _session=session)
def wrap_bio(self, incoming, outgoing, server_side=False,
- server_hostname=None):
+ server_hostname=None, session=None):
sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
server_hostname=server_hostname)
- return SSLObject(sslobj)
+ return SSLObject(sslobj, session=session)
def set_npn_protocols(self, npn_protocols):
protos = bytearray()
@@ -434,6 +450,34 @@ class SSLContext(_SSLContext):
self._load_windows_store_certs(storename, purpose)
self.set_default_verify_paths()
+ @property
+ def options(self):
+ return Options(super().options)
+
+ @options.setter
+ def options(self, value):
+ super(SSLContext, SSLContext).options.__set__(self, value)
+
+ @property
+ def verify_flags(self):
+ return VerifyFlags(super().verify_flags)
+
+ @verify_flags.setter
+ def verify_flags(self, value):
+ super(SSLContext, SSLContext).verify_flags.__set__(self, value)
+
+ @property
+ def verify_mode(self):
+ value = super().verify_mode
+ try:
+ return VerifyMode(value)
+ except ValueError:
+ return value
+
+ @verify_mode.setter
+ def verify_mode(self, value):
+ super(SSLContext, SSLContext).verify_mode.__set__(self, value)
+
def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
capath=None, cadata=None):
@@ -446,32 +490,16 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
if not isinstance(purpose, _ASN1Object):
raise TypeError(purpose)
+ # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
+ # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
+ # by default.
context = SSLContext(PROTOCOL_TLS)
- # SSLv2 considered harmful.
- context.options |= OP_NO_SSLv2
-
- # SSLv3 has problematic security and is only required for really old
- # clients such as IE6 on Windows XP
- context.options |= OP_NO_SSLv3
-
- # disable compression to prevent CRIME attacks (OpenSSL 1.0+)
- context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0)
-
if purpose == Purpose.SERVER_AUTH:
# verify certs and host name in client mode
context.verify_mode = CERT_REQUIRED
context.check_hostname = True
elif purpose == Purpose.CLIENT_AUTH:
- # Prefer the server's ciphers by default so that we get stronger
- # encryption
- context.options |= getattr(_ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
-
- # Use single use keys in order to improve forward secrecy
- context.options |= getattr(_ssl, "OP_SINGLE_DH_USE", 0)
- context.options |= getattr(_ssl, "OP_SINGLE_ECDH_USE", 0)
-
- # disallow ciphers with known vulnerabilities
context.set_ciphers(_RESTRICTED_SERVER_CIPHERS)
if cafile or capath or cadata:
@@ -497,12 +525,10 @@ def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None,
if not isinstance(purpose, _ASN1Object):
raise TypeError(purpose)
+ # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
+ # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
+ # by default.
context = SSLContext(protocol)
- # SSLv2 considered harmful.
- context.options |= OP_NO_SSLv2
- # SSLv3 has problematic security and is only required for really old
- # clients such as IE6 on Windows XP
- context.options |= OP_NO_SSLv3
if cert_reqs is not None:
context.verify_mode = cert_reqs
@@ -548,10 +574,12 @@ class SSLObject:
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""
- def __init__(self, sslobj, owner=None):
+ def __init__(self, sslobj, owner=None, session=None):
self._sslobj = sslobj
# Note: _sslobj takes a weak reference to owner
self._sslobj.owner = owner or self
+ if session is not None:
+ self._sslobj.session = session
@property
def context(self):
@@ -563,6 +591,20 @@ class SSLObject:
self._sslobj.context = ctx
@property
+ def session(self):
+ """The SSLSession for client socket."""
+ return self._sslobj.session
+
+ @session.setter
+ def session(self, session):
+ self._sslobj.session = session
+
+ @property
+ def session_reused(self):
+ """Was the client session reused during handshake"""
+ return self._sslobj.session_reused
+
+ @property
def server_side(self):
"""Whether this is a server-side socket."""
return self._sslobj.server_side
@@ -679,7 +721,7 @@ class SSLSocket(socket):
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None,
- _context=None):
+ _context=None, _session=None):
if _context:
self._context = _context
@@ -711,11 +753,16 @@ class SSLSocket(socket):
# mixed in.
if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
raise NotImplementedError("only stream sockets are supported")
- if server_side and server_hostname:
- raise ValueError("server_hostname can only be specified "
- "in client mode")
+ if server_side:
+ if server_hostname:
+ raise ValueError("server_hostname can only be specified "
+ "in client mode")
+ if _session is not None:
+ raise ValueError("session can only be specified in "
+ "client mode")
if self._context.check_hostname and not server_hostname:
raise ValueError("check_hostname requires server_hostname")
+ self._session = _session
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
@@ -751,7 +798,8 @@ class SSLSocket(socket):
try:
sslobj = self._context._wrap_socket(self, server_side,
server_hostname)
- self._sslobj = SSLObject(sslobj, owner=self)
+ self._sslobj = SSLObject(sslobj, owner=self,
+ session=self._session)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
@@ -772,6 +820,24 @@ class SSLSocket(socket):
self._context = ctx
self._sslobj.context = ctx
+ @property
+ def session(self):
+ """The SSLSession for client socket."""
+ if self._sslobj is not None:
+ return self._sslobj.session
+
+ @session.setter
+ def session(self, session):
+ self._session = session
+ if self._sslobj is not None:
+ self._sslobj.session = session
+
+ @property
+ def session_reused(self):
+ """Was the client session reused during handshake"""
+ if self._sslobj is not None:
+ return self._sslobj.session_reused
+
def dup(self):
raise NotImplemented("Can't dup() %s instances" %
self.__class__.__name__)
@@ -898,7 +964,6 @@ class SSLSocket(socket):
while (count < amount):
v = self.send(data[count:])
count += v
- return amount
else:
return socket.sendall(self, data, flags)
@@ -1005,7 +1070,8 @@ class SSLSocket(socket):
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
- self._sslobj = SSLObject(sslobj, owner=self)
+ self._sslobj = SSLObject(sslobj, owner=self,
+ session=self._session)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
@@ -1068,7 +1134,6 @@ def wrap_socket(sock, keyfile=None, certfile=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
-
return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,