summaryrefslogtreecommitdiff
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py361
1 files changed, 293 insertions, 68 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 06437b3046..4c155ea7d4 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -52,10 +52,47 @@ PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
+PROTOCOL_TLSv1_1
+PROTOCOL_TLSv1_2
+
+The following constants identify various SSL alert message descriptions as per
+http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
+
+ALERT_DESCRIPTION_CLOSE_NOTIFY
+ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
+ALERT_DESCRIPTION_BAD_RECORD_MAC
+ALERT_DESCRIPTION_RECORD_OVERFLOW
+ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
+ALERT_DESCRIPTION_HANDSHAKE_FAILURE
+ALERT_DESCRIPTION_BAD_CERTIFICATE
+ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
+ALERT_DESCRIPTION_CERTIFICATE_REVOKED
+ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
+ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
+ALERT_DESCRIPTION_ILLEGAL_PARAMETER
+ALERT_DESCRIPTION_UNKNOWN_CA
+ALERT_DESCRIPTION_ACCESS_DENIED
+ALERT_DESCRIPTION_DECODE_ERROR
+ALERT_DESCRIPTION_DECRYPT_ERROR
+ALERT_DESCRIPTION_PROTOCOL_VERSION
+ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
+ALERT_DESCRIPTION_INTERNAL_ERROR
+ALERT_DESCRIPTION_USER_CANCELLED
+ALERT_DESCRIPTION_NO_RENEGOTIATION
+ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
+ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
+ALERT_DESCRIPTION_UNRECOGNIZED_NAME
+ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
+ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
+ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
"""
import textwrap
import re
+import sys
+import os
+from collections import namedtuple
+from enum import Enum as _Enum
import _ssl # if we can't import it, let the error propagate
@@ -66,35 +103,26 @@ from _ssl import (
SSLSyscallError, SSLEOFError,
)
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
-from _ssl import (
- OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1,
- OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE
- )
-try:
- from _ssl import OP_NO_COMPRESSION
-except ImportError:
- pass
-try:
- from _ssl import OP_SINGLE_ECDH_USE
-except ImportError:
- pass
+from _ssl import (VERIFY_DEFAULT, VERIFY_CRL_CHECK_LEAF, VERIFY_CRL_CHECK_CHAIN,
+ VERIFY_X509_STRICT)
+from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes
-from _ssl import (
- SSL_ERROR_ZERO_RETURN,
- SSL_ERROR_WANT_READ,
- SSL_ERROR_WANT_WRITE,
- SSL_ERROR_WANT_X509_LOOKUP,
- SSL_ERROR_SYSCALL,
- SSL_ERROR_SSL,
- SSL_ERROR_WANT_CONNECT,
- SSL_ERROR_EOF,
- SSL_ERROR_INVALID_ERROR_CODE,
- )
+
+def _import_symbols(prefix):
+ for n in dir(_ssl):
+ if n.startswith(prefix):
+ globals()[n] = getattr(_ssl, n)
+
+_import_symbols('OP_')
+_import_symbols('ALERT_DESCRIPTION_')
+_import_symbols('SSL_ERROR_')
+
from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
-from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23,
- PROTOCOL_TLSv1)
+
+from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
from _ssl import _OPENSSL_API_VERSION
+
_PROTOCOL_NAMES = {
PROTOCOL_TLSv1: "TLSv1",
PROTOCOL_SSLv23: "SSLv23",
@@ -108,13 +136,27 @@ except ImportError:
else:
_PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
+try:
+ from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2
+except ImportError:
+ pass
+else:
+ _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1"
+ _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2"
+
+if sys.platform == "win32":
+ from _ssl import enum_certificates, enum_crls
+
from socket import getnameinfo as _getnameinfo
-from socket import error as socket_error
+from socket import SHUT_RDWR as _SHUT_RDWR
from socket import socket, AF_INET, SOCK_STREAM, create_connection
import base64 # for DER-to-PEM translation
import traceback
import errno
+
+socket_error = OSError # keep that public name in module namespace
+
if _ssl.HAS_TLS_UNIQUE:
CHANNEL_BINDING_TYPES = ['tls-unique']
else:
@@ -124,6 +166,13 @@ else:
# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')
_DEFAULT_CIPHERS = 'DEFAULT:!aNULL:!eNULL:!LOW:!EXPORT:!SSLv2'
+# restricted and more secure ciphers
+# HIGH: high encryption cipher suites with key length >= 128 bits (no MD5)
+# !aNULL: only authenticated cipher suites (no anonymous DH)
+# !RC4: no RC4 streaming cipher, RC4 is broken
+# !DSS: RSA is preferred over DSA
+_RESTRICTED_CIPHERS = 'HIGH:!aNULL:!RC4:!DSS'
+
class CertificateError(ValueError):
pass
@@ -187,7 +236,9 @@ def match_hostname(cert, hostname):
returns nothing.
"""
if not cert:
- raise ValueError("empty or no certificate")
+ raise ValueError("empty or no certificate, match_hostname needs a "
+ "SSL socket or SSL context with either "
+ "CERT_OPTIONAL or CERT_REQUIRED")
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
@@ -219,11 +270,58 @@ def match_hostname(cert, hostname):
"subjectAltName fields were found")
+DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
+ "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
+ "openssl_capath")
+
+def get_default_verify_paths():
+ """Return paths to default cafile and capath.
+ """
+ parts = _ssl.get_default_verify_paths()
+
+ # environment vars shadow paths
+ cafile = os.environ.get(parts[0], parts[1])
+ capath = os.environ.get(parts[2], parts[3])
+
+ return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
+ capath if os.path.isdir(capath) else None,
+ *parts)
+
+
+class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
+ """ASN.1 object identifier lookup
+ """
+ __slots__ = ()
+
+ def __new__(cls, oid):
+ return super().__new__(cls, *_txt2obj(oid, name=False))
+
+ @classmethod
+ def fromnid(cls, nid):
+ """Create _ASN1Object from OpenSSL numeric ID
+ """
+ return super().__new__(cls, *_nid2obj(nid))
+
+ @classmethod
+ def fromname(cls, name):
+ """Create _ASN1Object from short name, long name or OID
+ """
+ return super().__new__(cls, *_txt2obj(name, name=True))
+
+
+class Purpose(_ASN1Object, _Enum):
+ """SSLContext purpose flags with X509v3 Extended Key Usage objects
+ """
+ SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
+ CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
+
+
class SSLContext(_SSLContext):
"""An SSLContext holds various SSL-related configuration options and
data, such as certificates and possibly a private key."""
- __slots__ = ('protocol',)
+ __slots__ = ('protocol', '__weakref__')
+ _windows_cert_stores = ("CA", "ROOT")
def __new__(cls, protocol, *args, **kwargs):
self = _SSLContext.__new__(cls, protocol)
@@ -255,6 +353,93 @@ class SSLContext(_SSLContext):
self._set_npn_protocols(protos)
+ def _load_windows_store_certs(self, storename, purpose):
+ certs = bytearray()
+ for cert, encoding, trust in enum_certificates(storename):
+ # CA certs are never PKCS#7 encoded
+ if encoding == "x509_asn":
+ if trust is True or purpose.oid in trust:
+ certs.extend(cert)
+ self.load_verify_locations(cadata=certs)
+ return certs
+
+ def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
+ if not isinstance(purpose, _ASN1Object):
+ raise TypeError(purpose)
+ if sys.platform == "win32":
+ for storename in self._windows_cert_stores:
+ self._load_windows_store_certs(storename, purpose)
+ else:
+ self.set_default_verify_paths()
+
+
+def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
+ capath=None, cadata=None):
+ """Create a SSLContext object with default settings.
+
+ NOTE: The protocol and settings may change anytime without prior
+ deprecation. The values represent a fair balance between maximum
+ compatibility and security.
+ """
+ if not isinstance(purpose, _ASN1Object):
+ raise TypeError(purpose)
+ context = SSLContext(PROTOCOL_TLSv1)
+ # SSLv2 considered harmful.
+ context.options |= OP_NO_SSLv2
+ # disable compression to prevent CRIME attacks (OpenSSL 1.0+)
+ context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0)
+ # disallow ciphers with known vulnerabilities
+ context.set_ciphers(_RESTRICTED_CIPHERS)
+ # verify certs and host name in client mode
+ if purpose == Purpose.SERVER_AUTH:
+ context.verify_mode = CERT_REQUIRED
+ context.check_hostname = True
+ if cafile or capath or cadata:
+ context.load_verify_locations(cafile, capath, cadata)
+ elif context.verify_mode != CERT_NONE:
+ # no explicit cafile, capath or cadata but the verify mode is
+ # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
+ # root CA certificates for the given purpose. This may fail silently.
+ context.load_default_certs(purpose)
+ return context
+
+
+def _create_stdlib_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None,
+ purpose=Purpose.SERVER_AUTH,
+ certfile=None, keyfile=None,
+ cafile=None, capath=None, cadata=None):
+ """Create a SSLContext object for Python stdlib modules
+
+ All Python stdlib modules shall use this function to create SSLContext
+ objects in order to keep common settings in one place. The configuration
+ is less restrict than create_default_context()'s to increase backward
+ compatibility.
+ """
+ if not isinstance(purpose, _ASN1Object):
+ raise TypeError(purpose)
+
+ context = SSLContext(protocol)
+ # SSLv2 considered harmful.
+ context.options |= OP_NO_SSLv2
+
+ if cert_reqs is not None:
+ context.verify_mode = cert_reqs
+
+ if keyfile and not certfile:
+ raise ValueError("certfile must be specified")
+ if certfile or keyfile:
+ context.load_cert_chain(certfile, keyfile)
+
+ # load CA root certs
+ if cafile or capath or cadata:
+ context.load_verify_locations(cafile, capath, cadata)
+ elif context.verify_mode != CERT_NONE:
+ # no explicit cafile, capath or cadata but the verify mode is
+ # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
+ # root CA certificates for the given purpose. This may fail silently.
+ context.load_default_certs(purpose)
+
+ return context
class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
@@ -271,7 +456,7 @@ class SSLSocket(socket):
_context=None):
if _context:
- self.context = _context
+ self._context = _context
else:
if server_side and not certfile:
raise ValueError("certfile must be specified for server-side "
@@ -280,16 +465,16 @@ class SSLSocket(socket):
raise ValueError("certfile must be specified")
if certfile and not keyfile:
keyfile = certfile
- self.context = SSLContext(ssl_version)
- self.context.verify_mode = cert_reqs
+ self._context = SSLContext(ssl_version)
+ self._context.verify_mode = cert_reqs
if ca_certs:
- self.context.load_verify_locations(ca_certs)
+ self._context.load_verify_locations(ca_certs)
if certfile:
- self.context.load_cert_chain(certfile, keyfile)
+ self._context.load_cert_chain(certfile, keyfile)
if npn_protocols:
- self.context.set_npn_protocols(npn_protocols)
+ self._context.set_npn_protocols(npn_protocols)
if ciphers:
- self.context.set_ciphers(ciphers)
+ self._context.set_ciphers(ciphers)
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
@@ -299,11 +484,17 @@ class SSLSocket(socket):
if server_side and server_hostname:
raise ValueError("server_hostname can only be specified "
"in client mode")
+ if self._context.check_hostname and not server_hostname:
+ if HAS_SNI:
+ raise ValueError("check_hostname requires server_hostname")
+ else:
+ raise ValueError("check_hostname requires server_hostname, "
+ "but it's not supported by your OpenSSL "
+ "library")
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
- connected = False
if sock is not None:
socket.__init__(self,
family=sock.family,
@@ -311,27 +502,29 @@ class SSLSocket(socket):
proto=sock.proto,
fileno=sock.fileno())
self.settimeout(sock.gettimeout())
- # see if it's connected
- try:
- sock.getpeername()
- except socket_error as e:
- if e.errno != errno.ENOTCONN:
- raise
- else:
- connected = True
sock.detach()
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
socket.__init__(self, family=family, type=type, proto=proto)
+ # See if we are connected
+ try:
+ self.getpeername()
+ except OSError as e:
+ if e.errno != errno.ENOTCONN:
+ raise
+ connected = False
+ else:
+ connected = True
+
self._closed = False
self._sslobj = None
self._connected = connected
if connected:
# create the SSL object
try:
- self._sslobj = self.context._wrap_socket(self, server_side,
+ self._sslobj = self._context._wrap_socket(self, server_side,
server_hostname)
if do_handshake_on_connect:
timeout = self.gettimeout()
@@ -340,9 +533,18 @@ class SSLSocket(socket):
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
- except socket_error as x:
+ except (OSError, ValueError):
self.close()
- raise x
+ raise
+
+ @property
+ def context(self):
+ return self._context
+
+ @context.setter
+ def context(self, ctx):
+ self._context = ctx
+ self._sslobj.context = ctx
def dup(self):
raise NotImplemented("Can't dup() %s instances" %
@@ -352,11 +554,21 @@ class SSLSocket(socket):
# raise an exception here if you wish to check for spurious closes
pass
+ def _check_connected(self):
+ if not self._connected:
+ # getpeername() will raise ENOTCONN if the socket is really
+ # not connected; note that we can be connected even without
+ # _connected being set, e.g. if connect() first returned
+ # EAGAIN.
+ self.getpeername()
+
def read(self, len=0, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
+ if not self._sslobj:
+ raise ValueError("Read on closed or unwrapped SSL socket.")
try:
if buffer is not None:
v = self._sslobj.read(len, buffer)
@@ -377,6 +589,8 @@ class SSLSocket(socket):
number of bytes of DATA actually transmitted."""
self._checkClosed()
+ if not self._sslobj:
+ raise ValueError("Write on closed or unwrapped SSL socket.")
return self._sslobj.write(data)
def getpeercert(self, binary_form=False):
@@ -386,6 +600,7 @@ class SSLSocket(socket):
certificate was provided, but not validated."""
self._checkClosed()
+ self._check_connected()
return self._sslobj.peer_certificate(binary_form)
def selected_npn_protocol(self):
@@ -416,18 +631,17 @@ class SSLSocket(socket):
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
- while True:
- try:
- v = self._sslobj.write(data)
- except SSLError as x:
- if x.args[0] == SSL_ERROR_WANT_READ:
- return 0
- elif x.args[0] == SSL_ERROR_WANT_WRITE:
- return 0
- else:
- raise
+ try:
+ v = self._sslobj.write(data)
+ except SSLError as x:
+ if x.args[0] == SSL_ERROR_WANT_READ:
+ return 0
+ elif x.args[0] == SSL_ERROR_WANT_WRITE:
+ return 0
else:
- return v
+ raise
+ else:
+ return v
else:
return socket.send(self, data, flags)
@@ -535,12 +749,11 @@ class SSLSocket(socket):
def _real_close(self):
self._sslobj = None
- # self._closed = True
socket._real_close(self)
def do_handshake(self, block=False):
"""Perform a TLS/SSL handshake."""
-
+ self._check_connected()
timeout = self.gettimeout()
try:
if timeout == 0.0 and block:
@@ -549,6 +762,17 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
+ if self.context.check_hostname:
+ try:
+ if not self.server_hostname:
+ raise ValueError("check_hostname needs server_hostname "
+ "argument")
+ match_hostname(self.getpeercert(), self.server_hostname)
+ except Exception:
+ self.shutdown(_SHUT_RDWR)
+ self.close()
+ raise
+
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
@@ -564,11 +788,11 @@ class SSLSocket(socket):
rc = None
socket.connect(self, addr)
if not rc:
+ self._connected = True
if self.do_handshake_on_connect:
self.do_handshake()
- self._connected = True
return rc
- except socket_error:
+ except (OSError, ValueError):
self._sslobj = None
raise
@@ -666,15 +890,16 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
If 'ssl_version' is specified, use it in the connection attempt."""
host, port = addr
- if (ca_certs is not None):
+ if ca_certs is not None:
cert_reqs = CERT_REQUIRED
else:
cert_reqs = CERT_NONE
- s = create_connection(addr)
- s = wrap_socket(s, ssl_version=ssl_version,
- cert_reqs=cert_reqs, ca_certs=ca_certs)
- dercert = s.getpeercert(True)
- s.close()
+ context = _create_stdlib_context(ssl_version,
+ cert_reqs=cert_reqs,
+ cafile=ca_certs)
+ with create_connection(addr) as sock:
+ with context.wrap_socket(sock) as sslsock:
+ dercert = sslsock.getpeercert(True)
return DER_cert_to_PEM_cert(dercert)
def get_protocol_name(protocol_code):