summaryrefslogtreecommitdiff
path: root/lib/cpp/src/thrift/transport/TSSLSocket.cpp
diff options
context:
space:
mode:
authorMartin Haimberger <martin.haimberger@thincast.com>2015-11-13 03:18:50 -0800
committerNobuaki Sukegawa <nsuke@apache.org>2015-11-23 17:09:27 +0900
commit0ad6ee95e002f41dd628d4044f901468f43ffc32 (patch)
tree71331e3d041d730ddd27a97617646fa7d740ab6f /lib/cpp/src/thrift/transport/TSSLSocket.cpp
parentae971ce917bf9b60ee8ae83b834dad1eb149a82f (diff)
downloadthrift-0ad6ee95e002f41dd628d4044f901468f43ffc32.tar.gz
THRIFT-3420 C++: TSSLSockets are not interruptable
Client: C++ Patch: Martin Haimberger This closes #690
Diffstat (limited to 'lib/cpp/src/thrift/transport/TSSLSocket.cpp')
-rw-r--r--lib/cpp/src/thrift/transport/TSSLSocket.cpp259
1 files changed, 242 insertions, 17 deletions
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index 98c532676..6e9a4de0f 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -28,6 +28,14 @@
#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
+#ifdef HAVE_SYS_POLL_H
+#include <sys/poll.h>
+#endif
+#ifdef HAVE_FCNTL_H
+#include <fcntl.h>
+#endif
+
+
#include <boost/lexical_cast.hpp>
#include <boost/shared_array.hpp>
#include <openssl/err.h>
@@ -189,14 +197,28 @@ TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx)
: TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
}
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+ : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+ interruptListener_ = interruptListener;
+}
+
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
: TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
}
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+ : TSocket(socket, interruptListener), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port)
: TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
}
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+ : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+ interruptListener_ = interruptListener;
+}
+
TSSLSocket::~TSSLSocket() {
close();
}
@@ -222,16 +244,32 @@ bool TSSLSocket::peek() {
checkHandshake();
int rc;
uint8_t byte;
- rc = SSL_peek(ssl_, &byte, 1);
- if (rc < 0) {
- int errno_copy = THRIFT_GET_SOCKET_ERROR;
- string errors;
- buildErrors(errors, errno_copy);
- throw TSSLException("SSL_peek: " + errors);
- }
- if (rc == 0) {
- ERR_clear_error();
- }
+ do {
+ rc = SSL_peek(ssl_, &byte, 1);
+ if (rc < 0) {
+
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ continue;
+ default:;// do nothing
+ }
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_peek: " + errors);
+ } else if (rc == 0) {
+ ERR_clear_error();
+ break;
+ }
+ } while (true);
return (rc > 0);
}
@@ -244,7 +282,28 @@ void TSSLSocket::open() {
void TSSLSocket::close() {
if (ssl_ != NULL) {
- int rc = SSL_shutdown(ssl_);
+ int rc;
+
+ do {
+ rc = SSL_shutdown(ssl_);
+ if (rc <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ default:;// do nothing
+ }
+ }
+ } while (rc == 2);
+
if (rc < 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
string errors;
@@ -262,14 +321,36 @@ uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
checkHandshake();
int32_t bytes = 0;
for (int32_t retries = 0; retries < maxRecvRetries_; retries++) {
+ ERR_clear_error();
bytes = SSL_read(ssl_, buf, len);
if (bytes >= 0)
break;
- int errno_copy = THRIFT_GET_SOCKET_ERROR;
- if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
- if (ERR_get_error() == 0 && errno_copy == THRIFT_EINTR) {
+ int32_t errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int32_t error = SSL_get_error(ssl_, bytes);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ if (retries++ >= maxRecvRetries_) {
+ // THRIFT_EINTR needs to be handled manually and we can tolerate
+ // a certain number
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) {
+ // repeat operation
+ if (retries++ < maxRecvRetries_) {
+ // THRIFT_EINTR needs to be handled manually and we can tolerate
+ // a certain number
+ continue;
+ }
+ throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries");
+ }
continue;
- }
+ default:;// do nothing
}
string errors;
buildErrors(errors, errno_copy);
@@ -283,9 +364,23 @@ void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
// loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
uint32_t written = 0;
while (written < len) {
+ ERR_clear_error();
int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
if (bytes <= 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, bytes);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ continue;
+ default:;// do nothing
+ }
string errors;
buildErrors(errors, errno_copy);
throw TSSLException("SSL_write: " + errors);
@@ -319,13 +414,76 @@ void TSSLSocket::checkHandshake() {
if (ssl_ != NULL) {
return;
}
+
+ // set underlying socket to non-blocking
+ int flags;
+ if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0
+ || THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) {
+ GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ",
+ THRIFT_GET_SOCKET_ERROR);
+ ::THRIFT_CLOSESOCKET(socket_);
+ return;
+ }
+
ssl_ = ctx_->createSSL();
+
+ //set read and write bios to non-blocking
+ BIO* wbio = BIO_new(BIO_s_mem());
+ if (wbio == NULL) {
+ throw TSSLException("SSL_get_wbio returns NULL");
+ }
+ BIO_set_nbio(wbio, 1);
+
+ BIO* rbio = BIO_new(BIO_s_mem());
+ if (rbio == NULL) {
+ throw TSSLException("SSL_get_rbio returns NULL");
+ }
+ BIO_set_nbio(rbio, 1);
+
+ SSL_set_bio(ssl_, rbio, wbio);
+
SSL_set_fd(ssl_, static_cast<int>(socket_));
int rc;
if (server()) {
- rc = SSL_accept(ssl_);
+ do {
+ rc = SSL_accept(ssl_);
+ if (rc <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ default:;// do nothing
+ }
+ }
+ } while (rc == 2);
} else {
- rc = SSL_connect(ssl_);
+ do {
+ rc = SSL_connect(ssl_);
+ if (rc <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ default:;// do nothing
+ }
+ }
+ } while (rc == 2);
}
if (rc <= 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
@@ -443,6 +601,54 @@ void TSSLSocket::authorize() {
}
}
+unsigned int TSSLSocket::waitForEvent(bool wantRead) {
+ int fdSocket;
+ BIO* bio;
+
+ if (wantRead) {
+ bio = SSL_get_rbio(ssl_);
+ } else {
+ bio = SSL_get_wbio(ssl_);
+ }
+
+ if (bio == NULL) {
+ throw TSSLException("SSL_get_?bio returned NULL");
+ }
+
+ if (BIO_get_fd(bio, &fdSocket) <= 0) {
+ throw TSSLException("BIO_get_fd failed");
+ }
+
+ struct THRIFT_POLLFD fds[2];
+ std::memset(fds, 0, sizeof(fds));
+ fds[0].fd = fdSocket;
+ fds[0].events = wantRead ? THRIFT_POLLIN : THRIFT_POLLOUT;
+
+ if (interruptListener_) {
+ fds[1].fd = *(interruptListener_.get());
+ fds[1].events = THRIFT_POLLIN;
+ }
+
+ int ret = THRIFT_POLL(fds, interruptListener_ ? 2 : 1, -1);
+
+ if (ret < 0) {
+ // error cases
+ if (THRIFT_GET_SOCKET_ERROR == THRIFT_EINTR) {
+ return TSSL_EINTR; // repeat operation
+ }
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+ } else if (ret > 0){
+ if (fds[1].revents & THRIFT_POLLIN) {
+ throw TTransportException(TTransportException::INTERRUPTED, "Interrupted");
+ }
+ return TSSL_DATA;
+ } else {
+ throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_POLL (timed out)");
+ }
+}
+
// TSSLSocketFactory implementation
uint64_t TSSLSocketFactory::count_ = 0;
Mutex TSSLSocketFactory::mutex_;
@@ -475,18 +681,37 @@ boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
return ssl;
}
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+ boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener));
+ setup(ssl);
+ return ssl;
+}
+
boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) {
boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
setup(ssl);
return ssl;
}
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+ boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener));
+ setup(ssl);
+ return ssl;
+}
+
boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port) {
boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
setup(ssl);
return ssl;
}
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+ boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener));
+ setup(ssl);
+ return ssl;
+}
+
+
void TSSLSocketFactory::setup(boost::shared_ptr<TSSLSocket> ssl) {
ssl->server(server());
if (access_ == NULL && !server()) {