diff options
-rw-r--r-- | src/cluster.c | 18 | ||||
-rw-r--r-- | src/connection.c | 15 | ||||
-rw-r--r-- | src/connection.h | 18 | ||||
-rw-r--r-- | src/networking.c | 6 | ||||
-rw-r--r-- | src/replication.c | 9 | ||||
-rw-r--r-- | src/socket.c | 11 | ||||
-rw-r--r-- | src/tls.c | 38 |
7 files changed, 72 insertions, 43 deletions
diff --git a/src/cluster.c b/src/cluster.c index a7d6f6205..b60b79153 100644 --- a/src/cluster.c +++ b/src/cluster.c @@ -119,6 +119,14 @@ dictType clusterNodesBlackListDictType = { NULL /* allow to expand */ }; +static int connTypeOfCluster() { + if (server.tls_cluster) { + return CONN_TYPE_TLS; + } + + return CONN_TYPE_SOCKET; +} + /* ----------------------------------------------------------------------------- * Initialization * -------------------------------------------------------------------------- */ @@ -865,6 +873,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) { int cport, cfd; int max = MAX_CLUSTER_ACCEPTS_PER_CALL; char cip[NET_IP_STR_LEN]; + int require_auth = TLS_CLIENT_AUTH_YES; UNUSED(el); UNUSED(mask); UNUSED(privdata); @@ -882,8 +891,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) { return; } - connection *conn = server.tls_cluster ? - connCreateAcceptedTLS(cfd, TLS_CLIENT_AUTH_YES) : connCreateAcceptedSocket(cfd); + connection *conn = connCreateAccepted(connTypeOfCluster(), cfd, &require_auth); /* Make sure connection is not in an error state */ if (connGetState(conn) != CONN_STATE_ACCEPTING) { @@ -3969,7 +3977,7 @@ static int clusterNodeCronHandleReconnect(clusterNode *node, mstime_t handshake_ if (node->link == NULL) { clusterLink *link = createClusterLink(node); - link->conn = server.tls_cluster ? connCreateTLS() : connCreateSocket(); + link->conn = connCreate(connTypeOfCluster()); connSetPrivateData(link->conn, link); if (connConnect(link->conn, node->ip, node->cport, server.bind_source_addr, clusterLinkConnectHandler) == -1) { @@ -6175,8 +6183,8 @@ migrateCachedSocket* migrateGetSocket(client *c, robj *host, robj *port, long ti dictDelete(server.migrate_cached_sockets,dictGetKey(de)); } - /* Create the socket */ - conn = server.tls_cluster ? connCreateTLS() : connCreateSocket(); + /* Create the connection */ + conn = connCreate(connTypeOfCluster()); if (connBlockingConnect(conn, host->ptr, atoi(port->ptr), timeout) != C_OK) { addReplyError(c,"-IOERR error or timeout connecting to the client"); diff --git a/src/connection.c b/src/connection.c index e28257fab..72db82212 100644 --- a/src/connection.c +++ b/src/connection.c @@ -152,3 +152,18 @@ void *connTypeGetClientCtx(int type) { return NULL; } +connection *connCreate(int type) { + ConnectionType *ct = connectionByType(type); + + serverAssert(ct && ct->conn_create); + + return ct->conn_create(); +} + +connection *connCreateAccepted(int type, int fd, void *priv) { + ConnectionType *ct = connectionByType(type); + + serverAssert(ct && ct->conn_create_accepted); + + return ct->conn_create_accepted(fd, priv); +} diff --git a/src/connection.h b/src/connection.h index 4cb74c4dc..4fca50fd1 100644 --- a/src/connection.h +++ b/src/connection.h @@ -74,6 +74,8 @@ typedef struct ConnectionType { int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote); /* create/close connection */ + connection* (*conn_create)(void); + connection* (*conn_create_accepted)(int fd, void *priv); void (*close)(struct connection *conn); /* connect & accept */ @@ -290,12 +292,6 @@ static inline int connAddrSockName(connection *conn, char *ip, size_t ip_len, in return connAddr(conn, ip, ip_len, port, 0); } -connection *connCreateSocket(); -connection *connCreateAcceptedSocket(int fd); - -connection *connCreateTLS(); -connection *connCreateAcceptedTLS(int fd, int require_auth); - static inline int connGetState(connection *conn) { return conn->state; } @@ -358,6 +354,16 @@ int connTypeInitialize(); /* Register a connection type into redis connection framework */ int connTypeRegister(ConnectionType *ct); +/* Lookup a connection type by index */ +ConnectionType *connectionByType(int type); + +/* Create a connection of specified type */ +connection *connCreate(int type); + +/* Create a accepted connection of specified type. + * @priv is connection type specified argument */ +connection *connCreateAccepted(int type, int fd, void *priv); + /* Configure a connection type. A typical case is to configure TLS. * @priv is connection type specified, * @reconfigure is boolean type to specify if overwrite the original config */ diff --git a/src/networking.c b/src/networking.c index 04ba54768..45ff585e2 100644 --- a/src/networking.c +++ b/src/networking.c @@ -1344,7 +1344,7 @@ void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask) { return; } serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); - acceptCommonHandler(connCreateAcceptedSocket(cfd),0,cip); + acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),0,cip); } } @@ -1364,7 +1364,7 @@ void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask) { return; } serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); - acceptCommonHandler(connCreateAcceptedTLS(cfd, server.tls_auth_clients),0,cip); + acceptCommonHandler(connCreateAccepted(CONN_TYPE_TLS, cfd, &server.tls_auth_clients),0,cip); } } @@ -1383,7 +1383,7 @@ void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) { return; } serverLog(LL_VERBOSE,"Accepted connection to %s", server.unixsocket); - acceptCommonHandler(connCreateAcceptedSocket(cfd),CLIENT_UNIX_SOCKET,NULL); + acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),CLIENT_UNIX_SOCKET,NULL); } } diff --git a/src/replication.c b/src/replication.c index 3ae130252..8de14c9f2 100644 --- a/src/replication.c +++ b/src/replication.c @@ -55,6 +55,13 @@ int cancelReplicationHandshake(int reconnect); int RDBGeneratedByReplication = 0; /* --------------------------- Utility functions ---------------------------- */ +static int connTypeOfReplication() { + if (server.tls_replication) { + return CONN_TYPE_TLS; + } + + return CONN_TYPE_SOCKET; +} /* Return the pointer to a string representing the slave ip:listening_port * pair. Mostly useful for logging, since we want to log a slave using its @@ -2864,7 +2871,7 @@ write_error: /* Handle sendCommand() errors. */ } int connectWithMaster(void) { - server.repl_transfer_s = server.tls_replication ? connCreateTLS() : connCreateSocket(); + server.repl_transfer_s = connCreate(connTypeOfReplication()); if (connConnect(server.repl_transfer_s, server.masterhost, server.masterport, server.bind_source_addr, syncWithMaster) == C_ERR) { serverLog(LL_WARNING,"Unable to connect to MASTER: %s", diff --git a/src/socket.c b/src/socket.c index ce293e444..5aea1954e 100644 --- a/src/socket.c +++ b/src/socket.c @@ -49,7 +49,7 @@ * depending on the implementation (for TCP they are; for TLS they aren't). */ -ConnectionType CT_Socket; +static ConnectionType CT_Socket; /* When a connection is created we must know its type already, but the * underlying socket may or may not exist: @@ -74,7 +74,7 @@ ConnectionType CT_Socket; * be embedded in different structs, not just client. */ -connection *connCreateSocket() { +static connection *connCreateSocket(void) { connection *conn = zcalloc(sizeof(connection)); conn->type = &CT_Socket; conn->fd = -1; @@ -92,7 +92,8 @@ connection *connCreateSocket() { * is not in an error state (which is not possible for a socket connection, * but could but possible with other protocols). */ -connection *connCreateAcceptedSocket(int fd) { +static connection *connCreateAcceptedSocket(int fd, void *priv) { + UNUSED(priv); connection *conn = connCreateSocket(); conn->fd = fd; conn->state = CONN_STATE_ACCEPTING; @@ -348,7 +349,7 @@ static int connSocketGetType(connection *conn) { return CONN_TYPE_SOCKET; } -ConnectionType CT_Socket = { +static ConnectionType CT_Socket = { /* connection type */ .get_type = connSocketGetType, @@ -362,6 +363,8 @@ ConnectionType CT_Socket = { .addr = connSocketAddr, /* create/close connection */ + .conn_create = connCreateSocket, + .conn_create_accepted = connCreateAcceptedSocket, .close = connSocketClose, /* connect & accept */ @@ -56,8 +56,6 @@ #define REDIS_TLS_PROTO_DEFAULT (REDIS_TLS_PROTO_TLSv1_2) #endif -extern ConnectionType CT_Socket; - static SSL_CTX *redis_tls_ctx = NULL; static SSL_CTX *redis_tls_client_ctx = NULL; @@ -421,7 +419,7 @@ error: #define TLSCONN_DEBUG(fmt, ...) #endif -ConnectionType CT_TLS; +static ConnectionType CT_TLS; /* Normal socket connections have a simple events/handler correlation. * @@ -466,7 +464,7 @@ static connection *createTLSConnection(int client_side) { return (connection *) conn; } -connection *connCreateTLS(void) { +static connection *connCreateTLS(void) { return createTLSConnection(1); } @@ -487,7 +485,8 @@ static void updateTLSError(tls_connection *conn) { * Callers should use connGetState() and verify the created connection * is not in an error state. */ -connection *connCreateAcceptedTLS(int fd, int require_auth) { +static connection *connCreateAcceptedTLS(int fd, void *priv) { + int require_auth = *(int *)priv; tls_connection *conn = (tls_connection *) createTLSConnection(0); conn->c.fd = fd; conn->c.state = CONN_STATE_ACCEPTING; @@ -550,7 +549,7 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType * return 0; } -void registerSSLEvent(tls_connection *conn, WantIOType want) { +static void registerSSLEvent(tls_connection *conn, WantIOType want) { int mask = aeGetFileEvents(server.el, conn->c.fd); switch (want) { @@ -570,7 +569,7 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) { } } -void updateSSLEvent(tls_connection *conn) { +static void updateSSLEvent(tls_connection *conn) { int mask = aeGetFileEvents(server.el, conn->c.fd); int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); @@ -744,7 +743,7 @@ static void connTLSClose(connection *conn_) { conn->pending_list_node = NULL; } - CT_Socket.close(conn_); + connectionByType(CONN_TYPE_SOCKET)->close(conn_); } static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { @@ -783,7 +782,7 @@ static int connTLSConnect(connection *conn_, const char *addr, int port, const c ERR_clear_error(); /* Initiate Socket connection first */ - if (CT_Socket.connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR; + if (connectionByType(CONN_TYPE_SOCKET)->connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR; /* Return now, once the socket is connected we'll initiate * TLS connection from the event handler. @@ -911,7 +910,7 @@ static const char *connTLSGetLastError(connection *conn_) { return NULL; } -int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { +static int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { conn->write_handler = func; if (barrier) conn->flags |= CONN_FLAG_WRITE_BARRIER; @@ -921,7 +920,7 @@ int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int ba return C_OK; } -int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { +static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { conn->read_handler = func; updateSSLEvent((tls_connection *) conn); return C_OK; @@ -946,7 +945,7 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, if (conn->c.state != CONN_STATE_NONE) return C_ERR; /* Initiate socket blocking connect first */ - if (CT_Socket.blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR; + if (connectionByType(CONN_TYPE_SOCKET)->blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR; /* Initiate TLS connection now. We set up a send/recv timeout on the socket, * which means the specified timeout will not be enforced accurately. */ @@ -1072,7 +1071,7 @@ static void *tlsGetClientCtx(void) { return redis_tls_client_ctx; } -ConnectionType CT_TLS = { +static ConnectionType CT_TLS = { /* connection type */ .get_type = connTLSGetType, @@ -1086,6 +1085,8 @@ ConnectionType CT_TLS = { .addr = connTLSAddr, /* create/close connection */ + .conn_create = connCreateTLS, + .conn_create_accepted = connCreateAcceptedTLS, .close = connTLSClose, /* connect & accept */ @@ -1126,15 +1127,4 @@ int RedisRegisterConnectionTypeTLS() return C_ERR; } -connection *connCreateTLS(void) { - return NULL; -} - -connection *connCreateAcceptedTLS(int fd, int require_auth) { - UNUSED(fd); - UNUSED(require_auth); - - return NULL; -} - #endif |