From 0ad6ee95e002f41dd628d4044f901468f43ffc32 Mon Sep 17 00:00:00 2001 From: Martin Haimberger Date: Fri, 13 Nov 2015 03:18:50 -0800 Subject: THRIFT-3420 C++: TSSLSockets are not interruptable Client: C++ Patch: Martin Haimberger This closes #690 --- lib/cpp/src/thrift/transport/TSSLSocket.cpp | 259 ++++++++++++++++++++++++++-- 1 file changed, 242 insertions(+), 17 deletions(-) (limited to 'lib/cpp/src/thrift/transport/TSSLSocket.cpp') diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp index 98c532676..6e9a4de0f 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp @@ -28,6 +28,14 @@ #ifdef HAVE_SYS_SOCKET_H #include #endif +#ifdef HAVE_SYS_POLL_H +#include +#endif +#ifdef HAVE_FCNTL_H +#include +#endif + + #include #include #include @@ -189,14 +197,28 @@ TSSLSocket::TSSLSocket(boost::shared_ptr ctx) : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) { } +TSSLSocket::TSSLSocket(boost::shared_ptr ctx, boost::shared_ptr interruptListener) + : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) { + interruptListener_ = interruptListener; +} + TSSLSocket::TSSLSocket(boost::shared_ptr ctx, THRIFT_SOCKET socket) : TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) { } +TSSLSocket::TSSLSocket(boost::shared_ptr ctx, THRIFT_SOCKET socket, boost::shared_ptr interruptListener) + : TSocket(socket, interruptListener), server_(false), ssl_(NULL), ctx_(ctx) { +} + TSSLSocket::TSSLSocket(boost::shared_ptr ctx, string host, int port) : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) { } +TSSLSocket::TSSLSocket(boost::shared_ptr ctx, string host, int port, boost::shared_ptr interruptListener) + : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) { + interruptListener_ = interruptListener; +} + TSSLSocket::~TSSLSocket() { close(); } @@ -222,16 +244,32 @@ bool TSSLSocket::peek() { checkHandshake(); int rc; uint8_t byte; - rc = SSL_peek(ssl_, &byte, 1); - if (rc < 0) { - int errno_copy = THRIFT_GET_SOCKET_ERROR; - string errors; - buildErrors(errors, errno_copy); - throw TSSLException("SSL_peek: " + errors); - } - if (rc == 0) { - ERR_clear_error(); - } + do { + rc = SSL_peek(ssl_, &byte, 1); + if (rc < 0) { + + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + continue; + default:;// do nothing + } + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_peek: " + errors); + } else if (rc == 0) { + ERR_clear_error(); + break; + } + } while (true); return (rc > 0); } @@ -244,7 +282,28 @@ void TSSLSocket::open() { void TSSLSocket::close() { if (ssl_ != NULL) { - int rc = SSL_shutdown(ssl_); + int rc; + + do { + rc = SSL_shutdown(ssl_); + if (rc <= 0) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + rc = 2; + default:;// do nothing + } + } + } while (rc == 2); + if (rc < 0) { int errno_copy = THRIFT_GET_SOCKET_ERROR; string errors; @@ -262,14 +321,36 @@ uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) { checkHandshake(); int32_t bytes = 0; for (int32_t retries = 0; retries < maxRecvRetries_; retries++) { + ERR_clear_error(); bytes = SSL_read(ssl_, buf, len); if (bytes >= 0) break; - int errno_copy = THRIFT_GET_SOCKET_ERROR; - if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) { - if (ERR_get_error() == 0 && errno_copy == THRIFT_EINTR) { + int32_t errno_copy = THRIFT_GET_SOCKET_ERROR; + int32_t error = SSL_get_error(ssl_, bytes); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + if (retries++ >= maxRecvRetries_) { + // THRIFT_EINTR needs to be handled manually and we can tolerate + // a certain number + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) { + // repeat operation + if (retries++ < maxRecvRetries_) { + // THRIFT_EINTR needs to be handled manually and we can tolerate + // a certain number + continue; + } + throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries"); + } continue; - } + default:;// do nothing } string errors; buildErrors(errors, errno_copy); @@ -283,9 +364,23 @@ void TSSLSocket::write(const uint8_t* buf, uint32_t len) { // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX. uint32_t written = 0; while (written < len) { + ERR_clear_error(); int32_t bytes = SSL_write(ssl_, &buf[written], len - written); if (bytes <= 0) { int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, bytes); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + continue; + default:;// do nothing + } string errors; buildErrors(errors, errno_copy); throw TSSLException("SSL_write: " + errors); @@ -319,13 +414,76 @@ void TSSLSocket::checkHandshake() { if (ssl_ != NULL) { return; } + + // set underlying socket to non-blocking + int flags; + if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0 + || THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) { + GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ", + THRIFT_GET_SOCKET_ERROR); + ::THRIFT_CLOSESOCKET(socket_); + return; + } + ssl_ = ctx_->createSSL(); + + //set read and write bios to non-blocking + BIO* wbio = BIO_new(BIO_s_mem()); + if (wbio == NULL) { + throw TSSLException("SSL_get_wbio returns NULL"); + } + BIO_set_nbio(wbio, 1); + + BIO* rbio = BIO_new(BIO_s_mem()); + if (rbio == NULL) { + throw TSSLException("SSL_get_rbio returns NULL"); + } + BIO_set_nbio(rbio, 1); + + SSL_set_bio(ssl_, rbio, wbio); + SSL_set_fd(ssl_, static_cast(socket_)); int rc; if (server()) { - rc = SSL_accept(ssl_); + do { + rc = SSL_accept(ssl_); + if (rc <= 0) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + rc = 2; + default:;// do nothing + } + } + } while (rc == 2); } else { - rc = SSL_connect(ssl_); + do { + rc = SSL_connect(ssl_); + if (rc <= 0) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + rc = 2; + default:;// do nothing + } + } + } while (rc == 2); } if (rc <= 0) { int errno_copy = THRIFT_GET_SOCKET_ERROR; @@ -443,6 +601,54 @@ void TSSLSocket::authorize() { } } +unsigned int TSSLSocket::waitForEvent(bool wantRead) { + int fdSocket; + BIO* bio; + + if (wantRead) { + bio = SSL_get_rbio(ssl_); + } else { + bio = SSL_get_wbio(ssl_); + } + + if (bio == NULL) { + throw TSSLException("SSL_get_?bio returned NULL"); + } + + if (BIO_get_fd(bio, &fdSocket) <= 0) { + throw TSSLException("BIO_get_fd failed"); + } + + struct THRIFT_POLLFD fds[2]; + std::memset(fds, 0, sizeof(fds)); + fds[0].fd = fdSocket; + fds[0].events = wantRead ? THRIFT_POLLIN : THRIFT_POLLOUT; + + if (interruptListener_) { + fds[1].fd = *(interruptListener_.get()); + fds[1].events = THRIFT_POLLIN; + } + + int ret = THRIFT_POLL(fds, interruptListener_ ? 2 : 1, -1); + + if (ret < 0) { + // error cases + if (THRIFT_GET_SOCKET_ERROR == THRIFT_EINTR) { + return TSSL_EINTR; // repeat operation + } + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } else if (ret > 0){ + if (fds[1].revents & THRIFT_POLLIN) { + throw TTransportException(TTransportException::INTERRUPTED, "Interrupted"); + } + return TSSL_DATA; + } else { + throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_POLL (timed out)"); + } +} + // TSSLSocketFactory implementation uint64_t TSSLSocketFactory::count_ = 0; Mutex TSSLSocketFactory::mutex_; @@ -475,18 +681,37 @@ boost::shared_ptr TSSLSocketFactory::createSocket() { return ssl; } +boost::shared_ptr TSSLSocketFactory::createSocket(boost::shared_ptr interruptListener) { + boost::shared_ptr ssl(new TSSLSocket(ctx_, interruptListener)); + setup(ssl); + return ssl; +} + boost::shared_ptr TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) { boost::shared_ptr ssl(new TSSLSocket(ctx_, socket)); setup(ssl); return ssl; } +boost::shared_ptr TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, boost::shared_ptr interruptListener) { + boost::shared_ptr ssl(new TSSLSocket(ctx_, socket, interruptListener)); + setup(ssl); + return ssl; +} + boost::shared_ptr TSSLSocketFactory::createSocket(const string& host, int port) { boost::shared_ptr ssl(new TSSLSocket(ctx_, host, port)); setup(ssl); return ssl; } +boost::shared_ptr TSSLSocketFactory::createSocket(const string& host, int port, boost::shared_ptr interruptListener) { + boost::shared_ptr ssl(new TSSLSocket(ctx_, host, port, interruptListener)); + setup(ssl); + return ssl; +} + + void TSSLSocketFactory::setup(boost::shared_ptr ssl) { ssl->server(server()); if (access_ == NULL && !server()) { -- cgit v1.2.1