diff options
Diffstat (limited to 'SWIG/_ssl.i')
-rw-r--r-- | SWIG/_ssl.i | 260 |
1 files changed, 189 insertions, 71 deletions
diff --git a/SWIG/_ssl.i b/SWIG/_ssl.i index 28a247c..ccfece9 100644 --- a/SWIG/_ssl.i +++ b/SWIG/_ssl.i @@ -11,10 +11,13 @@ %{ #include <pythread.h> +#include <limits.h> #include <openssl/bio.h> #include <openssl/dh.h> #include <openssl/ssl.h> #include <openssl/x509.h> +#include <poll.h> +#include <sys/time.h> %} %apply Pointer NONNULL { SSL_CTX * }; @@ -155,6 +158,11 @@ extern long SSL_SESSION_set_timeout(SSL_SESSION *, long); %rename(ssl_session_get_timeout) SSL_SESSION_get_timeout; extern long SSL_SESSION_get_timeout(CONST SSL_SESSION *); +extern PyObject *ssl_accept(SSL *ssl, double timeout = -1); +extern PyObject *ssl_connect(SSL *ssl, double timeout = -1); +extern PyObject *ssl_read(SSL *ssl, int num, double timeout = -1); +extern int ssl_write(SSL *ssl, PyObject *blob, double timeout = -1); + %constant int ssl_error_none = SSL_ERROR_NONE; %constant int ssl_error_ssl = SSL_ERROR_SSL; %constant int ssl_error_want_read = SSL_ERROR_WANT_READ; @@ -210,14 +218,19 @@ extern long SSL_SESSION_get_timeout(CONST SSL_SESSION *); %constant int SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER = SSL_MODE_ENABLE_PARTIAL_WRITE; %constant int SSL_MODE_AUTO_RETRY = SSL_MODE_AUTO_RETRY; +%ignore ssl_handle_error; +%ignore ssl_sleep_with_timeout; %inline %{ static PyObject *_ssl_err; +static PyObject *_ssl_timeout_err; -void ssl_init(PyObject *ssl_err) { +void ssl_init(PyObject *ssl_err, PyObject *ssl_timeout_err) { SSL_library_init(); SSL_load_error_strings(); Py_INCREF(ssl_err); + Py_INCREF(ssl_timeout_err); _ssl_err = ssl_err; + _ssl_timeout_err = ssl_timeout_err; } void ssl_ctx_passphrase_callback(SSL_CTX *ctx, PyObject *pyfunc) { @@ -403,36 +416,130 @@ int ssl_set_fd(SSL *ssl, int fd) { return ret; } -PyObject *ssl_accept(SSL *ssl) { +static void ssl_handle_error(int ssl_err, int ret) { + int err; + + switch (ssl_err) { + case SSL_ERROR_SSL: + PyErr_SetString(_ssl_err, + ERR_reason_error_string(ERR_get_error())); + break; + case SSL_ERROR_SYSCALL: + err = ERR_get_error(); + if (err) + PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); + else if (ret == 0) + PyErr_SetString(_ssl_err, "unexpected eof"); + else if (ret == -1) + PyErr_SetFromErrno(_ssl_err); + else + assert(0); + break; + default: + PyErr_SetString(_ssl_err, "unexpected SSL error"); + } +} + +static int ssl_sleep_with_timeout(SSL *ssl, const struct timeval *start, + double timeout, int ssl_err) { + struct pollfd fd; + struct timeval tv; + int ms, tmp; + + assert(timeout > 0); + again: + gettimeofday(&tv, NULL); + /* tv >= start */ + if ((timeout + start->tv_sec - tv.tv_sec) > INT_MAX / 1000) + ms = -1; + else { + int fract; + + ms = ((start->tv_sec + (int)timeout) - tv.tv_sec) * 1000; + fract = (start->tv_usec + (timeout - (int)timeout) * 1000000 + - tv.tv_usec + 999) / 1000; + if (ms > 0 && fract > INT_MAX - ms) + ms = -1; + else { + ms += fract; + if (ms <= 0) + goto timeout; + } + } + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + fd.fd = SSL_get_rfd(ssl); + fd.events = POLLIN; + break; + + case SSL_ERROR_WANT_WRITE: + fd.fd = SSL_get_wfd(ssl); + fd.events = POLLOUT; + break; + + case SSL_ERROR_WANT_X509_LOOKUP: + return 0; /* FIXME: is this correct? */ + + default: + assert(0); + } + if (fd.fd == -1) { + PyErr_SetString(_ssl_err, "timeout on a non-FD SSL"); + return -1; + } + Py_BEGIN_ALLOW_THREADS + tmp = poll(&fd, 1, ms); + Py_END_ALLOW_THREADS + switch (tmp) { + case 1: + return 0; + case 0: + goto timeout; + case -1: + if (errno == EINTR) + goto again; + PyErr_SetFromErrno(_ssl_err); + return -1; + } + return 0; + + timeout: + PyErr_SetString(_ssl_timeout_err, "timed out"); + return -1; +} + +PyObject *ssl_accept(SSL *ssl, double timeout) { PyObject *obj = NULL; - int r, err; + int r, ssl_err; + struct timeval tv; + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_accept(ssl); + ssl_err = SSL_get_error(ssl, r); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: obj = PyInt_FromLong((long)1); break; case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: - obj = PyInt_FromLong((long)0); - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); + if (timeout <= 0) { + obj = PyInt_FromLong((long)0); + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; obj = NULL; break; + case SSL_ERROR_SSL: case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); obj = NULL; break; } @@ -441,36 +548,38 @@ PyObject *ssl_accept(SSL *ssl) { return obj; } -PyObject *ssl_connect(SSL *ssl) { +PyObject *ssl_connect(SSL *ssl, double timeout) { PyObject *obj = NULL; - int r, err; + int r, ssl_err; + struct timeval tv; + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_connect(ssl); + ssl_err = SSL_get_error(ssl, r); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: obj = PyInt_FromLong((long)1); break; case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: - obj = PyInt_FromLong((long)0); - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); + if (timeout <= 0) { + obj = PyInt_FromLong((long)0); + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; obj = NULL; break; + case SSL_ERROR_SSL: case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); obj = NULL; break; } @@ -483,10 +592,11 @@ void ssl_set_shutdown1(SSL *ssl, int mode) { SSL_set_shutdown(ssl, mode); } -PyObject *ssl_read(SSL *ssl, int num) { +PyObject *ssl_read(SSL *ssl, int num, double timeout) { PyObject *obj = NULL; void *buf; - int r, err; + int r; + struct timeval tv; if (!(buf = PyMem_Malloc(num))) { PyErr_SetString(PyExc_MemoryError, "ssl_read"); @@ -494,37 +604,44 @@ PyObject *ssl_read(SSL *ssl, int num) { } + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_read(ssl, buf, num); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { - case SSL_ERROR_NONE: - case SSL_ERROR_ZERO_RETURN: - buf = PyMem_Realloc(buf, r); - obj = PyString_FromStringAndSize(buf, r); - break; - case SSL_ERROR_WANT_WRITE: - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_X509_LOOKUP: - Py_INCREF(Py_None); - obj = Py_None; - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - obj = NULL; - break; - case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); - obj = NULL; - break; + if (r >= 0) { + buf = PyMem_Realloc(buf, r); + obj = PyString_FromStringAndSize(buf, r); + } else { + int ssl_err; + + ssl_err = SSL_get_error(ssl, r); + switch (ssl_err) { + case SSL_ERROR_NONE: + case SSL_ERROR_ZERO_RETURN: + assert(0); + + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_X509_LOOKUP: + if (timeout <= 0) { + Py_INCREF(Py_None); + obj = Py_None; + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; + obj = NULL; + break; + case SSL_ERROR_SSL: + case SSL_ERROR_SYSCALL: + ssl_handle_error(ssl_err, r); + obj = NULL; + break; + } } PyMem_Free(buf); @@ -582,22 +699,26 @@ PyObject *ssl_read_nbio(SSL *ssl, int num) { return obj; } -int ssl_write(SSL *ssl, PyObject *blob) { +int ssl_write(SSL *ssl, PyObject *blob, double timeout) { const void *buf; - int len, r, err, ret; + int len, r, ssl_err, ret; + struct timeval tv; if (m2_PyObject_AsReadBufferInt(blob, &buf, &len) == -1) { return -1; } - + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_write(ssl, buf, len); + ssl_err = SSL_get_error(ssl, r); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: ret = r; @@ -605,20 +726,17 @@ int ssl_write(SSL *ssl, PyObject *blob) { case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_X509_LOOKUP: + if (timeout <= 0) { + ret = -1; + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; ret = -1; break; case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - ret = -1; - break; case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); default: ret = -1; } |