/* * Copyright (c) 2019, Redis Labs * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #define REDISMODULE_CORE_MODULE /* A module that's part of the redis core, uses server.h too. */ #include "server.h" #include "connhelpers.h" #include "adlist.h" #if (USE_OPENSSL == 1 /* BUILD_YES */ ) || ((USE_OPENSSL == 2 /* BUILD_MODULE */) && (BUILD_TLS_MODULE == 2)) #include #include #include #include #include #if OPENSSL_VERSION_NUMBER >= 0x30000000L #include #endif #include #include #define REDIS_TLS_PROTO_TLSv1 (1<<0) #define REDIS_TLS_PROTO_TLSv1_1 (1<<1) #define REDIS_TLS_PROTO_TLSv1_2 (1<<2) #define REDIS_TLS_PROTO_TLSv1_3 (1<<3) /* Use safe defaults */ #ifdef TLS1_3_VERSION #define REDIS_TLS_PROTO_DEFAULT (REDIS_TLS_PROTO_TLSv1_2|REDIS_TLS_PROTO_TLSv1_3) #else #define REDIS_TLS_PROTO_DEFAULT (REDIS_TLS_PROTO_TLSv1_2) #endif SSL_CTX *redis_tls_ctx = NULL; SSL_CTX *redis_tls_client_ctx = NULL; static int parseProtocolsConfig(const char *str) { int i, count = 0; int protocols = 0; if (!str) return REDIS_TLS_PROTO_DEFAULT; sds *tokens = sdssplitlen(str, strlen(str), " ", 1, &count); if (!tokens) { serverLog(LL_WARNING, "Invalid tls-protocols configuration string"); return -1; } for (i = 0; i < count; i++) { if (!strcasecmp(tokens[i], "tlsv1")) protocols |= REDIS_TLS_PROTO_TLSv1; else if (!strcasecmp(tokens[i], "tlsv1.1")) protocols |= REDIS_TLS_PROTO_TLSv1_1; else if (!strcasecmp(tokens[i], "tlsv1.2")) protocols |= REDIS_TLS_PROTO_TLSv1_2; else if (!strcasecmp(tokens[i], "tlsv1.3")) { #ifdef TLS1_3_VERSION protocols |= REDIS_TLS_PROTO_TLSv1_3; #else serverLog(LL_WARNING, "TLSv1.3 is specified in tls-protocols but not supported by OpenSSL."); protocols = -1; break; #endif } else { serverLog(LL_WARNING, "Invalid tls-protocols specified. " "Use a combination of 'TLSv1', 'TLSv1.1', 'TLSv1.2' and 'TLSv1.3'."); protocols = -1; break; } } sdsfreesplitres(tokens, count); return protocols; } /* list of connections with pending data already read from the socket, but not * served to the reader yet. */ static list *pending_list = NULL; /** * OpenSSL global initialization and locking handling callbacks. * Note that this is only required for OpenSSL < 1.1.0. */ #if OPENSSL_VERSION_NUMBER < 0x10100000L #define USE_CRYPTO_LOCKS #endif #ifdef USE_CRYPTO_LOCKS static pthread_mutex_t *openssl_locks; static void sslLockingCallback(int mode, int lock_id, const char *f, int line) { pthread_mutex_t *mt = openssl_locks + lock_id; if (mode & CRYPTO_LOCK) { pthread_mutex_lock(mt); } else { pthread_mutex_unlock(mt); } (void)f; (void)line; } static void initCryptoLocks(void) { unsigned i, nlocks; if (CRYPTO_get_locking_callback() != NULL) { /* Someone already set the callback before us. Don't destroy it! */ return; } nlocks = CRYPTO_num_locks(); openssl_locks = zmalloc(sizeof(*openssl_locks) * nlocks); for (i = 0; i < nlocks; i++) { pthread_mutex_init(openssl_locks + i, NULL); } CRYPTO_set_locking_callback(sslLockingCallback); } #endif /* USE_CRYPTO_LOCKS */ static void tlsInit(void) { /* Enable configuring OpenSSL using the standard openssl.cnf * OPENSSL_config()/OPENSSL_init_crypto() should be the first * call to the OpenSSL* library. * - OPENSSL_config() should be used for OpenSSL versions < 1.1.0 * - OPENSSL_init_crypto() should be used for OpenSSL versions >= 1.1.0 */ #if OPENSSL_VERSION_NUMBER < 0x10100000L OPENSSL_config(NULL); SSL_load_error_strings(); SSL_library_init(); #elif OPENSSL_VERSION_NUMBER < 0x10101000L OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CONFIG, NULL); #else OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CONFIG|OPENSSL_INIT_ATFORK, NULL); #endif #ifdef USE_CRYPTO_LOCKS initCryptoLocks(); #endif if (!RAND_poll()) { serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator."); } pending_list = listCreate(); } static void tlsCleanup(void) { if (redis_tls_ctx) { SSL_CTX_free(redis_tls_ctx); redis_tls_ctx = NULL; } if (redis_tls_client_ctx) { SSL_CTX_free(redis_tls_client_ctx); redis_tls_client_ctx = NULL; } #if OPENSSL_VERSION_NUMBER >= 0x10100000L && !defined(LIBRESSL_VERSION_NUMBER) // unavailable on LibreSSL OPENSSL_cleanup(); #endif } /* Callback for passing a keyfile password stored as an sds to OpenSSL */ static int tlsPasswordCallback(char *buf, int size, int rwflag, void *u) { UNUSED(rwflag); const char *pass = u; size_t pass_len; if (!pass) return -1; pass_len = strlen(pass); if (pass_len > (size_t) size) return -1; memcpy(buf, pass, pass_len); return (int) pass_len; } /* Create a *base* SSL_CTX using the SSL configuration provided. The base context * includes everything that's common for both client-side and server-side connections. */ static SSL_CTX *createSSLContext(redisTLSContextConfig *ctx_config, int protocols, int client) { const char *cert_file = client ? ctx_config->client_cert_file : ctx_config->cert_file; const char *key_file = client ? ctx_config->client_key_file : ctx_config->key_file; const char *key_file_pass = client ? ctx_config->client_key_file_pass : ctx_config->key_file_pass; char errbuf[256]; SSL_CTX *ctx = NULL; ctx = SSL_CTX_new(SSLv23_method()); SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2|SSL_OP_NO_SSLv3); #ifdef SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); #endif if (!(protocols & REDIS_TLS_PROTO_TLSv1)) SSL_CTX_set_options(ctx, SSL_OP_NO_TLSv1); if (!(protocols & REDIS_TLS_PROTO_TLSv1_1)) SSL_CTX_set_options(ctx, SSL_OP_NO_TLSv1_1); #ifdef SSL_OP_NO_TLSv1_2 if (!(protocols & REDIS_TLS_PROTO_TLSv1_2)) SSL_CTX_set_options(ctx, SSL_OP_NO_TLSv1_2); #endif #ifdef SSL_OP_NO_TLSv1_3 if (!(protocols & REDIS_TLS_PROTO_TLSv1_3)) SSL_CTX_set_options(ctx, SSL_OP_NO_TLSv1_3); #endif #ifdef SSL_OP_NO_COMPRESSION SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION); #endif SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE|SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); SSL_CTX_set_default_passwd_cb(ctx, tlsPasswordCallback); SSL_CTX_set_default_passwd_cb_userdata(ctx, (void *) key_file_pass); if (SSL_CTX_use_certificate_chain_file(ctx, cert_file) <= 0) { ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); serverLog(LL_WARNING, "Failed to load certificate: %s: %s", cert_file, errbuf); goto error; } if (SSL_CTX_use_PrivateKey_file(ctx, key_file, SSL_FILETYPE_PEM) <= 0) { ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); serverLog(LL_WARNING, "Failed to load private key: %s: %s", key_file, errbuf); goto error; } if ((ctx_config->ca_cert_file || ctx_config->ca_cert_dir) && SSL_CTX_load_verify_locations(ctx, ctx_config->ca_cert_file, ctx_config->ca_cert_dir) <= 0) { ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); serverLog(LL_WARNING, "Failed to configure CA certificate(s) file/directory: %s", errbuf); goto error; } if (ctx_config->ciphers && !SSL_CTX_set_cipher_list(ctx, ctx_config->ciphers)) { serverLog(LL_WARNING, "Failed to configure ciphers: %s", ctx_config->ciphers); goto error; } #ifdef TLS1_3_VERSION if (ctx_config->ciphersuites && !SSL_CTX_set_ciphersuites(ctx, ctx_config->ciphersuites)) { serverLog(LL_WARNING, "Failed to configure ciphersuites: %s", ctx_config->ciphersuites); goto error; } #endif return ctx; error: if (ctx) SSL_CTX_free(ctx); return NULL; } /* Attempt to configure/reconfigure TLS. This operation is atomic and will * leave the SSL_CTX unchanged if fails. * @priv: config of redisTLSContextConfig. * @reconfigure: if true, ignore the previous configure; if false, only * configure from @ctx_config if redis_tls_ctx is NULL. */ static int tlsConfigure(void *priv, int reconfigure) { redisTLSContextConfig *ctx_config = (redisTLSContextConfig *)priv; char errbuf[256]; SSL_CTX *ctx = NULL; SSL_CTX *client_ctx = NULL; if (!reconfigure && redis_tls_ctx) { return C_OK; } if (!ctx_config->cert_file) { serverLog(LL_WARNING, "No tls-cert-file configured!"); goto error; } if (!ctx_config->key_file) { serverLog(LL_WARNING, "No tls-key-file configured!"); goto error; } if (((server.tls_auth_clients != TLS_CLIENT_AUTH_NO) || server.tls_cluster || server.tls_replication) && !ctx_config->ca_cert_file && !ctx_config->ca_cert_dir) { serverLog(LL_WARNING, "Either tls-ca-cert-file or tls-ca-cert-dir must be specified when tls-cluster, tls-replication or tls-auth-clients are enabled!"); goto error; } int protocols = parseProtocolsConfig(ctx_config->protocols); if (protocols == -1) goto error; /* Create server side/general context */ ctx = createSSLContext(ctx_config, protocols, 0); if (!ctx) goto error; if (ctx_config->session_caching) { SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_SERVER); SSL_CTX_sess_set_cache_size(ctx, ctx_config->session_cache_size); SSL_CTX_set_timeout(ctx, ctx_config->session_cache_timeout); SSL_CTX_set_session_id_context(ctx, (void *) "redis", 5); } else { SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF); } #ifdef SSL_OP_NO_CLIENT_RENEGOTIATION SSL_CTX_set_options(ctx, SSL_OP_NO_CLIENT_RENEGOTIATION); #endif if (ctx_config->prefer_server_ciphers) SSL_CTX_set_options(ctx, SSL_OP_CIPHER_SERVER_PREFERENCE); #if ((OPENSSL_VERSION_NUMBER < 0x30000000L) && defined(SSL_CTX_set_ecdh_auto)) SSL_CTX_set_ecdh_auto(ctx, 1); #endif SSL_CTX_set_options(ctx, SSL_OP_SINGLE_DH_USE); if (ctx_config->dh_params_file) { FILE *dhfile = fopen(ctx_config->dh_params_file, "r"); if (!dhfile) { serverLog(LL_WARNING, "Failed to load %s: %s", ctx_config->dh_params_file, strerror(errno)); goto error; } #if (OPENSSL_VERSION_NUMBER >= 0x30000000L) EVP_PKEY *pkey = NULL; OSSL_DECODER_CTX *dctx = OSSL_DECODER_CTX_new_for_pkey( &pkey, "PEM", NULL, "DH", OSSL_KEYMGMT_SELECT_DOMAIN_PARAMETERS, NULL, NULL); if (!dctx) { serverLog(LL_WARNING, "No decoder for DH params."); fclose(dhfile); goto error; } if (!OSSL_DECODER_from_fp(dctx, dhfile)) { serverLog(LL_WARNING, "%s: failed to read DH params.", ctx_config->dh_params_file); OSSL_DECODER_CTX_free(dctx); fclose(dhfile); goto error; } OSSL_DECODER_CTX_free(dctx); fclose(dhfile); if (SSL_CTX_set0_tmp_dh_pkey(ctx, pkey) <= 0) { ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); serverLog(LL_WARNING, "Failed to load DH params file: %s: %s", ctx_config->dh_params_file, errbuf); EVP_PKEY_free(pkey); goto error; } /* Not freeing pkey, it is owned by OpenSSL now */ #else DH *dh = PEM_read_DHparams(dhfile, NULL, NULL, NULL); fclose(dhfile); if (!dh) { serverLog(LL_WARNING, "%s: failed to read DH params.", ctx_config->dh_params_file); goto error; } if (SSL_CTX_set_tmp_dh(ctx, dh) <= 0) { ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); serverLog(LL_WARNING, "Failed to load DH params file: %s: %s", ctx_config->dh_params_file, errbuf); DH_free(dh); goto error; } DH_free(dh); #endif } else { #if (OPENSSL_VERSION_NUMBER >= 0x30000000L) SSL_CTX_set_dh_auto(ctx, 1); #endif } /* If a client-side certificate is configured, create an explicit client context */ if (ctx_config->client_cert_file && ctx_config->client_key_file) { client_ctx = createSSLContext(ctx_config, protocols, 1); if (!client_ctx) goto error; } SSL_CTX_free(redis_tls_ctx); SSL_CTX_free(redis_tls_client_ctx); redis_tls_ctx = ctx; redis_tls_client_ctx = client_ctx; return C_OK; error: if (ctx) SSL_CTX_free(ctx); if (client_ctx) SSL_CTX_free(client_ctx); return C_ERR; } #ifdef TLS_DEBUGGING #define TLSCONN_DEBUG(fmt, ...) \ serverLog(LL_DEBUG, "TLSCONN: " fmt, __VA_ARGS__) #else #define TLSCONN_DEBUG(fmt, ...) #endif static ConnectionType CT_TLS; /* Normal socket connections have a simple events/handler correlation. * * With TLS connections we need to handle cases where during a logical read * or write operation, the SSL library asks to block for the opposite * socket operation. * * When this happens, we need to do two things: * 1. Make sure we register for the event. * 2. Make sure we know which handler needs to execute when the * event fires. That is, if we notify the caller of a write operation * that it blocks, and SSL asks for a read, we need to trigger the * write handler again on the next read event. * */ typedef enum { WANT_READ = 1, WANT_WRITE } WantIOType; #define TLS_CONN_FLAG_READ_WANT_WRITE (1<<0) #define TLS_CONN_FLAG_WRITE_WANT_READ (1<<1) #define TLS_CONN_FLAG_FD_SET (1<<2) typedef struct tls_connection { connection c; int flags; SSL *ssl; char *ssl_error; listNode *pending_list_node; } tls_connection; static connection *createTLSConnection(int client_side) { SSL_CTX *ctx = redis_tls_ctx; if (client_side && redis_tls_client_ctx) ctx = redis_tls_client_ctx; tls_connection *conn = zcalloc(sizeof(tls_connection)); conn->c.type = &CT_TLS; conn->c.fd = -1; conn->ssl = SSL_new(ctx); return (connection *) conn; } static connection *connCreateTLS(void) { return createTLSConnection(1); } /* Fetch the latest OpenSSL error and store it in the connection */ static void updateTLSError(tls_connection *conn) { conn->c.last_errno = 0; if (conn->ssl_error) zfree(conn->ssl_error); conn->ssl_error = zmalloc(512); ERR_error_string_n(ERR_get_error(), conn->ssl_error, 512); } /* Create a new TLS connection that is already associated with * an accepted underlying file descriptor. * * The socket is not ready for I/O until connAccept() was called and * invoked the connection-level accept handler. * * Callers should use connGetState() and verify the created connection * is not in an error state. */ static connection *connCreateAcceptedTLS(int fd, void *priv) { int require_auth = *(int *)priv; tls_connection *conn = (tls_connection *) createTLSConnection(0); conn->c.fd = fd; conn->c.state = CONN_STATE_ACCEPTING; if (!conn->ssl) { updateTLSError(conn); conn->c.state = CONN_STATE_ERROR; return (connection *) conn; } switch (require_auth) { case TLS_CLIENT_AUTH_NO: SSL_set_verify(conn->ssl, SSL_VERIFY_NONE, NULL); break; case TLS_CLIENT_AUTH_OPTIONAL: SSL_set_verify(conn->ssl, SSL_VERIFY_PEER, NULL); break; default: /* TLS_CLIENT_AUTH_YES, also fall-secure */ SSL_set_verify(conn->ssl, SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); break; } SSL_set_fd(conn->ssl, conn->c.fd); SSL_set_accept_state(conn->ssl); return (connection *) conn; } static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask); static void updateSSLEvent(tls_connection *conn); /* Process the return code received from OpenSSL> * Update the want parameter with expected I/O. * Update the connection's error state if a real error has occurred. * Returns an SSL error code, or 0 if no further handling is required. */ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *want) { if (ret_value <= 0) { int ssl_err = SSL_get_error(conn->ssl, ret_value); switch (ssl_err) { case SSL_ERROR_WANT_WRITE: *want = WANT_WRITE; return 0; case SSL_ERROR_WANT_READ: *want = WANT_READ; return 0; case SSL_ERROR_SYSCALL: conn->c.last_errno = errno; if (conn->ssl_error) zfree(conn->ssl_error); conn->ssl_error = errno ? zstrdup(strerror(errno)) : NULL; break; default: /* Error! */ updateTLSError(conn); break; } return ssl_err; } return 0; } /* Handle OpenSSL return code following SSL_write() or SSL_read(): * * - Updates conn state and last_errno. * - If update_event is nonzero, calls updateSSLEvent() when necessary. * * Returns ret_value, or -1 on error or dropped connection. */ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update_event) { /* If system call was interrupted, there's no need to go through the full * OpenSSL error handling and just report this for the caller to retry the * operation. */ if (errno == EINTR) { conn->c.last_errno = EINTR; return -1; } if (ret_value <= 0) { WantIOType want = 0; int ssl_err; if (!(ssl_err = handleSSLReturnCode(conn, ret_value, &want))) { if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; if (update_event) updateSSLEvent(conn); errno = EAGAIN; return -1; } else { if (ssl_err == SSL_ERROR_ZERO_RETURN || ((ssl_err == SSL_ERROR_SYSCALL && !errno))) { conn->c.state = CONN_STATE_CLOSED; return -1; } else { conn->c.state = CONN_STATE_ERROR; return -1; } } } return ret_value; } static void registerSSLEvent(tls_connection *conn, WantIOType want) { int mask = aeGetFileEvents(server.el, conn->c.fd); switch (want) { case WANT_READ: if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); break; case WANT_WRITE: if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); break; default: serverAssert(0); break; } } static void updateSSLEvent(tls_connection *conn) { int mask = aeGetFileEvents(server.el, conn->c.fd); int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); if (need_read && !(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); if (!need_read && (mask & AE_READABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); if (need_write && !(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); } static void tlsHandleEvent(tls_connection *conn, int mask) { int ret, conn_error; TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d", fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL, conn->flags); ERR_clear_error(); switch (conn->c.state) { case CONN_STATE_CONNECTING: conn_error = anetGetError(conn->c.fd); if (conn_error) { conn->c.last_errno = conn_error; conn->c.state = CONN_STATE_ERROR; } else { if (!(conn->flags & TLS_CONN_FLAG_FD_SET)) { SSL_set_fd(conn->ssl, conn->c.fd); conn->flags |= TLS_CONN_FLAG_FD_SET; } ret = SSL_connect(conn->ssl); if (ret <= 0) { WantIOType want = 0; if (!handleSSLReturnCode(conn, ret, &want)) { registerSSLEvent(conn, want); /* Avoid hitting UpdateSSLEvent, which knows nothing * of what SSL_connect() wants and instead looks at our * R/W handlers. */ return; } /* If not handled, it's an error */ conn->c.state = CONN_STATE_ERROR; } else { conn->c.state = CONN_STATE_CONNECTED; } } if (!callHandler((connection *) conn, conn->c.conn_handler)) return; conn->c.conn_handler = NULL; break; case CONN_STATE_ACCEPTING: ret = SSL_accept(conn->ssl); if (ret <= 0) { WantIOType want = 0; if (!handleSSLReturnCode(conn, ret, &want)) { /* Avoid hitting UpdateSSLEvent, which knows nothing * of what SSL_connect() wants and instead looks at our * R/W handlers. */ registerSSLEvent(conn, want); return; } /* If not handled, it's an error */ conn->c.state = CONN_STATE_ERROR; } else { conn->c.state = CONN_STATE_CONNECTED; } if (!callHandler((connection *) conn, conn->c.conn_handler)) return; conn->c.conn_handler = NULL; break; case CONN_STATE_CONNECTED: { int call_read = ((mask & AE_READABLE) && conn->c.read_handler) || ((mask & AE_WRITABLE) && (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE)); int call_write = ((mask & AE_WRITABLE) && conn->c.write_handler) || ((mask & AE_READABLE) && (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ)); /* Normally we execute the readable event first, and the writable * event laster. This is useful as sometimes we may be able * to serve the reply of a query immediately after processing the * query. * * However if WRITE_BARRIER is set in the mask, our application is * asking us to do the reverse: never fire the writable event * after the readable. In such a case, we invert the calls. * This is useful when, for instance, we want to do things * in the beforeSleep() hook, like fsynching a file to disk, * before replying to a client. */ int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER; if (!invert && call_read) { conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *) conn, conn->c.read_handler)) return; } /* Fire the writable event. */ if (call_write) { conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; if (!callHandler((connection *) conn, conn->c.write_handler)) return; } /* If we have to invert the call, fire the readable event now * after the writable one. */ if (invert && call_read) { conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *) conn, conn->c.read_handler)) return; } /* If SSL has pending that, already read from the socket, we're at * risk of not calling the read handler again, make sure to add it * to a list of pending connection that should be handled anyway. */ if ((mask & AE_READABLE)) { if (SSL_pending(conn->ssl) > 0) { if (!conn->pending_list_node) { listAddNodeTail(pending_list, conn); conn->pending_list_node = listLast(pending_list); } } else if (conn->pending_list_node) { listDelNode(pending_list, conn->pending_list_node); conn->pending_list_node = NULL; } } break; } default: break; } updateSSLEvent(conn); } static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) { UNUSED(el); UNUSED(fd); tls_connection *conn = clientData; tlsHandleEvent(conn, mask); } static void tlsAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) { int cport, cfd, max = MAX_ACCEPTS_PER_CALL; char cip[NET_IP_STR_LEN]; UNUSED(el); UNUSED(mask); UNUSED(privdata); while(max--) { cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport); if (cfd == ANET_ERR) { if (errno != EWOULDBLOCK) serverLog(LL_WARNING, "Accepting client connection: %s", server.neterr); return; } serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); acceptCommonHandler(connCreateAcceptedTLS(cfd, &server.tls_auth_clients),0,cip); } } static int connTLSAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) { return anetFdToString(conn->fd, ip, ip_len, port, remote); } static int connTLSIsLocal(connection *conn) { return connectionTypeTcp()->is_local(conn); } static int connTLSListen(connListener *listener) { return listenToPort(listener); } static void connTLSShutdown(connection *conn_) { tls_connection *conn = (tls_connection *) conn_; if (conn->ssl) { if (conn->c.state == CONN_STATE_CONNECTED) SSL_shutdown(conn->ssl); SSL_free(conn->ssl); conn->ssl = NULL; } connectionTypeTcp()->shutdown(conn_); } static void connTLSClose(connection *conn_) { tls_connection *conn = (tls_connection *) conn_; if (conn->ssl) { if (conn->c.state == CONN_STATE_CONNECTED) SSL_shutdown(conn->ssl); SSL_free(conn->ssl); conn->ssl = NULL; } if (conn->ssl_error) { zfree(conn->ssl_error); conn->ssl_error = NULL; } if (conn->pending_list_node) { listDelNode(pending_list, conn->pending_list_node); conn->pending_list_node = NULL; } connectionTypeTcp()->close(conn_); } static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { tls_connection *conn = (tls_connection *) _conn; int ret; if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR; ERR_clear_error(); /* Try to accept */ conn->c.conn_handler = accept_handler; ret = SSL_accept(conn->ssl); if (ret <= 0) { WantIOType want = 0; if (!handleSSLReturnCode(conn, ret, &want)) { registerSSLEvent(conn, want); /* We'll fire back */ return C_OK; } else { conn->c.state = CONN_STATE_ERROR; return C_ERR; } } conn->c.state = CONN_STATE_CONNECTED; if (!callHandler((connection *) conn, conn->c.conn_handler)) return C_OK; conn->c.conn_handler = NULL; return C_OK; } static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) { tls_connection *conn = (tls_connection *) conn_; unsigned char addr_buf[sizeof(struct in6_addr)]; if (conn->c.state != CONN_STATE_NONE) return C_ERR; ERR_clear_error(); /* Check whether addr is an IP address, if not, use the value for Server Name Indication */ if (inet_pton(AF_INET, addr, addr_buf) != 1 && inet_pton(AF_INET6, addr, addr_buf) != 1) { SSL_set_tlsext_host_name(conn->ssl, addr); } /* Initiate Socket connection first */ if (connectionTypeTcp()->connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR; /* Return now, once the socket is connected we'll initiate * TLS connection from the event handler. */ return C_OK; } static int connTLSWrite(connection *conn_, const void *data, size_t data_len) { tls_connection *conn = (tls_connection *) conn_; int ret; if (conn->c.state != CONN_STATE_CONNECTED) return -1; ERR_clear_error(); ret = SSL_write(conn->ssl, data, data_len); return updateStateAfterSSLIO(conn, ret, 1); } static int connTLSWritev(connection *conn_, const struct iovec *iov, int iovcnt) { if (iovcnt == 1) return connTLSWrite(conn_, iov[0].iov_base, iov[0].iov_len); /* Accumulate the amount of bytes of each buffer and check if it exceeds NET_MAX_WRITES_PER_EVENT. */ size_t iov_bytes_len = 0; for (int i = 0; i < iovcnt; i++) { iov_bytes_len += iov[i].iov_len; if (iov_bytes_len > NET_MAX_WRITES_PER_EVENT) break; } /* The amount of all buffers is greater than NET_MAX_WRITES_PER_EVENT, * which is not worth doing so much memory copying to reduce system calls, * therefore, invoke connTLSWrite() multiple times to avoid memory copies. */ if (iov_bytes_len > NET_MAX_WRITES_PER_EVENT) { ssize_t tot_sent = 0; for (int i = 0; i < iovcnt; i++) { ssize_t sent = connTLSWrite(conn_, iov[i].iov_base, iov[i].iov_len); if (sent <= 0) return tot_sent > 0 ? tot_sent : sent; tot_sent += sent; if ((size_t) sent != iov[i].iov_len) break; } return tot_sent; } /* The amount of all buffers is less than NET_MAX_WRITES_PER_EVENT, * which is worth doing more memory copies in exchange for fewer system calls, * so concatenate these scattered buffers into a contiguous piece of memory * and send it away by one call to connTLSWrite(). */ char buf[iov_bytes_len]; size_t offset = 0; for (int i = 0; i < iovcnt; i++) { memcpy(buf + offset, iov[i].iov_base, iov[i].iov_len); offset += iov[i].iov_len; } return connTLSWrite(conn_, buf, iov_bytes_len); } static int connTLSRead(connection *conn_, void *buf, size_t buf_len) { tls_connection *conn = (tls_connection *) conn_; int ret; if (conn->c.state != CONN_STATE_CONNECTED) return -1; ERR_clear_error(); ret = SSL_read(conn->ssl, buf, buf_len); return updateStateAfterSSLIO(conn, ret, 1); } static const char *connTLSGetLastError(connection *conn_) { tls_connection *conn = (tls_connection *) conn_; if (conn->ssl_error) return conn->ssl_error; return NULL; } static int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { conn->write_handler = func; if (barrier) conn->flags |= CONN_FLAG_WRITE_BARRIER; else conn->flags &= ~CONN_FLAG_WRITE_BARRIER; updateSSLEvent((tls_connection *) conn); return C_OK; } static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { conn->read_handler = func; updateSSLEvent((tls_connection *) conn); return C_OK; } static void setBlockingTimeout(tls_connection *conn, long long timeout) { anetBlock(NULL, conn->c.fd); anetSendTimeout(NULL, conn->c.fd, timeout); anetRecvTimeout(NULL, conn->c.fd, timeout); } static void unsetBlockingTimeout(tls_connection *conn) { anetNonBlock(NULL, conn->c.fd); anetSendTimeout(NULL, conn->c.fd, 0); anetRecvTimeout(NULL, conn->c.fd, 0); } static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, long long timeout) { tls_connection *conn = (tls_connection *) conn_; int ret; if (conn->c.state != CONN_STATE_NONE) return C_ERR; /* Initiate socket blocking connect first */ if (connectionTypeTcp()->blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR; /* Initiate TLS connection now. We set up a send/recv timeout on the socket, * which means the specified timeout will not be enforced accurately. */ SSL_set_fd(conn->ssl, conn->c.fd); setBlockingTimeout(conn, timeout); if ((ret = SSL_connect(conn->ssl)) <= 0) { conn->c.state = CONN_STATE_ERROR; return C_ERR; } unsetBlockingTimeout(conn); conn->c.state = CONN_STATE_CONNECTED; return C_OK; } static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; setBlockingTimeout(conn, timeout); SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); ERR_clear_error(); int ret = SSL_write(conn->ssl, ptr, size); ret = updateStateAfterSSLIO(conn, ret, 0); SSL_set_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); unsetBlockingTimeout(conn); return ret; } static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; setBlockingTimeout(conn, timeout); ERR_clear_error(); int ret = SSL_read(conn->ssl, ptr, size); ret = updateStateAfterSSLIO(conn, ret, 0); unsetBlockingTimeout(conn); return ret; } static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; ssize_t nread = 0; setBlockingTimeout(conn, timeout); size--; while(size) { char c; ERR_clear_error(); int ret = SSL_read(conn->ssl, &c, 1); ret = updateStateAfterSSLIO(conn, ret, 0); if (ret <= 0) { nread = -1; goto exit; } if (c == '\n') { *ptr = '\0'; if (nread && *(ptr-1) == '\r') *(ptr-1) = '\0'; goto exit; } else { *ptr++ = c; *ptr = '\0'; nread++; } size--; } exit: unsetBlockingTimeout(conn); return nread; } static const char *connTLSGetType(connection *conn_) { (void) conn_; return CONN_TYPE_TLS; } static int tlsHasPendingData() { if (!pending_list) return 0; return listLength(pending_list) > 0; } static int tlsProcessPendingData() { listIter li; listNode *ln; int processed = listLength(pending_list); listRewind(pending_list,&li); while((ln = listNext(&li))) { tls_connection *conn = listNodeValue(ln); tlsHandleEvent(conn, AE_READABLE); } return processed; } /* Fetch the peer certificate used for authentication on the specified * connection and return it as a PEM-encoded sds. */ static sds connTLSGetPeerCert(connection *conn_) { tls_connection *conn = (tls_connection *) conn_; if ((conn_->type != connectionTypeTls()) || !conn->ssl) return NULL; X509 *cert = SSL_get_peer_certificate(conn->ssl); if (!cert) return NULL; BIO *bio = BIO_new(BIO_s_mem()); if (bio == NULL || !PEM_write_bio_X509(bio, cert)) { if (bio != NULL) BIO_free(bio); return NULL; } const char *bio_ptr; long long bio_len = BIO_get_mem_data(bio, &bio_ptr); sds cert_pem = sdsnewlen(bio_ptr, bio_len); BIO_free(bio); return cert_pem; } static ConnectionType CT_TLS = { /* connection type */ .get_type = connTLSGetType, /* connection type initialize & finalize & configure */ .init = tlsInit, .cleanup = tlsCleanup, .configure = tlsConfigure, /* ae & accept & listen & error & address handler */ .ae_handler = tlsEventHandler, .accept_handler = tlsAcceptHandler, .addr = connTLSAddr, .is_local = connTLSIsLocal, .listen = connTLSListen, /* create/shutdown/close connection */ .conn_create = connCreateTLS, .conn_create_accepted = connCreateAcceptedTLS, .shutdown = connTLSShutdown, .close = connTLSClose, /* connect & accept */ .connect = connTLSConnect, .blocking_connect = connTLSBlockingConnect, .accept = connTLSAccept, /* IO */ .read = connTLSRead, .write = connTLSWrite, .writev = connTLSWritev, .set_write_handler = connTLSSetWriteHandler, .set_read_handler = connTLSSetReadHandler, .get_last_error = connTLSGetLastError, .sync_write = connTLSSyncWrite, .sync_read = connTLSSyncRead, .sync_readline = connTLSSyncReadLine, /* pending data */ .has_pending_data = tlsHasPendingData, .process_pending_data = tlsProcessPendingData, /* TLS specified methods */ .get_peer_cert = connTLSGetPeerCert, }; int RedisRegisterConnectionTypeTLS() { return connTypeRegister(&CT_TLS); } #else /* USE_OPENSSL */ int RedisRegisterConnectionTypeTLS() { serverLog(LL_VERBOSE, "Connection type %s not builtin", CONN_TYPE_TLS); return C_ERR; } #endif #if BUILD_TLS_MODULE == 2 /* BUILD_MODULE */ #include "release.h" int RedisModule_OnLoad(void *ctx, RedisModuleString **argv, int argc) { UNUSED(argv); UNUSED(argc); /* Connection modules must be part of the same build as redis. */ if (strcmp(REDIS_BUILD_ID_RAW, redisBuildIdRaw())) { serverLog(LL_NOTICE, "Connection type %s was not built together with the redis-server used.", CONN_TYPE_TLS); return REDISMODULE_ERR; } if (RedisModule_Init(ctx,"tls",1,REDISMODULE_APIVER_1) == REDISMODULE_ERR) return REDISMODULE_ERR; /* Connection modules is available only bootup. */ if ((RedisModule_GetContextFlags(ctx) & REDISMODULE_CTX_FLAGS_SERVER_STARTUP) == 0) { serverLog(LL_NOTICE, "Connection type %s can be loaded only during bootup", CONN_TYPE_TLS); return REDISMODULE_ERR; } RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD); if(connTypeRegister(&CT_TLS) != C_OK) return REDISMODULE_ERR; return REDISMODULE_OK; } int RedisModule_OnUnload(void *arg) { UNUSED(arg); serverLog(LL_NOTICE, "Connection type %s can not be unloaded", CONN_TYPE_TLS); return REDISMODULE_ERR; } #endif