diff options
Diffstat (limited to 'src/M2Crypto/SSL')
-rw-r--r-- | src/M2Crypto/SSL/Checker.py | 296 | ||||
-rw-r--r-- | src/M2Crypto/SSL/Cipher.py | 59 | ||||
-rw-r--r-- | src/M2Crypto/SSL/Connection.py | 712 | ||||
-rw-r--r-- | src/M2Crypto/SSL/Context.py | 445 | ||||
-rw-r--r-- | src/M2Crypto/SSL/SSLServer.py | 65 | ||||
-rw-r--r-- | src/M2Crypto/SSL/Session.py | 69 | ||||
-rw-r--r-- | src/M2Crypto/SSL/TwistedProtocolWrapper.py | 490 | ||||
-rw-r--r-- | src/M2Crypto/SSL/__init__.py | 44 | ||||
-rw-r--r-- | src/M2Crypto/SSL/cb.py | 96 | ||||
-rw-r--r-- | src/M2Crypto/SSL/ssl_dispatcher.py | 43 | ||||
-rw-r--r-- | src/M2Crypto/SSL/timeout.py | 50 |
11 files changed, 2369 insertions, 0 deletions
diff --git a/src/M2Crypto/SSL/Checker.py b/src/M2Crypto/SSL/Checker.py new file mode 100644 index 0000000..46d397a --- /dev/null +++ b/src/M2Crypto/SSL/Checker.py @@ -0,0 +1,296 @@ +""" +SSL peer certificate checking routines + +Copyright (c) 2004-2007 Open Source Applications Foundation. +All rights reserved. + +Copyright 2008 Heikki Toivonen. All rights reserved. +""" + +__all__ = ['SSLVerificationError', 'NoCertificate', 'WrongCertificate', + 'WrongHost', 'Checker'] + +import re +import socket + +from M2Crypto import X509, m2, six # noqa +from typing import AnyStr, Optional # noqa + + +class SSLVerificationError(Exception): + pass + + +class NoCertificate(SSLVerificationError): + pass + + +class WrongCertificate(SSLVerificationError): + pass + + +class WrongHost(SSLVerificationError): + def __init__(self, expectedHost, actualHost, fieldName='commonName'): + # type: (str, AnyStr, str) -> None + """ + This exception will be raised if the certificate returned by the + peer was issued for a different host than we tried to connect to. + This could be due to a server misconfiguration or an active attack. + + :param expectedHost: The name of the host we expected to find in the + certificate. + :param actualHost: The name of the host we actually found in the + certificate. + :param fieldName: The field name where we noticed the error. This + should be either 'commonName' or 'subjectAltName'. + """ + if fieldName not in ('commonName', 'subjectAltName'): + raise ValueError( + 'Unknown fieldName, should be either commonName ' + + 'or subjectAltName') + + SSLVerificationError.__init__(self) + self.expectedHost = expectedHost + self.actualHost = actualHost + self.fieldName = fieldName + + def __str__(self): + # type: () -> str + s = 'Peer certificate %s does not match host, expected %s, got %s' \ + % (self.fieldName, self.expectedHost, self.actualHost) + return six.ensure_text(s) + + +class Checker(object): + + numericIpMatch = re.compile('^[0-9]+(\.[0-9]+)*$') + + def __init__(self, host=None, peerCertHash=None, peerCertDigest='sha1'): + # type: (Optional[str], Optional[bytes], str) -> None + self.host = host + if peerCertHash is not None: + peerCertHash = six.ensure_binary(peerCertHash) + self.fingerprint = peerCertHash + self.digest = peerCertDigest # type: str + + def __call__(self, peerCert, host=None): + # type: (X509.X509, Optional[str]) -> bool + if peerCert is None: + raise NoCertificate('peer did not return certificate') + + if host is not None: + self.host = host # type: str + + if self.fingerprint: + if self.digest not in ('sha1', 'md5'): + raise ValueError('unsupported digest "%s"' % self.digest) + + if self.digest == 'sha1': + expected_len = 40 + elif self.digest == 'md5': + expected_len = 32 + else: + raise ValueError('Unexpected digest {0}'.format(self.digest)) + + if len(self.fingerprint) != expected_len: + raise WrongCertificate( + ('peer certificate fingerprint length does not match\n' + + 'fingerprint: {0}\nexpected = {1}\n' + + 'observed = {2}').format(self.fingerprint, + expected_len, + len(self.fingerprint))) + + expected_fingerprint = six.ensure_text(self.fingerprint) + observed_fingerprint = peerCert.get_fingerprint(md=self.digest) + if observed_fingerprint != expected_fingerprint: + raise WrongCertificate( + ('peer certificate fingerprint does not match\n' + + 'expected = {0},\n' + + 'observed = {1}').format(expected_fingerprint, + observed_fingerprint)) + + if self.host: + hostValidationPassed = False + self.useSubjectAltNameOnly = False + + # subjectAltName=DNS:somehost[, ...]* + try: + subjectAltName = peerCert.get_ext('subjectAltName').get_value() + if self._splitSubjectAltName(self.host, subjectAltName): + hostValidationPassed = True + elif self.useSubjectAltNameOnly: + raise WrongHost(expectedHost=self.host, + actualHost=subjectAltName, + fieldName='subjectAltName') + except LookupError: + pass + + # commonName=somehost[, ...]* + if not hostValidationPassed: + hasCommonName = False + commonNames = '' + for entry in peerCert.get_subject().get_entries_by_nid( + m2.NID_commonName): + hasCommonName = True + commonName = entry.get_data().as_text() + if not commonNames: + commonNames = commonName + else: + commonNames += ',' + commonName + if self._match(self.host, commonName): + hostValidationPassed = True + break + + if not hasCommonName: + raise WrongCertificate('no commonName in peer certificate') + + if not hostValidationPassed: + raise WrongHost(expectedHost=self.host, + actualHost=commonNames, + fieldName='commonName') + + return True + + def _splitSubjectAltName(self, host, subjectAltName): + # type: (AnyStr, AnyStr) -> bool + """ + >>> check = Checker() + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:my.example.com') + True + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:*.example.com') + True + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:m*.example.com') + True + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:m*ample.com') + False + >>> check.useSubjectAltNameOnly + True + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:m*ample.com, othername:<unsupported>') + False + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:m*ample.com, DNS:my.example.org') + False + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:m*ample.com, DNS:my.example.com') + True + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='DNS:my.example.com, DNS:my.example.org') + True + >>> check.useSubjectAltNameOnly + True + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='') + False + >>> check._splitSubjectAltName(host='my.example.com', + ... subjectAltName='othername:<unsupported>') + False + >>> check.useSubjectAltNameOnly + False + """ + self.useSubjectAltNameOnly = False + for certHost in subjectAltName.split(','): + certHost = certHost.lower().strip() + if certHost[:4] == 'dns:': + self.useSubjectAltNameOnly = True + if self._match(host, certHost[4:]): + return True + elif certHost[:11] == 'ip address:': + self.useSubjectAltNameOnly = True + if self._matchIPAddress(host, certHost[11:]): + return True + return False + + def _match(self, host, certHost): + # type: (str, str) -> bool + """ + >>> check = Checker() + >>> check._match(host='my.example.com', certHost='my.example.com') + True + >>> check._match(host='my.example.com', certHost='*.example.com') + True + >>> check._match(host='my.example.com', certHost='m*.example.com') + True + >>> check._match(host='my.example.com', certHost='m*.EXAMPLE.com') + True + >>> check._match(host='my.example.com', certHost='m*ample.com') + False + >>> check._match(host='my.example.com', certHost='*.*.com') + False + >>> check._match(host='1.2.3.4', certHost='1.2.3.4') + True + >>> check._match(host='1.2.3.4', certHost='*.2.3.4') + False + >>> check._match(host='1234', certHost='1234') + True + """ + # XXX See RFC 2818 and 3280 for matching rules, this is may not + # XXX yet be complete. + + host = host.lower() + certHost = certHost.lower() + + if host == certHost: + return True + + if certHost.count('*') > 1: + # Not sure about this, but being conservative + return False + + if self.numericIpMatch.match(host) or \ + self.numericIpMatch.match(certHost.replace('*', '')): + # Not sure if * allowed in numeric IP, but think not. + return False + + if certHost.find('\\') > -1: + # Not sure about this, maybe some encoding might have these. + # But being conservative for now, because regex below relies + # on this. + return False + + # Massage certHost so that it can be used in regex + certHost = certHost.replace('.', '\.') + certHost = certHost.replace('*', '[^\.]*') + if re.compile('^%s$' % certHost).match(host): + return True + + return False + + def _matchIPAddress(self, host, certHost): + # type: (AnyStr, AnyStr) -> bool + """ + >>> check = Checker() + >>> check._matchIPAddress(host='my.example.com', + ... certHost='my.example.com') + False + >>> check._matchIPAddress(host='1.2.3.4', certHost='1.2.3.4') + True + >>> check._matchIPAddress(host='1.2.3.4', certHost='*.2.3.4') + False + >>> check._matchIPAddress(host='1.2.3.4', certHost='1.2.3.40') + False + >>> check._matchIPAddress(host='::1', certHost='::1') + True + >>> check._matchIPAddress(host='::1', certHost='0:0:0:0:0:0:0:1') + True + >>> check._matchIPAddress(host='::1', certHost='::2') + False + """ + try: + canonical = socket.getaddrinfo(host, 0, 0, socket.SOCK_STREAM, 0, + socket.AI_NUMERICHOST) + certCanonical = socket.getaddrinfo(certHost, 0, 0, + socket.SOCK_STREAM, 0, + socket.AI_NUMERICHOST) + except: + return False + return canonical == certCanonical + + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/src/M2Crypto/SSL/Cipher.py b/src/M2Crypto/SSL/Cipher.py new file mode 100644 index 0000000..06f2dbf --- /dev/null +++ b/src/M2Crypto/SSL/Cipher.py @@ -0,0 +1,59 @@ +"""SSL Ciphers + +Copyright (c) 1999-2003 Ng Pheng Siong. All rights reserved.""" + +__all__ = ['Cipher', 'Cipher_Stack'] + +from M2Crypto import m2, six +from typing import Iterable # noqa + + +class Cipher(object): + def __init__(self, cipher): + # type: (str) -> None + self.cipher = cipher + + def __len__(self): + # type: () -> int + return m2.ssl_cipher_get_bits(self.cipher) + + def __repr__(self): + # type: () -> str + return "%s-%s" % (self.name(), len(self)) + + def __str__(self): + # type: () -> str + return "%s-%s" % (self.name(), len(self)) + + def version(self): + # type: () -> int + return m2.ssl_cipher_get_version(self.cipher) + + def name(self): + # type: () -> str + return six.ensure_text(m2.ssl_cipher_get_name(self.cipher)) + + +class Cipher_Stack(object): + def __init__(self, stack): + # type: (bytes) -> None + """ + :param stack: binary of the C-type STACK_OF(SSL_CIPHER) + """ + self.stack = stack + + def __len__(self): + # type: () -> int + return m2.sk_ssl_cipher_num(self.stack) + + def __getitem__(self, idx): + # type: (int) -> Cipher + if not 0 <= idx < m2.sk_ssl_cipher_num(self.stack): + raise IndexError('index out of range') + v = m2.sk_ssl_cipher_value(self.stack, idx) + return Cipher(v) + + def __iter__(self): + # type: () -> Iterable + for i in six.moves.range(m2.sk_ssl_cipher_num(self.stack)): + yield self[i] diff --git a/src/M2Crypto/SSL/Connection.py b/src/M2Crypto/SSL/Connection.py new file mode 100644 index 0000000..c5693c9 --- /dev/null +++ b/src/M2Crypto/SSL/Connection.py @@ -0,0 +1,712 @@ +from __future__ import absolute_import + +"""SSL Connection aka socket + +Copyright (c) 1999-2004 Ng Pheng Siong. All rights reserved. + +Portions created by Open Source Applications Foundation (OSAF) are +Copyright (C) 2004-2007 OSAF. All Rights Reserved. + +Copyright 2008 Heikki Toivonen. All rights reserved. +""" + +import logging +import socket +import io + +from M2Crypto import BIO, Err, X509, m2, six, util # noqa +from M2Crypto.SSL import Checker, Context, timeout # noqa +from M2Crypto.SSL import SSLError +from M2Crypto.SSL.Cipher import Cipher, Cipher_Stack +from M2Crypto.SSL.Session import Session +from typing import Any, AnyStr, Callable, Optional, Tuple, Union # noqa + +__all__ = ['Connection', + 'timeout', # XXX Not really, but for documentation purposes + ] + +log = logging.getLogger(__name__) + + +def _serverPostConnectionCheck(*args, **kw): + # type: (*Any, **Any) -> int + return 1 + + +class Connection(object): + """An SSL connection.""" + + serverPostConnectionCheck = _serverPostConnectionCheck + + m2_bio_free = m2.bio_free + m2_ssl_free = m2.ssl_free + m2_bio_noclose = m2.bio_noclose + + def __init__(self, ctx, sock=None, family=socket.AF_INET): + # type: (Context, socket.socket, int) -> None + """ + + :param ctx: SSL.Context + :param sock: socket to be used + :param family: socket family + """ + # The Checker needs to be an instance attribute + # and not a class attribute for thread safety reason + self.clientPostConnectionCheck = Checker.Checker() + + self._bio_freed = False + self.ctx = ctx + self.ssl = m2.ssl_new(self.ctx.ctx) # type: bytes + if sock is not None: + self.socket = sock + else: + self.socket = socket.socket(family, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._fileno = self.socket.fileno() + + self._timeout = self.socket.gettimeout() + if self._timeout is None: + self._timeout = -1.0 + + self.ssl_close_flag = m2.bio_noclose + + if self.ctx.post_connection_check is not None: + self.set_post_connection_check_callback( + self.ctx.post_connection_check) + + self.host = None + + def _free_bio(self): + """ + Free the sslbio and sockbio, and close the socket. + """ + # Do not do it twice + if not self._bio_freed: + if getattr(self, 'sslbio', None): + self.m2_bio_free(self.sslbio) + if getattr(self, 'sockbio', None): + self.m2_bio_free(self.sockbio) + if self.ssl_close_flag == self.m2_bio_noclose and \ + getattr(self, 'ssl', None): + self.m2_ssl_free(self.ssl) + self.socket.close() + self._bio_freed = True + + + def __del__(self): + # type: () -> None + # Notice that M2Crypto doesn't automatically shuts down the + # connection here. You have to call self.close() in your + # program, M2Crypto won't do it automatically for you. + self._free_bio() + + def close(self, freeBio=False): + """ + if freeBio is true, call _free_bio + """ + # type: () -> None + m2.ssl_shutdown(self.ssl) + if freeBio: + self._free_bio() + + def clear(self): + # type: () -> int + """ + If there were errors in this connection, call clear() rather + than close() to end it, so that bad sessions will be cleared + from cache. + """ + return m2.ssl_clear(self.ssl) + + def set_shutdown(self, mode): + # type: (int) -> None + """Sets the shutdown state of the Connection to mode. + + The shutdown state of an ssl connection is a bitmask of (use + m2.SSL_* constants): + + 0 No shutdown setting, yet. + + SSL_SENT_SHUTDOWN + A "close notify" shutdown alert was sent to the peer, the + connection is being considered closed and the session is + closed and correct. + + SSL_RECEIVED_SHUTDOWN + A shutdown alert was received form the peer, either a normal + "close notify" or a fatal error. + + SSL_SENT_SHUTDOWN and SSL_RECEIVED_SHUTDOWN can be set at the + same time. + + :param mode: set the mode bitmask. + """ + m2.ssl_set_shutdown1(self.ssl, mode) + + def get_shutdown(self): + # type: () -> None + """Get the current shutdown mode of the Connection.""" + return m2.ssl_get_shutdown(self.ssl) + + def bind(self, addr): + # type: (util.AddrType) -> None + self.socket.bind(addr) + + def listen(self, qlen=5): + # type: (int) -> None + self.socket.listen(qlen) + + def ssl_get_error(self, ret): + # type: (int) -> int + return m2.ssl_get_error(self.ssl, ret) + + def set_bio(self, readbio, writebio): + # type: (BIO.BIO, BIO.BIO) -> None + """Explicitly set read and write bios + + Connects the BIOs for the read and write operations of the + TLS/SSL (encrypted) side of ssl. + + The SSL engine inherits the behaviour of both BIO objects, + respectively. If a BIO is non-blocking, the Connection will also + have non-blocking behaviour. + + If there was already a BIO connected to Connection, BIO_free() + will be called (for both the reading and writing side, if + different). + + :param readbio: BIO for reading + :param writebio: BIO for writing. + """ + m2.ssl_set_bio(self.ssl, readbio._ptr(), writebio._ptr()) + + def set_client_CA_list_from_file(self, cafile): + # type: (AnyStr) -> None + """Set the acceptable client CA list. + + If the client returns a certificate, it must have been issued by + one of the CAs listed in cafile. + + Makes sense only for servers. + + :param cafile: Filename from which to load the CA list. + + :return: 0 A failure while manipulating the STACK_OF(X509_NAME) + object occurred or the X509_NAME could not be + extracted from cacert. Check the error stack to find + out the reason. + + 1 The operation succeeded. + """ + m2.ssl_set_client_CA_list_from_file(self.ssl, cafile) + + def set_client_CA_list_from_context(self): + # type: () -> None + """ + Set the acceptable client CA list. If the client + returns a certificate, it must have been issued by + one of the CAs listed in context. + + Makes sense only for servers. + """ + m2.ssl_set_client_CA_list_from_context(self.ssl, self.ctx.ctx) + + def setup_addr(self, addr): + # type: (util.AddrType) -> None + self.addr = addr + + def set_ssl_close_flag(self, flag): + # type: (int) -> None + """ + By default, SSL struct will be freed in __del__. Call with + m2.bio_close to override this default. + + :param flag: either m2.bio_close or m2.bio_noclose + """ + if flag not in (m2.bio_close, m2.bio_noclose): + raise ValueError("flag must be m2.bio_close or m2.bio_noclose") + self.ssl_close_flag = flag + + def setup_ssl(self): + # type: () -> None + # Make a BIO_s_socket. + self.sockbio = m2.bio_new_socket(self.socket.fileno(), 0) + # Link SSL struct with the BIO_socket. + m2.ssl_set_bio(self.ssl, self.sockbio, self.sockbio) + # Make a BIO_f_ssl. + self.sslbio = m2.bio_new(m2.bio_f_ssl()) + # Link BIO_f_ssl with the SSL struct. + m2.bio_set_ssl(self.sslbio, self.ssl, m2.bio_noclose) + + def _setup_ssl(self, addr): + # type: (util.AddrType) -> None + """Deprecated""" + self.setup_addr(addr) + self.setup_ssl() + + def set_accept_state(self): + # type: () -> None + """Sets Connection to work in the server mode.""" + m2.ssl_set_accept_state(self.ssl) + + def accept_ssl(self): + # type: () -> Optional[int] + """Waits for a TLS/SSL client to initiate the TLS/SSL handshake. + + The communication channel must already have been set and + assigned to the ssl by setting an underlying BIO. + + :return: 0 The TLS/SSL handshake was not successful but was shut + down controlled and by the specifications of the + TLS/SSL protocol. Call get_error() with the return + value ret to find out the reason. + + 1 The TLS/SSL handshake was successfully completed, + a TLS/SSL connection has been established. + + <0 The TLS/SSL handshake was not successful because + a fatal error occurred either at the protocol level + or a connection failure occurred. The shutdown was + not clean. It can also occur of action is need to + continue the operation for non-blocking BIOs. Call + get_error() with the return value ret to find + out the reason. + """ + return m2.ssl_accept(self.ssl, self._timeout) + + def accept(self): + # type: () -> Tuple[Connection, util.AddrType] + """Accept an SSL connection. + + The return value is a pair (ssl, addr) where ssl is a new SSL + connection object and addr is the address bound to the other end + of the SSL connection. + + :return: tuple of Connection and addr. Address can take very + various forms (see socket documentation), for IPv4 it + is tuple(str, int), for IPv6 a tuple of four (host, + port, flowinfo, scopeid), where the last two are + optional ints. + """ + sock, addr = self.socket.accept() + ssl = Connection(self.ctx, sock) + ssl.addr = addr + ssl.setup_ssl() + ssl.set_accept_state() + ssl.accept_ssl() + check = getattr(self, 'postConnectionCheck', + self.serverPostConnectionCheck) + if check is not None: + if not check(ssl.get_peer_cert(), ssl.addr[0]): + raise Checker.SSLVerificationError( + 'post connection check failed') + return ssl, addr + + def set_connect_state(self): + # type: () -> None + """Sets Connection to work in the client mode.""" + m2.ssl_set_connect_state(self.ssl) + + def connect_ssl(self): + # type: () -> Optional[int] + return m2.ssl_connect(self.ssl, self._timeout) + + def connect(self, addr): + # type: (util.AddrType) -> int + """Overloading socket.connect() + + :param addr: addresses have various depending on their type + + :return:status of ssl_connect() + """ + self.socket.connect(addr) + self.addr = addr + self.setup_ssl() + self.set_connect_state() + ret = self.connect_ssl() + check = getattr(self, 'postConnectionCheck', + self.clientPostConnectionCheck) + if check is not None: + if not check(self.get_peer_cert(), + self.host if self.host else self.addr[0]): + raise Checker.SSLVerificationError( + 'post connection check failed') + return ret + + def shutdown(self, how): + # type: (int) -> None + m2.ssl_set_shutdown(self.ssl, how) + + def renegotiate(self): + # type: () -> int + """Renegotiate this connection's SSL parameters.""" + return m2.ssl_renegotiate(self.ssl) + + def pending(self): + # type: () -> int + """Return the numbers of octets that can be read from the connection.""" + return m2.ssl_pending(self.ssl) + + def _write_bio(self, data): + # type: (bytes) -> int + return m2.ssl_write(self.ssl, data, self._timeout) + + def _write_nbio(self, data): + # type: (bytes) -> int + return m2.ssl_write_nbio(self.ssl, data) + + def _read_bio(self, size=1024): + # type: (int) -> bytes + if size <= 0: + raise ValueError('size <= 0') + return m2.ssl_read(self.ssl, size, self._timeout) + + def _read_nbio(self, size=1024): + # type: (int) -> bytes + if size <= 0: + raise ValueError('size <= 0') + return m2.ssl_read_nbio(self.ssl, size) + + def write(self, data): + # type: (bytes) -> int + if self._timeout != 0.0: + return self._write_bio(data) + return self._write_nbio(data) + sendall = send = write + + def _decref_socketios(self): + pass + + def recv_into(self, buff, nbytes=0): + # type: (Union[bytearray, memoryview], int) -> int + """ + A version of recv() that stores its data into a buffer rather + than creating a new string. Receive up to buffersize bytes from + the socket. If buffersize is not specified (or 0), receive up + to the size available in the given buffer. + + If buff is bytearray, it will have after return length of the + actually returned number of bytes. If buff is memoryview, then + the size of buff won't change (it cannot), but all bytes after + the number of returned bytes will be NULL. + + :param buffer: a buffer for the received bytes + :param nbytes: maximum number of bytes to read + :return: number of bytes read + + See recv() for documentation about the flags. + """ + n = len(buff) if nbytes == 0 else nbytes + + if n <= 0: + raise ValueError('size <= 0') + + # buff_bytes are actual bytes returned + buff_bytes = m2.ssl_read(self.ssl, n, self._timeout) + buflen = len(buff_bytes) + + # memoryview type has been added in 2.7 + if isinstance(buff, memoryview): + buff[:buflen] = buff_bytes + buff[buflen:] = b'\x00' * (len(buff) - buflen) + else: + buff[:] = buff_bytes + + return buflen + + def read(self, size=1024): + # type: (int) -> bytes + if self._timeout != 0.0: + return self._read_bio(size) + return self._read_nbio(size) + recv = read + + def setblocking(self, mode): + # type: (int) -> None + """Set this connection's underlying socket to _mode_. + + Set blocking or non-blocking mode of the socket: if flag is 0, + the socket is set to non-blocking, else to blocking mode. + Initially all sockets are in blocking mode. In non-blocking mode, + if a recv() call doesn't find any data, or if a send() call can't + immediately dispose of the data, a error exception is raised; + in blocking mode, the calls block until they can proceed. + s.setblocking(0) is equivalent to s.settimeout(0.0); + s.setblocking(1) is equivalent to s.settimeout(None). + + :param mode: new mode to be set + """ + self.socket.setblocking(mode) + if mode: + self._timeout = -1.0 + else: + self._timeout = 0.0 + + def settimeout(self, timeout): + # type: (float) -> None + """Set this connection's underlying socket's timeout to _timeout_.""" + self.socket.settimeout(timeout) + self._timeout = timeout + if self._timeout is None: + self._timeout = -1.0 + + def fileno(self): + # type: () -> int + return self.socket.fileno() + + def getsockopt(self, level, optname, buflen=None): + # type: (int, int, Optional[int]) -> Union[int, bytes] + """Get the value of the given socket option. + + :param level: level at which the option resides. + To manipulate options at the sockets API level, level is + specified as socket.SOL_SOCKET. To manipulate options at + any other level the protocol number of the appropriate + protocol controlling the option is supplied. For example, + to indicate that an option is to be interpreted by the + TCP protocol, level should be set to the protocol number + of socket.SOL_TCP; see getprotoent(3). + + :param optname: The value of the given socket option is + described in the Unix man page getsockopt(2)). The needed + symbolic constants (SO_* etc.) are defined in the socket + module. + + :param buflen: If it is absent, an integer option is assumed + and its integer value is returned by the function. If + buflen is present, it specifies the maximum length of the + buffer used to receive the option in, and this buffer is + returned as a bytes object. + + :return: Either integer or bytes value of the option. It is up + to the caller to decode the contents of the buffer (see + the optional built-in module struct for a way to decode + C structures encoded as byte strings). + """ + return self.socket.getsockopt(level, optname, buflen) + + def setsockopt(self, level, optname, value=None): + # type: (int, int, Union[int, bytes, None]) -> Optional[bytes] + """Set the value of the given socket option. + + :param level: same as with getsockopt() above + + :param optname: same as with getsockopt() above + + :param value: an integer or a string representing a buffer. In + the latter case it is up to the caller to ensure + that the string contains the proper bits (see the + optional built-in module struct for a way to + encode C structures as strings). + + :return: None for success or the error handler for failure. + """ + return self.socket.setsockopt(level, optname, value) + + def get_context(self): + # type: () -> Context + """Return the Context object associated with this connection.""" + return m2.ssl_get_ssl_ctx(self.ssl) + + def get_state(self): + # type: () -> bytes + """Return the SSL state of this connection. + + During its use, an SSL objects passes several states. The state + is internally maintained. Querying the state information is not + very informative before or when a connection has been + established. It however can be of significant interest during + the handshake. + + :return: 6 letter string indicating the current state of the SSL + object ssl. + """ + return m2.ssl_get_state(self.ssl) + + def verify_ok(self): + # type: () -> bool + return (m2.ssl_get_verify_result(self.ssl) == m2.X509_V_OK) + + def get_verify_mode(self): + # type: () -> int + """Return the peer certificate verification mode.""" + return m2.ssl_get_verify_mode(self.ssl) + + def get_verify_depth(self): + # type: () -> int + """Return the peer certificate verification depth.""" + return m2.ssl_get_verify_depth(self.ssl) + + def get_verify_result(self): + # type: () -> int + """Return the peer certificate verification result.""" + return m2.ssl_get_verify_result(self.ssl) + + def get_peer_cert(self): + # type: () -> X509.X509 + """Return the peer certificate. + + If the peer did not provide a certificate, return None. + """ + c = m2.ssl_get_peer_cert(self.ssl) + if c is None: + return None + # Need to free the pointer coz OpenSSL doesn't. + return X509.X509(c, 1) + + def get_peer_cert_chain(self): + # type: () -> Optional[X509.X509_Stack] + """Return the peer certificate chain; if the peer did not provide + a certificate chain, return None. + + :warning: The returned chain will be valid only for as long as the + connection object is alive. Once the connection object + gets freed, the chain will be freed as well. + """ + c = m2.ssl_get_peer_cert_chain(self.ssl) + if c is None: + return None + # No need to free the pointer coz OpenSSL does. + return X509.X509_Stack(c) + + def get_cipher(self): + # type: () -> Optional[Cipher] + """Return an M2Crypto.SSL.Cipher object for this connection; if the + connection has not been initialised with a cipher suite, return None. + """ + c = m2.ssl_get_current_cipher(self.ssl) + if c is None: + return None + return Cipher(c) + + def get_ciphers(self): + # type: () -> Optional[Cipher_Stack] + """Return an M2Crypto.SSL.Cipher_Stack object for this + connection; if the connection has not been initialised with + cipher suites, return None. + """ + c = m2.ssl_get_ciphers(self.ssl) + if c is None: + return None + return Cipher_Stack(c) + + def get_cipher_list(self, idx=0): + # type: (int) -> str + """Return the cipher suites for this connection as a string object.""" + return six.ensure_text(m2.ssl_get_cipher_list(self.ssl, idx)) + + def set_cipher_list(self, cipher_list): + # type: (str) -> int + """Set the cipher suites for this connection.""" + return m2.ssl_set_cipher_list(self.ssl, cipher_list) + + def makefile(self, mode='rb', bufsize=-1): + # type: (AnyStr, int) -> Union[io.BufferedRWPair,io.BufferedReader] + if six.PY3: + raw = socket.SocketIO(self, mode) + if 'rw' in mode: + return io.BufferedRWPair(raw, raw) + return io.BufferedReader(raw, io.DEFAULT_BUFFER_SIZE) + else: + return socket._fileobject(self, mode, bufsize) + + def getsockname(self): + # type: () -> util.AddrType + """Return the socket's own address. + + This is useful to find out the port number of an IPv4/v6 socket, + for instance. (The format of the address returned depends + on the address family -- see above.) + + :return:socket's address as addr type + """ + return self.socket.getsockname() + + def getpeername(self): + # type: () -> util.AddrType + """Return the remote address to which the socket is connected. + + This is useful to find out the port number of a remote IPv4/v6 socket, + for instance. + On some systems this function is not supported. + + :return: + """ + return self.socket.getpeername() + + def set_session_id_ctx(self, id): + # type: (bytes) -> int + ret = m2.ssl_set_session_id_context(self.ssl, id) + if not ret: + raise SSLError(Err.get_error_message()) + + def get_session(self): + # type: () -> Session + sess = m2.ssl_get_session(self.ssl) + return Session(sess) + + def set_session(self, session): + # type: (Session) -> None + m2.ssl_set_session(self.ssl, session._ptr()) + + def get_default_session_timeout(self): + # type: () -> int + return m2.ssl_get_default_session_timeout(self.ssl) + + def get_socket_read_timeout(self): + # type: () -> timeout + return timeout.struct_to_timeout( + self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_RCVTIMEO, + timeout.struct_size())) + + @staticmethod + def _hexdump(s): + assert isinstance(s, six.binary_type) + return ":".join("{0:02x}".format(ord(c) if six.PY2 else c) for c in s) + + def get_socket_write_timeout(self): + # type: () -> timeout + binstr = self.socket.getsockopt( + socket.SOL_SOCKET, socket.SO_SNDTIMEO, timeout.struct_size()) + timeo = timeout.struct_to_timeout(binstr) + # print("Debug: get_socket_write_timeout: " + # "get sockopt value: %s -> ret timeout(sec=%r, microsec=%r)" % + # (self._hexdump(binstr), timeo.sec, timeo.microsec)) + return timeo + + def set_socket_read_timeout(self, timeo): + # type: (timeout) -> None + assert isinstance(timeo, timeout.timeout) + self.socket.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVTIMEO, timeo.pack()) + + def set_socket_write_timeout(self, timeo): + # type: (timeout) -> None + assert isinstance(timeo, timeout.timeout) + binstr = timeo.pack() + # print("Debug: set_socket_write_timeout: " + # "input timeout(sec=%r, microsec=%r) -> set sockopt value: %s" % + # (timeo.sec, timeo.microsec, self._hexdump(binstr))) + self.socket.setsockopt( + socket.SOL_SOCKET, socket.SO_SNDTIMEO, binstr) + + def get_version(self): + # type: () -> str + """Return the TLS/SSL protocol version for this connection.""" + return six.ensure_text(m2.ssl_get_version(self.ssl)) + + def set_post_connection_check_callback(self, postConnectionCheck): # noqa + # type: (Callable) -> None + self.postConnectionCheck = postConnectionCheck + + def set_tlsext_host_name(self, name): + # type: (bytes) -> None + """Set the requested hostname for the SNI (Server Name Indication) + extension. + """ + m2.ssl_set_tlsext_host_name(self.ssl, name) + + def set1_host(self, name): + # type: (bytes) -> None + """Set the requested hostname to check in the server certificate.""" + self.host = name diff --git a/src/M2Crypto/SSL/Context.py b/src/M2Crypto/SSL/Context.py new file mode 100644 index 0000000..d7cb7bc --- /dev/null +++ b/src/M2Crypto/SSL/Context.py @@ -0,0 +1,445 @@ +from __future__ import absolute_import + +"""SSL Context + +Copyright (c) 1999-2004 Ng Pheng Siong. All rights reserved.""" + +from M2Crypto import BIO, Err, RSA, X509, m2, util # noqa +from M2Crypto.SSL import cb # noqa +from M2Crypto.SSL.Session import Session # noqa +from weakref import WeakValueDictionary +from typing import Any, AnyStr, Callable, Optional, Union # noqa + +__all__ = ['ctxmap', 'Context', 'map'] + + +class _ctxmap(object): + singleton = None # type: Optional[_ctxmap] + + def __init__(self): + # type: () -> None + """Simple WeakReffed list. + """ + self._ctxmap = WeakValueDictionary() + + def __getitem__(self, key): + # type: (int) -> Any + return self._ctxmap[key] + + def __setitem__(self, key, value): + # type: (int, Any) -> None + self._ctxmap[key] = value + + def __delitem__(self, key): + # type: (int) -> None + del self._ctxmap[key] + + +def ctxmap(): + # type: () -> _ctxmap + if _ctxmap.singleton is None: + _ctxmap.singleton = _ctxmap() + return _ctxmap.singleton +# deprecated!!! +map = ctxmap + + +class Context(object): + + """'Context' for SSL connections.""" + + m2_ssl_ctx_free = m2.ssl_ctx_free + + def __init__(self, protocol='tls', weak_crypto=None, + post_connection_check=None): + # type: (str, Optional[int], Optional[Callable]) -> None + proto = getattr(m2, protocol + '_method', None) + if proto is None: + # default is 'sslv23' for older versions of OpenSSL + if protocol == 'tls': + proto = getattr(m2, 'sslv23_method') + else: + raise ValueError("no such protocol '%s'" % protocol) + self.ctx = m2.ssl_ctx_new(proto()) + self.allow_unknown_ca = 0 # type: Union[int, bool] + self.post_connection_check = post_connection_check + ctxmap()[int(self.ctx)] = self + m2.ssl_ctx_set_cache_size(self.ctx, 128) + if weak_crypto is None and protocol in ('sslv23', 'tls'): + self.set_options(m2.SSL_OP_ALL | m2.SSL_OP_NO_SSLv2 | + m2.SSL_OP_NO_SSLv3) + + def __del__(self): + # type: () -> None + if getattr(self, 'ctx', None): + self.m2_ssl_ctx_free(self.ctx) + + def close(self): + # type: () -> None + del ctxmap()[int(self.ctx)] + + def load_cert(self, certfile, keyfile=None, + callback=util.passphrase_callback): + # type: (AnyStr, Optional[AnyStr], Callable) -> None + """Load certificate and private key into the context. + + :param certfile: File that contains the PEM-encoded certificate. + :param keyfile: File that contains the PEM-encoded private key. + Default value of None indicates that the private key + is to be found in 'certfile'. + :param callback: Callable object to be invoked if the private key is + passphrase-protected. Default callback provides a + simple terminal-style input for the passphrase. + """ + m2.ssl_ctx_passphrase_callback(self.ctx, callback) + m2.ssl_ctx_use_cert(self.ctx, certfile) + if not keyfile: + keyfile = certfile + m2.ssl_ctx_use_privkey(self.ctx, keyfile) + if not m2.ssl_ctx_check_privkey(self.ctx): + raise ValueError('public/private key mismatch') + + def load_cert_chain(self, certchainfile, keyfile=None, + callback=util.passphrase_callback): + # type: (AnyStr, Optional[AnyStr], Callable) -> None + """Load certificate chain and private key into the context. + + :param certchainfile: File object containing the PEM-encoded + certificate chain. + :param keyfile: File object containing the PEM-encoded private + key. Default value of None indicates that the + private key is to be found in 'certchainfile'. + :param callback: Callable object to be invoked if the private key + is passphrase-protected. Default callback + provides a simple terminal-style input for the + passphrase. + """ + m2.ssl_ctx_passphrase_callback(self.ctx, callback) + m2.ssl_ctx_use_cert_chain(self.ctx, certchainfile) + if not keyfile: + keyfile = certchainfile + m2.ssl_ctx_use_privkey(self.ctx, keyfile) + if not m2.ssl_ctx_check_privkey(self.ctx): + raise ValueError('public/private key mismatch') + + def set_client_CA_list_from_file(self, cafile): + # type: (AnyStr) -> None + """Load CA certs into the context. These CA certs are sent to the + peer during *SSLv3 certificate request*. + + :param cafile: File object containing one or more PEM-encoded CA + certificates concatenated together. + """ + m2.ssl_ctx_set_client_CA_list_from_file(self.ctx, cafile) + + # Deprecated. + load_client_CA = load_client_ca = set_client_CA_list_from_file + + def load_verify_locations(self, cafile=None, capath=None): + # type: (Optional[AnyStr], Optional[AnyStr]) -> int + """Load CA certs into the context. + + These CA certs are used during verification of the peer's + certificate. + + :param cafile: File containing one or more PEM-encoded CA + certificates concatenated together. + + :param capath: Directory containing PEM-encoded CA certificates + (one certificate per file). + + :return: 0 if the operation failed because CAfile and CApath are NULL + or the processing at one of the locations specified failed. + Check the error stack to find out the reason. + + 1 The operation succeeded. + """ + if cafile is None and capath is None: + raise ValueError("cafile and capath can not both be None.") + return m2.ssl_ctx_load_verify_locations(self.ctx, cafile, capath) + + # Deprecated. + load_verify_info = load_verify_locations + + def set_session_id_ctx(self, id): + # type: (bytes) -> None + """Sets the session id for the SSL.Context w/in a session can be reused. + + :param id: Sessions are generated within a certain context. When + exporting/importing sessions with + i2d_SSL_SESSION/d2i_SSL_SESSION it would be possible, + to re-import a session generated from another context + (e.g. another application), which might lead to + malfunctions. Therefore each application must set its + own session id context sid_ctx which is used to + distinguish the contexts and is stored in exported + sessions. The sid_ctx can be any kind of binary data + with a given length, it is therefore possible to use + e.g. the name of the application and/or the hostname + and/or service name. + """ + ret = m2.ssl_ctx_set_session_id_context(self.ctx, id) + if not ret: + raise Err.SSLError(Err.get_error_code(), '') + + def set_default_verify_paths(self): + # type: () -> int + """ + Specifies that the default locations from which CA certs are + loaded should be used. + + There is one default directory and one default file. The default + CA certificates directory is called "certs" in the default + OpenSSL directory. Alternatively the SSL_CERT_DIR environment + variable can be defined to override this location. The default + CA certificates file is called "cert.pem" in the default OpenSSL + directory. Alternatively the SSL_CERT_FILE environment variable + can be defined to override this location. + + @return 0 if the operation failed. A missing default location is + still treated as a success. No error code is set. + + 1 The operation succeeded. + """ + ret = m2.ssl_ctx_set_default_verify_paths(self.ctx) + if not ret: + raise ValueError('Cannot use default SSL certificate store!') + + def set_allow_unknown_ca(self, ok): + # type: (Union[int, bool]) -> None + """Set the context to accept/reject a peer certificate if the + certificate's CA is unknown. + + :param ok: True to accept, False to reject. + """ + self.allow_unknown_ca = ok + + def get_allow_unknown_ca(self): + # type: () -> Union[int, bool] + """Get the context's setting that accepts/rejects a peer + certificate if the certificate's CA is unknown. + + FIXME 2Bconverted to bool + """ + return self.allow_unknown_ca + + def set_verify(self, mode, depth, callback=None): + # type: (int, int, Optional[Callable]) -> None + """ + Set verify options. Most applications will need to call this + method with the right options to make a secure SSL connection. + + :param mode: The verification mode to use. Typically at least + SSL.verify_peer is used. Clients would also typically + add SSL.verify_fail_if_no_peer_cert. + :param depth: The maximum allowed depth of the certificate chain + returned by the peer. + :param callback: Callable that can be used to specify custom + verification checks. + """ + if callback is None: + m2.ssl_ctx_set_verify_default(self.ctx, mode) + else: + m2.ssl_ctx_set_verify(self.ctx, mode, callback) + m2.ssl_ctx_set_verify_depth(self.ctx, depth) + + def get_verify_mode(self): + # type: () -> int + return m2.ssl_ctx_get_verify_mode(self.ctx) + + def get_verify_depth(self): + # type: () -> int + """Returns the verification mode currently set in the SSL Context.""" + return m2.ssl_ctx_get_verify_depth(self.ctx) + + def set_tmp_dh(self, dhpfile): + # type: (AnyStr) -> int + """Load ephemeral DH parameters into the context. + + :param dhpfile: Filename of the file containing the PEM-encoded + DH parameters. + """ + f = BIO.openfile(dhpfile) + dhp = m2.dh_read_parameters(f.bio_ptr()) + return m2.ssl_ctx_set_tmp_dh(self.ctx, dhp) + + def set_tmp_dh_callback(self, callback=None): + # type: (Optional[Callable]) -> None + """Sets the callback function for SSL.Context. + + :param callback: Callable to be used when a DH parameters are required. + """ + if callback is not None: + m2.ssl_ctx_set_tmp_dh_callback(self.ctx, callback) + + def set_tmp_rsa(self, rsa): + # type: (RSA.RSA) -> int + """Load ephemeral RSA key into the context. + + :param rsa: RSA.RSA instance. + """ + if isinstance(rsa, RSA.RSA): + return m2.ssl_ctx_set_tmp_rsa(self.ctx, rsa.rsa) + else: + raise TypeError("Expected an instance of RSA.RSA, got %s." % rsa) + + def set_tmp_rsa_callback(self, callback=None): + # type: (Optional[Callable]) -> None + """Sets the callback function to be used when + a temporary/ephemeral RSA key is required. + """ + if callback is not None: + m2.ssl_ctx_set_tmp_rsa_callback(self.ctx, callback) + + def set_info_callback(self, callback=cb.ssl_info_callback): + # type: (Callable) -> None + """Set a callback function to get state information. + + It can be used to get state information about the SSL + connections that are created from this context. + + :param callback: Callback function. The default prints + information to stderr. + """ + m2.ssl_ctx_set_info_callback(self.ctx, callback) + + def set_cipher_list(self, cipher_list): + # type: (str) -> int + """Sets the list of available ciphers. + + :param cipher_list: The format of the string is described in + ciphers(1). + :return: 1 if any cipher could be selected and 0 on complete + failure. + """ + return m2.ssl_ctx_set_cipher_list(self.ctx, cipher_list) + + def add_session(self, session): + # type: (Session) -> int + """Add the session to the context. + + :param session: the session to be added. + + :return: 0 The operation failed. It was tried to add the same + (identical) session twice. + + 1 The operation succeeded. + """ + return m2.ssl_ctx_add_session(self.ctx, session._ptr()) + + def remove_session(self, session): + # type: (Session) -> int + """Remove the session from the context. + + :param session: the session to be removed. + + :return: 0 The operation failed. The session was not found in + the cache. + + 1 The operation succeeded. + """ + return m2.ssl_ctx_remove_session(self.ctx, session._ptr()) + + def get_session_timeout(self): + # type: () -> int + """Get current session timeout. + + Whenever a new session is created, it is assigned a maximum + lifetime. This lifetime is specified by storing the creation + time of the session and the timeout value valid at this time. If + the actual time is later than creation time plus timeout, the + session is not reused. + + Due to this realization, all sessions behave according to the + timeout value valid at the time of the session negotiation. + Changes of the timeout value do not affect already established + sessions. + + Expired sessions are removed from the internal session cache, + whenever SSL_CTX_flush_sessions(3) is called, either directly by + the application or automatically (see + SSL_CTX_set_session_cache_mode(3)) + + The default value for session timeout is decided on a per + protocol basis, see SSL_get_default_timeout(3). All currently + supported protocols have the same default timeout value of 300 + seconds. + + SSL_CTX_set_timeout() returns the previously set timeout value. + + :return: the currently set timeout value. + """ + return m2.ssl_ctx_get_session_timeout(self.ctx) + + def set_session_timeout(self, timeout): + # type: (int) -> int + """Set new session timeout. + + See self.get_session_timeout() for explanation of the session + timeouts. + + :param timeout: new timeout value. + + :return: the previously set timeout value. + """ + return m2.ssl_ctx_set_session_timeout(self.ctx, timeout) + + def set_session_cache_mode(self, mode): + # type: (int) -> int + """Enables/disables session caching. + + The mode is set by using m2.SSL_SESS_CACHE_* constants. + + :param mode: new mode value. + + :return: the previously set cache mode value. + """ + return m2.ssl_ctx_set_session_cache_mode(self.ctx, mode) + + def get_session_cache_mode(self): + # type: () -> int + """Gets the current session caching. + + The mode is set to m2.SSL_SESS_CACHE_* constants. + + :return: the previously set cache mode value. + """ + return m2.ssl_ctx_get_session_cache_mode(self.ctx) + + def set_options(self, op): + # type: (int) -> int + """Adds the options set via bitmask in options to the Context. + + !!! Options already set before are not cleared! + + The behaviour of the SSL library can be changed by setting + several options. The options are coded as bitmasks and can be + combined by a logical or operation (|). + + SSL.Context.set_options() and SSL.set_options() affect the + (external) protocol behaviour of the SSL library. The (internal) + behaviour of the API can be changed by using the similar + SSL.Context.set_mode() and SSL.set_mode() functions. + + During a handshake, the option settings of the SSL object are + used. When a new SSL object is created from a context using + SSL(), the current option setting is copied. Changes to ctx + do not affect already created SSL objects. SSL.clear() does not + affect the settings. + + :param op: bitmask of additional options specified in + SSL_CTX_set_options(3) manpage. + + :return: the new options bitmask after adding options. + """ + return m2.ssl_ctx_set_options(self.ctx, op) + + def get_cert_store(self): + # type: () -> X509.X509 + """ + Get the certificate store associated with this context. + + :warning: The store is NOT refcounted, and as such can not be relied + to be valid once the context goes away or is changed. + """ + return X509.X509_Store(m2.ssl_ctx_get_cert_store(self.ctx)) diff --git a/src/M2Crypto/SSL/SSLServer.py b/src/M2Crypto/SSL/SSLServer.py new file mode 100644 index 0000000..a44c9c7 --- /dev/null +++ b/src/M2Crypto/SSL/SSLServer.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import, print_function + +"""SSLServer + +Copyright (c) 1999-2002 Ng Pheng Siong. All rights reserved.""" + + +# M2Crypto +from M2Crypto.SSL import SSLError +from M2Crypto.SSL.Connection import Connection +from M2Crypto.SSL.Context import Context # noqa +# from M2Crypto import six # noqa +from M2Crypto import util # noqa +from M2Crypto.six.moves.socketserver import (BaseRequestHandler, BaseServer, + TCPServer, ThreadingMixIn) +import os +if os.name != 'nt': + from M2Crypto.six.moves.socketserver import ForkingMixIn +from socket import socket # noqa +from typing import Union # noqa + +__all__ = ['SSLServer', 'ForkingSSLServer', 'ThreadingSSLServer'] + + +class SSLServer(TCPServer): + def __init__(self, server_address, RequestHandlerClass, ssl_context, # noqa + bind_and_activate=True): + # type: (util.AddrType, BaseRequestHandler, Context, bool) -> None + """ + Superclass says: Constructor. May be extended, do not override. + This class says: Ho-hum. + """ + BaseServer.__init__(self, server_address, RequestHandlerClass) + self.ssl_ctx = ssl_context + self.socket = Connection(self.ssl_ctx) + if bind_and_activate: + self.server_bind() + self.server_activate() + + def handle_request(self): + # type: () -> None + request = None + client_address = None + try: + request, client_address = self.get_request() + if self.verify_request(request, client_address): + self.process_request(request, client_address) + except SSLError: + self.handle_error(request, client_address) + + def handle_error(self, request, client_address): + # type: (Union[socket, Connection], util.AddrType) -> None + print('-' * 40) + import traceback + traceback.print_exc() + print('-' * 40) + + +class ThreadingSSLServer(ThreadingMixIn, SSLServer): + pass + + +if os.name != 'nt': + class ForkingSSLServer(ForkingMixIn, SSLServer): + pass diff --git a/src/M2Crypto/SSL/Session.py b/src/M2Crypto/SSL/Session.py new file mode 100644 index 0000000..c364d4e --- /dev/null +++ b/src/M2Crypto/SSL/Session.py @@ -0,0 +1,69 @@ +"""SSL Session + +Copyright (c) 1999-2003 Ng Pheng Siong. All rights reserved.""" + +__all__ = ['Session', 'load_session'] + +from M2Crypto import BIO, Err, m2 +from M2Crypto.SSL import SSLError +from typing import AnyStr # noqa + + +class Session(object): + + m2_ssl_session_free = m2.ssl_session_free + + def __init__(self, session, _pyfree=0): + # type: (bytes, int) -> None + assert session is not None + self.session = session + self._pyfree = _pyfree + + def __del__(self): + # type: () -> None + if getattr(self, '_pyfree', 0): + self.m2_ssl_session_free(self.session) + + def _ptr(self): + # type: () -> bytes + return self.session + + def as_text(self): + # type: () -> bytes + buf = BIO.MemoryBuffer() + m2.ssl_session_print(buf.bio_ptr(), self.session) + return buf.read_all() + + def as_der(self): + # type: () -> bytes + buf = BIO.MemoryBuffer() + m2.i2d_ssl_session(buf.bio_ptr(), self.session) + return buf.read_all() + + def write_bio(self, bio): + # type: (BIO.BIO) -> int + return m2.ssl_session_write_bio(bio.bio_ptr(), self.session) + + def get_time(self): + # type: () -> int + return m2.ssl_session_get_time(self.session) + + def set_time(self, t): + # type: (int) -> int + return m2.ssl_session_set_time(self.session, t) + + def get_timeout(self): + # type: () -> int + return m2.ssl_session_get_timeout(self.session) + + def set_timeout(self, t): + # type: (int) -> int + return m2.ssl_session_set_timeout(self.session, t) + + +def load_session(pemfile): + # type: (AnyStr) -> Session + with BIO.openfile(pemfile) as f: + cptr = m2.ssl_session_read_pem(f.bio_ptr()) + + return Session(cptr, 1) diff --git a/src/M2Crypto/SSL/TwistedProtocolWrapper.py b/src/M2Crypto/SSL/TwistedProtocolWrapper.py new file mode 100644 index 0000000..5e94503 --- /dev/null +++ b/src/M2Crypto/SSL/TwistedProtocolWrapper.py @@ -0,0 +1,490 @@ +""" +Make Twisted use M2Crypto for SSL + +Copyright (c) 2004-2007 Open Source Applications Foundation. +All rights reserved. + +FIXME THIS HAS NOT BEEN FINISHED. NEITHER PEP484 NOR PORT PYTHON3 HAS +BEEN FINISHED. THE FURTHER WORK WILL BE DONE WHEN THE STATUS OF TWISTED +IN THE PYTHON 3 (AND ASYNCIO) WORLD WILL BE CLEAR. +""" + +__all__ = ['connectSSL', 'connectTCP', 'listenSSL', 'listenTCP', + 'TLSProtocolWrapper'] + +import logging + +from functools import partial + +import twisted.internet.reactor +import twisted.protocols.policies as policies + +from M2Crypto import BIO, X509, m2, util +from M2Crypto.SSL.Checker import Checker, SSLVerificationError + +from twisted.internet.interfaces import ITLSTransport +from twisted.protocols.policies import ProtocolWrapper +from typing import AnyStr, Callable, Iterable, Optional # noqa +from zope.interface import implementer + +log = logging.getLogger(__name__) + + +def _alwaysSucceedsPostConnectionCheck(peerX509, expectedHost): + return 1 + + +def connectSSL(host, port, factory, contextFactory, timeout=30, + bindAddress=None, + reactor=twisted.internet.reactor, + postConnectionCheck=Checker()): + # type: (str, int, object, object, int, Optional[str], twisted.internet.reactor, Checker) -> reactor.connectTCP + """ + A convenience function to start an SSL/TLS connection using Twisted. + + See IReactorSSL interface in Twisted. + """ + wrappingFactory = policies.WrappingFactory(factory) + wrappingFactory.protocol = lambda factory, wrappedProtocol: \ + TLSProtocolWrapper(factory, + wrappedProtocol, + startPassThrough=0, + client=1, + contextFactory=contextFactory, + postConnectionCheck=postConnectionCheck) + return reactor.connectTCP(host, port, wrappingFactory, timeout, bindAddress) + + +def connectTCP(host, port, factory, timeout=30, bindAddress=None, + reactor=twisted.internet.reactor, + postConnectionCheck=Checker()): + # type: (str, int, object, int, Optional[util.AddrType], object, Callable) -> object + """ + A convenience function to start a TCP connection using Twisted. + + NOTE: You must call startTLS(ctx) to go into SSL/TLS mode. + + See IReactorTCP interface in Twisted. + """ + wrappingFactory = policies.WrappingFactory(factory) + wrappingFactory.protocol = lambda factory, wrappedProtocol: \ + TLSProtocolWrapper(factory, + wrappedProtocol, + startPassThrough=1, + client=1, + contextFactory=None, + postConnectionCheck=postConnectionCheck) + return reactor.connectTCP(host, port, wrappingFactory, timeout, bindAddress) + + +def listenSSL(port, factory, contextFactory, backlog=5, interface='', + reactor=twisted.internet.reactor, + postConnectionCheck=_alwaysSucceedsPostConnectionCheck): + """ + A convenience function to listen for SSL/TLS connections using Twisted. + + See IReactorSSL interface in Twisted. + """ + wrappingFactory = policies.WrappingFactory(factory) + wrappingFactory.protocol = lambda factory, wrappedProtocol: \ + TLSProtocolWrapper(factory, + wrappedProtocol, + startPassThrough=0, + client=0, + contextFactory=contextFactory, + postConnectionCheck=postConnectionCheck) + return reactor.listenTCP(port, wrappingFactory, backlog, interface) + + +def listenTCP(port, factory, backlog=5, interface='', + reactor=twisted.internet.reactor, + postConnectionCheck=None): + """ + A convenience function to listen for TCP connections using Twisted. + + NOTE: You must call startTLS(ctx) to go into SSL/TLS mode. + + See IReactorTCP interface in Twisted. + """ + wrappingFactory = policies.WrappingFactory(factory) + wrappingFactory.protocol = lambda factory, wrappedProtocol: \ + TLSProtocolWrapper(factory, + wrappedProtocol, + startPassThrough=1, + client=0, + contextFactory=None, + postConnectionCheck=postConnectionCheck) + return reactor.listenTCP(port, wrappingFactory, backlog, interface) + + +class _BioProxy(object): + """ + The purpose of this class is to eliminate the __del__ method from + TLSProtocolWrapper, and thus letting it be garbage collected. + """ + + m2_bio_free_all = m2.bio_free_all + + def __init__(self, bio): + self.bio = bio + + def _ptr(self): + return self.bio + + def __del__(self): + if self.bio is not None: + self.m2_bio_free_all(self.bio) + + +class _SSLProxy(object): + """ + The purpose of this class is to eliminate the __del__ method from + TLSProtocolWrapper, and thus letting it be garbage collected. + """ + + m2_ssl_free = m2.ssl_free + + def __init__(self, ssl): + self.ssl = ssl + + def _ptr(self): + return self.ssl + + def __del__(self): + if self.ssl is not None: + self.m2_ssl_free(self.ssl) + + +@implementer(ITLSTransport) +class TLSProtocolWrapper(ProtocolWrapper): + """ + A SSL/TLS protocol wrapper to be used with Twisted. Typically + you would not use this class directly. Use connectTCP, + connectSSL, listenTCP, listenSSL functions defined above, + which will hook in this class. + """ + + def __init__(self, factory, wrappedProtocol, startPassThrough, client, + contextFactory, postConnectionCheck): + # type: (policies.WrappingFactory, object, int, int, object, Checker) -> None + """ + :param factory: + :param wrappedProtocol: + :param startPassThrough: If true we won't encrypt at all. Need to + call startTLS() later to switch to SSL/TLS. + :param client: True if this should be a client protocol. + :param contextFactory: Factory that creates SSL.Context objects. + The called function is getContext(). + :param postConnectionCheck: The post connection check callback that + will be called just after connection has + been established but before any real data + has been exchanged. The first argument to + this function is an X509 object, the second + is the expected host name string. + """ + # ProtocolWrapper.__init__(self, factory, wrappedProtocol) + # XXX: Twisted 2.0 has a new addition where the wrappingFactory is + # set as the factory of the wrappedProtocol. This is an issue + # as the wrap should be transparent. What we want is + # the factory of the wrappedProtocol to be the wrappedFactory and + # not the outer wrappingFactory. This is how it was implemented in + # Twisted 1.3 + self.factory = factory + self.wrappedProtocol = wrappedProtocol + + # wrappedProtocol == client/server instance + # factory.wrappedFactory == client/server factory + + self.data = b'' # Clear text to encrypt and send + self.encrypted = b'' # Encrypted data we need to decrypt and pass on + self.tlsStarted = 0 # SSL/TLS mode or pass through + self.checked = 0 # Post connection check done or not + self.isClient = client + self.helloDone = 0 # True when hello has been sent + if postConnectionCheck is None: + self.postConnectionCheck = _alwaysSucceedsPostConnectionCheck + else: + self.postConnectionCheck = postConnectionCheck + + if not startPassThrough: + self.startTLS(contextFactory.getContext()) + + def clear(self): + """ + Clear this instance, after which it is ready for reuse. + """ + if getattr(self, 'tlsStarted', 0): + self.sslBio = None + self.ssl = None + self.internalBio = None + self.networkBio = None + self.data = b'' + self.encrypted = b'' + self.tlsStarted = 0 + self.checked = 0 + self.isClient = 1 + self.helloDone = 0 + # We can reuse self.ctx and it will be deleted automatically + # when this instance dies + + def startTLS(self, ctx): + """ + Start SSL/TLS. If this is not called, this instance just passes data + through untouched. + """ + # NOTE: This method signature must match the startTLS() method Twisted + # expects transports to have. This will be called automatically + # by Twisted in STARTTLS situations, for example with SMTP. + if self.tlsStarted: + raise Exception('TLS already started') + + self.ctx = ctx + + self.internalBio = m2.bio_new(m2.bio_s_bio()) + m2.bio_set_write_buf_size(self.internalBio, 0) + self.networkBio = _BioProxy(m2.bio_new(m2.bio_s_bio())) + m2.bio_set_write_buf_size(self.networkBio._ptr(), 0) + m2.bio_make_bio_pair(self.internalBio, self.networkBio._ptr()) + + self.sslBio = _BioProxy(m2.bio_new(m2.bio_f_ssl())) + + self.ssl = _SSLProxy(m2.ssl_new(self.ctx.ctx)) + + if self.isClient: + m2.ssl_set_connect_state(self.ssl._ptr()) + else: + m2.ssl_set_accept_state(self.ssl._ptr()) + + m2.ssl_set_bio(self.ssl._ptr(), self.internalBio, self.internalBio) + m2.bio_set_ssl(self.sslBio._ptr(), self.ssl._ptr(), m2.bio_noclose) + + # Need this for writes that are larger than BIO pair buffers + mode = m2.ssl_get_mode(self.ssl._ptr()) + m2.ssl_set_mode(self.ssl._ptr(), + mode | + m2.SSL_MODE_ENABLE_PARTIAL_WRITE | + m2.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER) + + self.tlsStarted = 1 + + def write(self, data): + # type: (bytes) -> None + if not self.tlsStarted: + ProtocolWrapper.write(self, data) + return + + try: + encryptedData = self._encrypt(data) + ProtocolWrapper.write(self, encryptedData) + self.helloDone = 1 + except BIO.BIOError as e: + # See http://www.openssl.org/docs/apps/verify.html#DIAGNOSTICS + # for the error codes returned by SSL_get_verify_result. + e.args = (m2.ssl_get_verify_result(self.ssl._ptr()), e.args[0]) + raise e + + def writeSequence(self, data): + # type: (Iterable[bytes]) -> None + if not self.tlsStarted: + ProtocolWrapper.writeSequence(self, b''.join(data)) + return + + self.write(b''.join(data)) + + def loseConnection(self): + # XXX Do we need to do m2.ssl_shutdown(self.ssl._ptr())? + ProtocolWrapper.loseConnection(self) + + def connectionMade(self): + ProtocolWrapper.connectionMade(self) + if self.tlsStarted and self.isClient and not self.helloDone: + self._clientHello() + + def dataReceived(self, data): + # type: (bytes) -> None + if not self.tlsStarted: + ProtocolWrapper.dataReceived(self, data) + return + + self.encrypted += data + + try: + while 1: + decryptedData = self._decrypt() + + self._check() + + encryptedData = self._encrypt() + ProtocolWrapper.write(self, encryptedData) + + ProtocolWrapper.dataReceived(self, decryptedData) + + if decryptedData == b'' and encryptedData == b'': + break + except BIO.BIOError as e: + # See http://www.openssl.org/docs/apps/verify.html#DIAGNOSTICS + # for the error codes returned by SSL_get_verify_result. + e.args = (m2.ssl_get_verify_result(self.ssl._ptr()), e.args[0]) + raise e + + def connectionLost(self, reason): + # type: (AnyStr) -> None + self.clear() + ProtocolWrapper.connectionLost(self, reason) + + def _check(self): + if not self.checked and m2.ssl_is_init_finished(self.ssl._ptr()): + x509 = m2.ssl_get_peer_cert(self.ssl._ptr()) + if x509 is not None: + x509 = X509.X509(x509, 1) + if self.isClient: + host = self.transport.addr[0] + else: + host = self.transport.getPeer().host + if not self.postConnectionCheck(x509, host): + raise SSLVerificationError('post connection check') + self.checked = 1 + + def _clientHello(self): + try: + # We rely on OpenSSL implicitly starting with client hello + # when we haven't yet established an SSL connection + encryptedData = self._encrypt(clientHello=1) + ProtocolWrapper.write(self, encryptedData) + self.helloDone = 1 + except BIO.BIOError as e: + # See http://www.openssl.org/docs/apps/verify.html#DIAGNOSTICS + # for the error codes returned by SSL_get_verify_result. + e.args = (m2.ssl_get_verify_result(self.ssl._ptr()), e.args[0]) + raise e + + # Optimizations to reduce attribute accesses + + @property + def _get_wr_guar_ssl(self): + # type: () -> Callable[[], int] + """Return max. length of data can be written to the BIO. + + Writes larger than this value will return a value from + BIO_write() less than the amount requested or if the buffer is + full request a retry. + """ + return partial(m2.bio_ctrl_get_write_guarantee, + self.sslBio._ptr()) + + @property + def _get_wr_guar_net(self): + # type: () -> Callable[[], int] + return partial(m2.bio_ctrl_get_write_guarantee, + self.networkBio._ptr()) + + @property + def _shoud_retry_ssl(self): + # type: () -> Callable[[], int] + # BIO_should_retry() is true if the call that produced this + # condition should then be retried at a later time. + return partial(m2.bio_should_retry, self.sslBio._ptr()) + + @property + def _shoud_retry_net(self): + # type: () -> Callable[[], int] + return partial(m2.bio_should_retry, self.networkBio._ptr()) + + @property + def _ctrl_pend_ssl(self): + # type: () -> Callable[[], int] + # size_t BIO_ctrl_pending(BIO *b); + # BIO_ctrl_pending() return the number of pending characters in + # the BIOs read and write buffers. + return partial(m2.bio_ctrl_pending, self.sslBio._ptr()) + + @property + def _ctrl_pend_net(self): + # type: () -> Callable[[], int] + return partial(m2.bio_ctrl_pending, self.networkBio._ptr()) + + @property + def _write_ssl(self): + # type: () -> Callable[[bytes], int] + # All these functions return either the amount of data + # successfully read or written (if the return value is + # positive) or that no data was successfully read or written + # if the result is 0 or -1. If the return value is -2 then + # the operation is not implemented in the specific BIO type. + return partial(m2.bio_write, self.sslBio._ptr()) + + @property + def _write_net(self): + # type: () -> Callable[[bytes], int] + return partial(m2.bio_write, self.networkBio._ptr()) + + @property + def _read_ssl(self): + # type: () -> Callable[[int], Optional[bytes]] + return partial(m2.bio_read, self.sslBio._ptr()) + + @property + def _read_net(self): + # type: () -> Callable[[int], Optional[bytes]] + return partial(m2.bio_read, self.networkBio._ptr()) + + def _encrypt(self, data=b'', clientHello=0): + # type: (bytes, int) -> bytes + """ + :param data: + :param clientHello: + :return: + """ + encryptedData = b'' + self.data += data + + while 1: + if (self._get_wr_guar_ssl() > 0 and self.data != b'') or clientHello: + r = self._write_ssl(self.data) + if r <= 0: + if not self._shoud_retry_ssl(): + raise IOError( + ('Data left to be written to {}, ' + + 'but cannot retry SSL connection!').format(self.sslBio)) + else: + assert self.checked + self.data = self.data[r:] + + pending = self._ctrl_pend_net() + if pending: + d = self._read_net(pending) + if d is not None: # This is strange, but d can be None + encryptedData += d + else: + assert(self._shoud_retry_net()) + else: + break + return encryptedData + + def _decrypt(self, data=b''): + # type: (bytes) -> bytes + self.encrypted += data + decryptedData = b'' + + while 1: + if self._get_wr_guar_ssl() > 0 and self.encrypted != b'': + r = self._write_net(self.encrypted) + if r <= 0: + if not self._shoud_retry_net(): + raise IOError( + ('Data left to be written to {}, ' + + 'but cannot retry SSL connection!').format(self.networkBio)) + else: + self.encrypted = self.encrypted[r:] + + pending = self._ctrl_pend_ssl() + if pending: + d = self._read_ssl(pending) + if d is not None: # This is strange, but d can be None + decryptedData += d + else: + assert(self._shoud_retry_ssl()) + else: + break + + return decryptedData diff --git a/src/M2Crypto/SSL/__init__.py b/src/M2Crypto/SSL/__init__.py new file mode 100644 index 0000000..33bf1b9 --- /dev/null +++ b/src/M2Crypto/SSL/__init__.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import + +"""M2Crypto SSL services. + +Copyright (c) 1999-2004 Ng Pheng Siong. All rights reserved.""" + +import socket, os + +# M2Crypto +from M2Crypto import _m2crypto as m2 + + +class SSLError(Exception): + pass + + +class SSLTimeoutError(SSLError, socket.timeout): + pass + +m2.ssl_init(SSLError, SSLTimeoutError) + +# M2Crypto.SSL +from M2Crypto.SSL.Cipher import Cipher, Cipher_Stack +from M2Crypto.SSL.Connection import Connection +from M2Crypto.SSL.Context import Context +from M2Crypto.SSL.SSLServer import SSLServer, ThreadingSSLServer +if os.name != 'nt': + from M2Crypto.SSL.SSLServer import ForkingSSLServer +from M2Crypto.SSL.ssl_dispatcher import ssl_dispatcher +from M2Crypto.SSL.timeout import timeout, struct_to_timeout, struct_size + +verify_none = m2.SSL_VERIFY_NONE # type: int +verify_peer = m2.SSL_VERIFY_PEER # type: int +verify_fail_if_no_peer_cert = m2.SSL_VERIFY_FAIL_IF_NO_PEER_CERT # type: int +verify_client_once = m2.SSL_VERIFY_CLIENT_ONCE # type: int +verify_crl_check_chain = m2.VERIFY_CRL_CHECK_CHAIN # type: int +verify_crl_check_leaf = m2.VERIFY_CRL_CHECK_LEAF # type: int + +SSL_SENT_SHUTDOWN = m2.SSL_SENT_SHUTDOWN # type: int +SSL_RECEIVED_SHUTDOWN = m2.SSL_RECEIVED_SHUTDOWN # type: int + +op_all = m2.SSL_OP_ALL # type: int +op_no_sslv2 = m2.SSL_OP_NO_SSLv2 # type: int + diff --git a/src/M2Crypto/SSL/cb.py b/src/M2Crypto/SSL/cb.py new file mode 100644 index 0000000..47f102e --- /dev/null +++ b/src/M2Crypto/SSL/cb.py @@ -0,0 +1,96 @@ +from __future__ import absolute_import + +"""SSL callbacks + +Copyright (c) 1999-2003 Ng Pheng Siong. All rights reserved.""" + +import sys + +from M2Crypto import m2 +from typing import Any # noqa + +__all__ = ['unknown_issuer', 'ssl_verify_callback_stub', 'ssl_verify_callback', + 'ssl_verify_callback_allow_unknown_ca', 'ssl_info_callback'] + + +def ssl_verify_callback_stub(ssl_ctx_ptr, x509_ptr, errnum, errdepth, ok): + # Deprecated + return ok + + +unknown_issuer = [ + m2.X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT, + m2.X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY, + m2.X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE, + m2.X509_V_ERR_CERT_UNTRUSTED, +] + + +def ssl_verify_callback(ssl_ctx_ptr, x509_ptr, errnum, errdepth, ok): + # type: (bytes, bytes, int, int, int) -> int + # Deprecated + + from M2Crypto.SSL.Context import Context + ssl_ctx = Context.ctxmap()[int(ssl_ctx_ptr)] + if errnum in unknown_issuer: + if ssl_ctx.get_allow_unknown_ca(): + sys.stderr.write("policy: %s: permitted...\n" % + (m2.x509_get_verify_error(errnum))) + sys.stderr.flush() + ok = 1 + # CRL checking goes here... + if ok: + if ssl_ctx.get_verify_depth() >= errdepth: + ok = 1 + else: + ok = 0 + return ok + + +def ssl_verify_callback_allow_unknown_ca(ok, store): + # type: (int, Any) -> int + errnum = store.get_error() + if errnum in unknown_issuer: + ok = 1 + return ok + + +# Cribbed from OpenSSL's apps/s_cb.c. +def ssl_info_callback(where, ret, ssl_ptr): + # type: (int, int, bytes) -> None + + w = where & ~m2.SSL_ST_MASK + if w & m2.SSL_ST_CONNECT: + state = "SSL connect" + elif w & m2.SSL_ST_ACCEPT: + state = "SSL accept" + else: + state = "SSL state unknown" + + if where & m2.SSL_CB_LOOP: + sys.stderr.write("LOOP: %s: %s\n" % + (state, m2.ssl_get_state_v(ssl_ptr))) + sys.stderr.flush() + return + + if where & m2.SSL_CB_EXIT: + if not ret: + sys.stderr.write("FAILED: %s: %s\n" % + (state, m2.ssl_get_state_v(ssl_ptr))) + sys.stderr.flush() + else: + sys.stderr.write("INFO: %s: %s\n" % + (state, m2.ssl_get_state_v(ssl_ptr))) + sys.stderr.flush() + return + + if where & m2.SSL_CB_ALERT: + if where & m2.SSL_CB_READ: + w = 'read' + else: + w = 'write' + sys.stderr.write("ALERT: %s: %s: %s\n" % + (w, m2.ssl_get_alert_type_v(ret), + m2.ssl_get_alert_desc_v(ret))) + sys.stderr.flush() + return diff --git a/src/M2Crypto/SSL/ssl_dispatcher.py b/src/M2Crypto/SSL/ssl_dispatcher.py new file mode 100644 index 0000000..73a4b82 --- /dev/null +++ b/src/M2Crypto/SSL/ssl_dispatcher.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +"""SSL dispatcher + +Copyright (c) 1999-2002 Ng Pheng Siong. All rights reserved.""" + +# Python +import asyncore +import socket + +# M2Crypto +from M2Crypto import util # noqa +from M2Crypto.SSL.Connection import Connection +from M2Crypto.SSL.Context import Context # noqa + +__all__ = ['ssl_dispatcher'] + + +class ssl_dispatcher(asyncore.dispatcher): + + def create_socket(self, ssl_context): + # type: (Context) -> None + self.family_and_type = socket.AF_INET, socket.SOCK_STREAM + self.ssl_ctx = ssl_context + self.socket = Connection(self.ssl_ctx) + # self.socket.setblocking(0) + self.add_channel() + + def connect(self, addr): + # type: (util.AddrType) -> None + self.socket.setblocking(1) + self.socket.connect(addr) + self.socket.setblocking(0) + + def recv(self, buffer_size=4096): + # type: (int) -> bytes + """Receive data over SSL.""" + return self.socket.recv(buffer_size) + + def send(self, buffer): + # type: (bytes) -> int + """Send data over SSL.""" + return self.socket.send(buffer) diff --git a/src/M2Crypto/SSL/timeout.py b/src/M2Crypto/SSL/timeout.py new file mode 100644 index 0000000..42b0291 --- /dev/null +++ b/src/M2Crypto/SSL/timeout.py @@ -0,0 +1,50 @@ +"""Support for SSL socket timeouts. + +Copyright (c) 1999-2003 Ng Pheng Siong. All rights reserved. + +Copyright 2008 Heikki Toivonen. All rights reserved. +""" + +__all__ = ['DEFAULT_TIMEOUT', 'timeout', 'struct_to_timeout', 'struct_size'] + +import sys +import struct + +DEFAULT_TIMEOUT = 600 # type: int + + +class timeout(object): + + def __init__(self, sec=DEFAULT_TIMEOUT, microsec=0): + # type: (int, int) -> None + self.sec = sec + self.microsec = microsec + + def pack(self): + if sys.platform == 'win32': + millisec = int(self.sec * 1000 + round(float(self.microsec) / 1000)) + binstr = struct.pack('l', millisec) + else: + binstr = struct.pack('ll', self.sec, self.microsec) + return binstr + + +def struct_to_timeout(binstr): + # type: (bytes) -> timeout + if sys.platform == 'win32': + millisec = struct.unpack('l', binstr)[0] + # On py3, int/int performs exact division and returns float. We want + # the whole number portion of the exact division result: + sec = int(millisec / 1000) + microsec = (millisec % 1000) * 1000 + else: + (sec, microsec) = struct.unpack('ll', binstr) + return timeout(sec, microsec) + + +def struct_size(): + # type: () -> int + if sys.platform == 'win32': + return struct.calcsize('l') + else: + return struct.calcsize('ll') |