From 14ae3070964b882ced21557aabdf4c9930d2cc6a Mon Sep 17 00:00:00 2001 From: Alan Antonuk Date: Thu, 16 Apr 2015 23:25:41 -0700 Subject: Add nonblocking sockets in OpenSSL socket impl --- librabbitmq/amqp_openssl.c | 60 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/librabbitmq/amqp_openssl.c b/librabbitmq/amqp_openssl.c index 57589d8..ceb0641 100644 --- a/librabbitmq/amqp_openssl.c +++ b/librabbitmq/amqp_openssl.c @@ -33,6 +33,7 @@ #include "amqp_socket.h" #include "amqp_hostcheck.h" #include "amqp_private.h" +#include "amqp_timer.h" #include "threads.h" #include @@ -97,6 +98,12 @@ amqp_ssl_socket_send(void *base, /* TODO: Close connection if it isn't already? */ /* TODO: Possibly be more intelligent in reporting WHAT went wrong */ switch (self->internal_error) { + case SSL_ERROR_WANT_READ: + res = AMQP_PRIVATE_STATUS_SOCKET_NEEDREAD; + break; + case SSL_ERROR_WANT_WRITE: + res = AMQP_PRIVATE_STATUS_SOCKET_NEEDWRITE; + break; case SSL_ERROR_ZERO_RETURN: res = AMQP_STATUS_CONNECTION_CLOSED; break; @@ -166,7 +173,13 @@ amqp_ssl_socket_recv(void *base, received = SSL_read(self->ssl, buf, len); if (0 >= received) { self->internal_error = SSL_get_error(self->ssl, received); - switch(self->internal_error) { + switch (self->internal_error) { + case SSL_ERROR_WANT_READ: + received = AMQP_PRIVATE_STATUS_SOCKET_NEEDREAD; + break; + case SSL_ERROR_WANT_WRITE: + received = AMQP_PRIVATE_STATUS_SOCKET_NEEDWRITE; + break; case SSL_ERROR_ZERO_RETURN: received = AMQP_STATUS_CONNECTION_CLOSED; break; @@ -289,6 +302,7 @@ amqp_ssl_socket_open(void *base, const char *host, int port, struct timeval *tim struct amqp_ssl_socket_t *self = (struct amqp_ssl_socket_t *)base; long result; int status; + amqp_timer_t timer; if (-1 != self->sockfd) { return AMQP_STATUS_SOCKET_INUSE; } @@ -301,8 +315,12 @@ amqp_ssl_socket_open(void *base, const char *host, int port, struct timeval *tim goto exit; } - SSL_set_mode(self->ssl, SSL_MODE_AUTO_RETRY); - self->sockfd = amqp_open_socket_noblock(host, port, timeout); + status = amqp_timer_start(&timer, timeout); + if (AMQP_STATUS_OK != status) { + return status; + } + + self->sockfd = amqp_open_socket_inner(host, port, timer); if (0 > self->sockfd) { status = self->sockfd; self->internal_error = amqp_os_socket_error(); @@ -317,10 +335,23 @@ amqp_ssl_socket_open(void *base, const char *host, int port, struct timeval *tim goto error_out2; } +start_connect: status = SSL_connect(self->ssl); if (!status) { self->internal_error = SSL_get_error(self->ssl, status); - status = AMQP_STATUS_SSL_CONNECTION_FAILED; + switch (self->internal_error) { + case SSL_ERROR_WANT_READ: + status = amqp_poll_read(self->sockfd, timer); + break; + case SSL_ERROR_WANT_WRITE: + status = amqp_poll_write(self->sockfd, timer); + break; + default: + status = AMQP_STATUS_SSL_CONNECTION_FAILED; + } + if (AMQP_STATUS_OK == status) { + goto start_connect; + } goto error_out2; } @@ -359,13 +390,32 @@ error_out1: static int amqp_ssl_socket_close(void *base) { + int res; struct amqp_ssl_socket_t *self = (struct amqp_ssl_socket_t *)base; if (-1 == self->sockfd) { return AMQP_STATUS_SOCKET_CLOSED; } - SSL_shutdown(self->ssl); +start_shutdown: + res = SSL_shutdown(self->ssl); + if (0 == res) { + goto start_shutdown; + } else if (-1 == res) { + self->internal_error = SSL_get_error(self->ssl, res); + switch (self->internal_error) { + case SSL_ERROR_WANT_READ: + res = amqp_poll_read(self->sockfd, amqp_timer_start_infinite()); + break; + case SSL_ERROR_WANT_WRITE: + res = amqp_poll_write(self->sockfd, amqp_timer_start_infinite()); + break; + } + if (AMQP_STATUS_OK == res) { + goto start_shutdown; + } + /* Swallow errors in poll, just consider the connection dead */ + } SSL_free(self->ssl); self->ssl = NULL; -- cgit v1.2.1