summaryrefslogtreecommitdiff
path: root/src/M2Crypto/SSL/TwistedProtocolWrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/M2Crypto/SSL/TwistedProtocolWrapper.py')
-rw-r--r--src/M2Crypto/SSL/TwistedProtocolWrapper.py490
1 files changed, 490 insertions, 0 deletions
diff --git a/src/M2Crypto/SSL/TwistedProtocolWrapper.py b/src/M2Crypto/SSL/TwistedProtocolWrapper.py
new file mode 100644
index 0000000..5e94503
--- /dev/null
+++ b/src/M2Crypto/SSL/TwistedProtocolWrapper.py
@@ -0,0 +1,490 @@
+"""
+Make Twisted use M2Crypto for SSL
+
+Copyright (c) 2004-2007 Open Source Applications Foundation.
+All rights reserved.
+
+FIXME THIS HAS NOT BEEN FINISHED. NEITHER PEP484 NOR PORT PYTHON3 HAS
+BEEN FINISHED. THE FURTHER WORK WILL BE DONE WHEN THE STATUS OF TWISTED
+IN THE PYTHON 3 (AND ASYNCIO) WORLD WILL BE CLEAR.
+"""
+
+__all__ = ['connectSSL', 'connectTCP', 'listenSSL', 'listenTCP',
+ 'TLSProtocolWrapper']
+
+import logging
+
+from functools import partial
+
+import twisted.internet.reactor
+import twisted.protocols.policies as policies
+
+from M2Crypto import BIO, X509, m2, util
+from M2Crypto.SSL.Checker import Checker, SSLVerificationError
+
+from twisted.internet.interfaces import ITLSTransport
+from twisted.protocols.policies import ProtocolWrapper
+from typing import AnyStr, Callable, Iterable, Optional # noqa
+from zope.interface import implementer
+
+log = logging.getLogger(__name__)
+
+
+def _alwaysSucceedsPostConnectionCheck(peerX509, expectedHost):
+ return 1
+
+
+def connectSSL(host, port, factory, contextFactory, timeout=30,
+ bindAddress=None,
+ reactor=twisted.internet.reactor,
+ postConnectionCheck=Checker()):
+ # type: (str, int, object, object, int, Optional[str], twisted.internet.reactor, Checker) -> reactor.connectTCP
+ """
+ A convenience function to start an SSL/TLS connection using Twisted.
+
+ See IReactorSSL interface in Twisted.
+ """
+ wrappingFactory = policies.WrappingFactory(factory)
+ wrappingFactory.protocol = lambda factory, wrappedProtocol: \
+ TLSProtocolWrapper(factory,
+ wrappedProtocol,
+ startPassThrough=0,
+ client=1,
+ contextFactory=contextFactory,
+ postConnectionCheck=postConnectionCheck)
+ return reactor.connectTCP(host, port, wrappingFactory, timeout, bindAddress)
+
+
+def connectTCP(host, port, factory, timeout=30, bindAddress=None,
+ reactor=twisted.internet.reactor,
+ postConnectionCheck=Checker()):
+ # type: (str, int, object, int, Optional[util.AddrType], object, Callable) -> object
+ """
+ A convenience function to start a TCP connection using Twisted.
+
+ NOTE: You must call startTLS(ctx) to go into SSL/TLS mode.
+
+ See IReactorTCP interface in Twisted.
+ """
+ wrappingFactory = policies.WrappingFactory(factory)
+ wrappingFactory.protocol = lambda factory, wrappedProtocol: \
+ TLSProtocolWrapper(factory,
+ wrappedProtocol,
+ startPassThrough=1,
+ client=1,
+ contextFactory=None,
+ postConnectionCheck=postConnectionCheck)
+ return reactor.connectTCP(host, port, wrappingFactory, timeout, bindAddress)
+
+
+def listenSSL(port, factory, contextFactory, backlog=5, interface='',
+ reactor=twisted.internet.reactor,
+ postConnectionCheck=_alwaysSucceedsPostConnectionCheck):
+ """
+ A convenience function to listen for SSL/TLS connections using Twisted.
+
+ See IReactorSSL interface in Twisted.
+ """
+ wrappingFactory = policies.WrappingFactory(factory)
+ wrappingFactory.protocol = lambda factory, wrappedProtocol: \
+ TLSProtocolWrapper(factory,
+ wrappedProtocol,
+ startPassThrough=0,
+ client=0,
+ contextFactory=contextFactory,
+ postConnectionCheck=postConnectionCheck)
+ return reactor.listenTCP(port, wrappingFactory, backlog, interface)
+
+
+def listenTCP(port, factory, backlog=5, interface='',
+ reactor=twisted.internet.reactor,
+ postConnectionCheck=None):
+ """
+ A convenience function to listen for TCP connections using Twisted.
+
+ NOTE: You must call startTLS(ctx) to go into SSL/TLS mode.
+
+ See IReactorTCP interface in Twisted.
+ """
+ wrappingFactory = policies.WrappingFactory(factory)
+ wrappingFactory.protocol = lambda factory, wrappedProtocol: \
+ TLSProtocolWrapper(factory,
+ wrappedProtocol,
+ startPassThrough=1,
+ client=0,
+ contextFactory=None,
+ postConnectionCheck=postConnectionCheck)
+ return reactor.listenTCP(port, wrappingFactory, backlog, interface)
+
+
+class _BioProxy(object):
+ """
+ The purpose of this class is to eliminate the __del__ method from
+ TLSProtocolWrapper, and thus letting it be garbage collected.
+ """
+
+ m2_bio_free_all = m2.bio_free_all
+
+ def __init__(self, bio):
+ self.bio = bio
+
+ def _ptr(self):
+ return self.bio
+
+ def __del__(self):
+ if self.bio is not None:
+ self.m2_bio_free_all(self.bio)
+
+
+class _SSLProxy(object):
+ """
+ The purpose of this class is to eliminate the __del__ method from
+ TLSProtocolWrapper, and thus letting it be garbage collected.
+ """
+
+ m2_ssl_free = m2.ssl_free
+
+ def __init__(self, ssl):
+ self.ssl = ssl
+
+ def _ptr(self):
+ return self.ssl
+
+ def __del__(self):
+ if self.ssl is not None:
+ self.m2_ssl_free(self.ssl)
+
+
+@implementer(ITLSTransport)
+class TLSProtocolWrapper(ProtocolWrapper):
+ """
+ A SSL/TLS protocol wrapper to be used with Twisted. Typically
+ you would not use this class directly. Use connectTCP,
+ connectSSL, listenTCP, listenSSL functions defined above,
+ which will hook in this class.
+ """
+
+ def __init__(self, factory, wrappedProtocol, startPassThrough, client,
+ contextFactory, postConnectionCheck):
+ # type: (policies.WrappingFactory, object, int, int, object, Checker) -> None
+ """
+ :param factory:
+ :param wrappedProtocol:
+ :param startPassThrough: If true we won't encrypt at all. Need to
+ call startTLS() later to switch to SSL/TLS.
+ :param client: True if this should be a client protocol.
+ :param contextFactory: Factory that creates SSL.Context objects.
+ The called function is getContext().
+ :param postConnectionCheck: The post connection check callback that
+ will be called just after connection has
+ been established but before any real data
+ has been exchanged. The first argument to
+ this function is an X509 object, the second
+ is the expected host name string.
+ """
+ # ProtocolWrapper.__init__(self, factory, wrappedProtocol)
+ # XXX: Twisted 2.0 has a new addition where the wrappingFactory is
+ # set as the factory of the wrappedProtocol. This is an issue
+ # as the wrap should be transparent. What we want is
+ # the factory of the wrappedProtocol to be the wrappedFactory and
+ # not the outer wrappingFactory. This is how it was implemented in
+ # Twisted 1.3
+ self.factory = factory
+ self.wrappedProtocol = wrappedProtocol
+
+ # wrappedProtocol == client/server instance
+ # factory.wrappedFactory == client/server factory
+
+ self.data = b'' # Clear text to encrypt and send
+ self.encrypted = b'' # Encrypted data we need to decrypt and pass on
+ self.tlsStarted = 0 # SSL/TLS mode or pass through
+ self.checked = 0 # Post connection check done or not
+ self.isClient = client
+ self.helloDone = 0 # True when hello has been sent
+ if postConnectionCheck is None:
+ self.postConnectionCheck = _alwaysSucceedsPostConnectionCheck
+ else:
+ self.postConnectionCheck = postConnectionCheck
+
+ if not startPassThrough:
+ self.startTLS(contextFactory.getContext())
+
+ def clear(self):
+ """
+ Clear this instance, after which it is ready for reuse.
+ """
+ if getattr(self, 'tlsStarted', 0):
+ self.sslBio = None
+ self.ssl = None
+ self.internalBio = None
+ self.networkBio = None
+ self.data = b''
+ self.encrypted = b''
+ self.tlsStarted = 0
+ self.checked = 0
+ self.isClient = 1
+ self.helloDone = 0
+ # We can reuse self.ctx and it will be deleted automatically
+ # when this instance dies
+
+ def startTLS(self, ctx):
+ """
+ Start SSL/TLS. If this is not called, this instance just passes data
+ through untouched.
+ """
+ # NOTE: This method signature must match the startTLS() method Twisted
+ # expects transports to have. This will be called automatically
+ # by Twisted in STARTTLS situations, for example with SMTP.
+ if self.tlsStarted:
+ raise Exception('TLS already started')
+
+ self.ctx = ctx
+
+ self.internalBio = m2.bio_new(m2.bio_s_bio())
+ m2.bio_set_write_buf_size(self.internalBio, 0)
+ self.networkBio = _BioProxy(m2.bio_new(m2.bio_s_bio()))
+ m2.bio_set_write_buf_size(self.networkBio._ptr(), 0)
+ m2.bio_make_bio_pair(self.internalBio, self.networkBio._ptr())
+
+ self.sslBio = _BioProxy(m2.bio_new(m2.bio_f_ssl()))
+
+ self.ssl = _SSLProxy(m2.ssl_new(self.ctx.ctx))
+
+ if self.isClient:
+ m2.ssl_set_connect_state(self.ssl._ptr())
+ else:
+ m2.ssl_set_accept_state(self.ssl._ptr())
+
+ m2.ssl_set_bio(self.ssl._ptr(), self.internalBio, self.internalBio)
+ m2.bio_set_ssl(self.sslBio._ptr(), self.ssl._ptr(), m2.bio_noclose)
+
+ # Need this for writes that are larger than BIO pair buffers
+ mode = m2.ssl_get_mode(self.ssl._ptr())
+ m2.ssl_set_mode(self.ssl._ptr(),
+ mode |
+ m2.SSL_MODE_ENABLE_PARTIAL_WRITE |
+ m2.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER)
+
+ self.tlsStarted = 1
+
+ def write(self, data):
+ # type: (bytes) -> None
+ if not self.tlsStarted:
+ ProtocolWrapper.write(self, data)
+ return
+
+ try:
+ encryptedData = self._encrypt(data)
+ ProtocolWrapper.write(self, encryptedData)
+ self.helloDone = 1
+ except BIO.BIOError as e:
+ # See http://www.openssl.org/docs/apps/verify.html#DIAGNOSTICS
+ # for the error codes returned by SSL_get_verify_result.
+ e.args = (m2.ssl_get_verify_result(self.ssl._ptr()), e.args[0])
+ raise e
+
+ def writeSequence(self, data):
+ # type: (Iterable[bytes]) -> None
+ if not self.tlsStarted:
+ ProtocolWrapper.writeSequence(self, b''.join(data))
+ return
+
+ self.write(b''.join(data))
+
+ def loseConnection(self):
+ # XXX Do we need to do m2.ssl_shutdown(self.ssl._ptr())?
+ ProtocolWrapper.loseConnection(self)
+
+ def connectionMade(self):
+ ProtocolWrapper.connectionMade(self)
+ if self.tlsStarted and self.isClient and not self.helloDone:
+ self._clientHello()
+
+ def dataReceived(self, data):
+ # type: (bytes) -> None
+ if not self.tlsStarted:
+ ProtocolWrapper.dataReceived(self, data)
+ return
+
+ self.encrypted += data
+
+ try:
+ while 1:
+ decryptedData = self._decrypt()
+
+ self._check()
+
+ encryptedData = self._encrypt()
+ ProtocolWrapper.write(self, encryptedData)
+
+ ProtocolWrapper.dataReceived(self, decryptedData)
+
+ if decryptedData == b'' and encryptedData == b'':
+ break
+ except BIO.BIOError as e:
+ # See http://www.openssl.org/docs/apps/verify.html#DIAGNOSTICS
+ # for the error codes returned by SSL_get_verify_result.
+ e.args = (m2.ssl_get_verify_result(self.ssl._ptr()), e.args[0])
+ raise e
+
+ def connectionLost(self, reason):
+ # type: (AnyStr) -> None
+ self.clear()
+ ProtocolWrapper.connectionLost(self, reason)
+
+ def _check(self):
+ if not self.checked and m2.ssl_is_init_finished(self.ssl._ptr()):
+ x509 = m2.ssl_get_peer_cert(self.ssl._ptr())
+ if x509 is not None:
+ x509 = X509.X509(x509, 1)
+ if self.isClient:
+ host = self.transport.addr[0]
+ else:
+ host = self.transport.getPeer().host
+ if not self.postConnectionCheck(x509, host):
+ raise SSLVerificationError('post connection check')
+ self.checked = 1
+
+ def _clientHello(self):
+ try:
+ # We rely on OpenSSL implicitly starting with client hello
+ # when we haven't yet established an SSL connection
+ encryptedData = self._encrypt(clientHello=1)
+ ProtocolWrapper.write(self, encryptedData)
+ self.helloDone = 1
+ except BIO.BIOError as e:
+ # See http://www.openssl.org/docs/apps/verify.html#DIAGNOSTICS
+ # for the error codes returned by SSL_get_verify_result.
+ e.args = (m2.ssl_get_verify_result(self.ssl._ptr()), e.args[0])
+ raise e
+
+ # Optimizations to reduce attribute accesses
+
+ @property
+ def _get_wr_guar_ssl(self):
+ # type: () -> Callable[[], int]
+ """Return max. length of data can be written to the BIO.
+
+ Writes larger than this value will return a value from
+ BIO_write() less than the amount requested or if the buffer is
+ full request a retry.
+ """
+ return partial(m2.bio_ctrl_get_write_guarantee,
+ self.sslBio._ptr())
+
+ @property
+ def _get_wr_guar_net(self):
+ # type: () -> Callable[[], int]
+ return partial(m2.bio_ctrl_get_write_guarantee,
+ self.networkBio._ptr())
+
+ @property
+ def _shoud_retry_ssl(self):
+ # type: () -> Callable[[], int]
+ # BIO_should_retry() is true if the call that produced this
+ # condition should then be retried at a later time.
+ return partial(m2.bio_should_retry, self.sslBio._ptr())
+
+ @property
+ def _shoud_retry_net(self):
+ # type: () -> Callable[[], int]
+ return partial(m2.bio_should_retry, self.networkBio._ptr())
+
+ @property
+ def _ctrl_pend_ssl(self):
+ # type: () -> Callable[[], int]
+ # size_t BIO_ctrl_pending(BIO *b);
+ # BIO_ctrl_pending() return the number of pending characters in
+ # the BIOs read and write buffers.
+ return partial(m2.bio_ctrl_pending, self.sslBio._ptr())
+
+ @property
+ def _ctrl_pend_net(self):
+ # type: () -> Callable[[], int]
+ return partial(m2.bio_ctrl_pending, self.networkBio._ptr())
+
+ @property
+ def _write_ssl(self):
+ # type: () -> Callable[[bytes], int]
+ # All these functions return either the amount of data
+ # successfully read or written (if the return value is
+ # positive) or that no data was successfully read or written
+ # if the result is 0 or -1. If the return value is -2 then
+ # the operation is not implemented in the specific BIO type.
+ return partial(m2.bio_write, self.sslBio._ptr())
+
+ @property
+ def _write_net(self):
+ # type: () -> Callable[[bytes], int]
+ return partial(m2.bio_write, self.networkBio._ptr())
+
+ @property
+ def _read_ssl(self):
+ # type: () -> Callable[[int], Optional[bytes]]
+ return partial(m2.bio_read, self.sslBio._ptr())
+
+ @property
+ def _read_net(self):
+ # type: () -> Callable[[int], Optional[bytes]]
+ return partial(m2.bio_read, self.networkBio._ptr())
+
+ def _encrypt(self, data=b'', clientHello=0):
+ # type: (bytes, int) -> bytes
+ """
+ :param data:
+ :param clientHello:
+ :return:
+ """
+ encryptedData = b''
+ self.data += data
+
+ while 1:
+ if (self._get_wr_guar_ssl() > 0 and self.data != b'') or clientHello:
+ r = self._write_ssl(self.data)
+ if r <= 0:
+ if not self._shoud_retry_ssl():
+ raise IOError(
+ ('Data left to be written to {}, ' +
+ 'but cannot retry SSL connection!').format(self.sslBio))
+ else:
+ assert self.checked
+ self.data = self.data[r:]
+
+ pending = self._ctrl_pend_net()
+ if pending:
+ d = self._read_net(pending)
+ if d is not None: # This is strange, but d can be None
+ encryptedData += d
+ else:
+ assert(self._shoud_retry_net())
+ else:
+ break
+ return encryptedData
+
+ def _decrypt(self, data=b''):
+ # type: (bytes) -> bytes
+ self.encrypted += data
+ decryptedData = b''
+
+ while 1:
+ if self._get_wr_guar_ssl() > 0 and self.encrypted != b'':
+ r = self._write_net(self.encrypted)
+ if r <= 0:
+ if not self._shoud_retry_net():
+ raise IOError(
+ ('Data left to be written to {}, ' +
+ 'but cannot retry SSL connection!').format(self.networkBio))
+ else:
+ self.encrypted = self.encrypted[r:]
+
+ pending = self._ctrl_pend_ssl()
+ if pending:
+ d = self._read_ssl(pending)
+ if d is not None: # This is strange, but d can be None
+ decryptedData += d
+ else:
+ assert(self._shoud_retry_ssl())
+ else:
+ break
+
+ return decryptedData