summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cluster.c18
-rw-r--r--src/connection.c15
-rw-r--r--src/connection.h18
-rw-r--r--src/networking.c6
-rw-r--r--src/replication.c9
-rw-r--r--src/socket.c11
-rw-r--r--src/tls.c38
7 files changed, 72 insertions, 43 deletions
diff --git a/src/cluster.c b/src/cluster.c
index a7d6f6205..b60b79153 100644
--- a/src/cluster.c
+++ b/src/cluster.c
@@ -119,6 +119,14 @@ dictType clusterNodesBlackListDictType = {
NULL /* allow to expand */
};
+static int connTypeOfCluster() {
+ if (server.tls_cluster) {
+ return CONN_TYPE_TLS;
+ }
+
+ return CONN_TYPE_SOCKET;
+}
+
/* -----------------------------------------------------------------------------
* Initialization
* -------------------------------------------------------------------------- */
@@ -865,6 +873,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd;
int max = MAX_CLUSTER_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
+ int require_auth = TLS_CLIENT_AUTH_YES;
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
@@ -882,8 +891,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return;
}
- connection *conn = server.tls_cluster ?
- connCreateAcceptedTLS(cfd, TLS_CLIENT_AUTH_YES) : connCreateAcceptedSocket(cfd);
+ connection *conn = connCreateAccepted(connTypeOfCluster(), cfd, &require_auth);
/* Make sure connection is not in an error state */
if (connGetState(conn) != CONN_STATE_ACCEPTING) {
@@ -3969,7 +3977,7 @@ static int clusterNodeCronHandleReconnect(clusterNode *node, mstime_t handshake_
if (node->link == NULL) {
clusterLink *link = createClusterLink(node);
- link->conn = server.tls_cluster ? connCreateTLS() : connCreateSocket();
+ link->conn = connCreate(connTypeOfCluster());
connSetPrivateData(link->conn, link);
if (connConnect(link->conn, node->ip, node->cport, server.bind_source_addr,
clusterLinkConnectHandler) == -1) {
@@ -6175,8 +6183,8 @@ migrateCachedSocket* migrateGetSocket(client *c, robj *host, robj *port, long ti
dictDelete(server.migrate_cached_sockets,dictGetKey(de));
}
- /* Create the socket */
- conn = server.tls_cluster ? connCreateTLS() : connCreateSocket();
+ /* Create the connection */
+ conn = connCreate(connTypeOfCluster());
if (connBlockingConnect(conn, host->ptr, atoi(port->ptr), timeout)
!= C_OK) {
addReplyError(c,"-IOERR error or timeout connecting to the client");
diff --git a/src/connection.c b/src/connection.c
index e28257fab..72db82212 100644
--- a/src/connection.c
+++ b/src/connection.c
@@ -152,3 +152,18 @@ void *connTypeGetClientCtx(int type) {
return NULL;
}
+connection *connCreate(int type) {
+ ConnectionType *ct = connectionByType(type);
+
+ serverAssert(ct && ct->conn_create);
+
+ return ct->conn_create();
+}
+
+connection *connCreateAccepted(int type, int fd, void *priv) {
+ ConnectionType *ct = connectionByType(type);
+
+ serverAssert(ct && ct->conn_create_accepted);
+
+ return ct->conn_create_accepted(fd, priv);
+}
diff --git a/src/connection.h b/src/connection.h
index 4cb74c4dc..4fca50fd1 100644
--- a/src/connection.h
+++ b/src/connection.h
@@ -74,6 +74,8 @@ typedef struct ConnectionType {
int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote);
/* create/close connection */
+ connection* (*conn_create)(void);
+ connection* (*conn_create_accepted)(int fd, void *priv);
void (*close)(struct connection *conn);
/* connect & accept */
@@ -290,12 +292,6 @@ static inline int connAddrSockName(connection *conn, char *ip, size_t ip_len, in
return connAddr(conn, ip, ip_len, port, 0);
}
-connection *connCreateSocket();
-connection *connCreateAcceptedSocket(int fd);
-
-connection *connCreateTLS();
-connection *connCreateAcceptedTLS(int fd, int require_auth);
-
static inline int connGetState(connection *conn) {
return conn->state;
}
@@ -358,6 +354,16 @@ int connTypeInitialize();
/* Register a connection type into redis connection framework */
int connTypeRegister(ConnectionType *ct);
+/* Lookup a connection type by index */
+ConnectionType *connectionByType(int type);
+
+/* Create a connection of specified type */
+connection *connCreate(int type);
+
+/* Create a accepted connection of specified type.
+ * @priv is connection type specified argument */
+connection *connCreateAccepted(int type, int fd, void *priv);
+
/* Configure a connection type. A typical case is to configure TLS.
* @priv is connection type specified,
* @reconfigure is boolean type to specify if overwrite the original config */
diff --git a/src/networking.c b/src/networking.c
index 04ba54768..45ff585e2 100644
--- a/src/networking.c
+++ b/src/networking.c
@@ -1344,7 +1344,7 @@ void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
- acceptCommonHandler(connCreateAcceptedSocket(cfd),0,cip);
+ acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),0,cip);
}
}
@@ -1364,7 +1364,7 @@ void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
- acceptCommonHandler(connCreateAcceptedTLS(cfd, server.tls_auth_clients),0,cip);
+ acceptCommonHandler(connCreateAccepted(CONN_TYPE_TLS, cfd, &server.tls_auth_clients),0,cip);
}
}
@@ -1383,7 +1383,7 @@ void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return;
}
serverLog(LL_VERBOSE,"Accepted connection to %s", server.unixsocket);
- acceptCommonHandler(connCreateAcceptedSocket(cfd),CLIENT_UNIX_SOCKET,NULL);
+ acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),CLIENT_UNIX_SOCKET,NULL);
}
}
diff --git a/src/replication.c b/src/replication.c
index 3ae130252..8de14c9f2 100644
--- a/src/replication.c
+++ b/src/replication.c
@@ -55,6 +55,13 @@ int cancelReplicationHandshake(int reconnect);
int RDBGeneratedByReplication = 0;
/* --------------------------- Utility functions ---------------------------- */
+static int connTypeOfReplication() {
+ if (server.tls_replication) {
+ return CONN_TYPE_TLS;
+ }
+
+ return CONN_TYPE_SOCKET;
+}
/* Return the pointer to a string representing the slave ip:listening_port
* pair. Mostly useful for logging, since we want to log a slave using its
@@ -2864,7 +2871,7 @@ write_error: /* Handle sendCommand() errors. */
}
int connectWithMaster(void) {
- server.repl_transfer_s = server.tls_replication ? connCreateTLS() : connCreateSocket();
+ server.repl_transfer_s = connCreate(connTypeOfReplication());
if (connConnect(server.repl_transfer_s, server.masterhost, server.masterport,
server.bind_source_addr, syncWithMaster) == C_ERR) {
serverLog(LL_WARNING,"Unable to connect to MASTER: %s",
diff --git a/src/socket.c b/src/socket.c
index ce293e444..5aea1954e 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -49,7 +49,7 @@
* depending on the implementation (for TCP they are; for TLS they aren't).
*/
-ConnectionType CT_Socket;
+static ConnectionType CT_Socket;
/* When a connection is created we must know its type already, but the
* underlying socket may or may not exist:
@@ -74,7 +74,7 @@ ConnectionType CT_Socket;
* be embedded in different structs, not just client.
*/
-connection *connCreateSocket() {
+static connection *connCreateSocket(void) {
connection *conn = zcalloc(sizeof(connection));
conn->type = &CT_Socket;
conn->fd = -1;
@@ -92,7 +92,8 @@ connection *connCreateSocket() {
* is not in an error state (which is not possible for a socket connection,
* but could but possible with other protocols).
*/
-connection *connCreateAcceptedSocket(int fd) {
+static connection *connCreateAcceptedSocket(int fd, void *priv) {
+ UNUSED(priv);
connection *conn = connCreateSocket();
conn->fd = fd;
conn->state = CONN_STATE_ACCEPTING;
@@ -348,7 +349,7 @@ static int connSocketGetType(connection *conn) {
return CONN_TYPE_SOCKET;
}
-ConnectionType CT_Socket = {
+static ConnectionType CT_Socket = {
/* connection type */
.get_type = connSocketGetType,
@@ -362,6 +363,8 @@ ConnectionType CT_Socket = {
.addr = connSocketAddr,
/* create/close connection */
+ .conn_create = connCreateSocket,
+ .conn_create_accepted = connCreateAcceptedSocket,
.close = connSocketClose,
/* connect & accept */
diff --git a/src/tls.c b/src/tls.c
index 98c5d9d99..39108afed 100644
--- a/src/tls.c
+++ b/src/tls.c
@@ -56,8 +56,6 @@
#define REDIS_TLS_PROTO_DEFAULT (REDIS_TLS_PROTO_TLSv1_2)
#endif
-extern ConnectionType CT_Socket;
-
static SSL_CTX *redis_tls_ctx = NULL;
static SSL_CTX *redis_tls_client_ctx = NULL;
@@ -421,7 +419,7 @@ error:
#define TLSCONN_DEBUG(fmt, ...)
#endif
-ConnectionType CT_TLS;
+static ConnectionType CT_TLS;
/* Normal socket connections have a simple events/handler correlation.
*
@@ -466,7 +464,7 @@ static connection *createTLSConnection(int client_side) {
return (connection *) conn;
}
-connection *connCreateTLS(void) {
+static connection *connCreateTLS(void) {
return createTLSConnection(1);
}
@@ -487,7 +485,8 @@ static void updateTLSError(tls_connection *conn) {
* Callers should use connGetState() and verify the created connection
* is not in an error state.
*/
-connection *connCreateAcceptedTLS(int fd, int require_auth) {
+static connection *connCreateAcceptedTLS(int fd, void *priv) {
+ int require_auth = *(int *)priv;
tls_connection *conn = (tls_connection *) createTLSConnection(0);
conn->c.fd = fd;
conn->c.state = CONN_STATE_ACCEPTING;
@@ -550,7 +549,7 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *
return 0;
}
-void registerSSLEvent(tls_connection *conn, WantIOType want) {
+static void registerSSLEvent(tls_connection *conn, WantIOType want) {
int mask = aeGetFileEvents(server.el, conn->c.fd);
switch (want) {
@@ -570,7 +569,7 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) {
}
}
-void updateSSLEvent(tls_connection *conn) {
+static void updateSSLEvent(tls_connection *conn) {
int mask = aeGetFileEvents(server.el, conn->c.fd);
int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ);
int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE);
@@ -744,7 +743,7 @@ static void connTLSClose(connection *conn_) {
conn->pending_list_node = NULL;
}
- CT_Socket.close(conn_);
+ connectionByType(CONN_TYPE_SOCKET)->close(conn_);
}
static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
@@ -783,7 +782,7 @@ static int connTLSConnect(connection *conn_, const char *addr, int port, const c
ERR_clear_error();
/* Initiate Socket connection first */
- if (CT_Socket.connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR;
+ if (connectionByType(CONN_TYPE_SOCKET)->connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR;
/* Return now, once the socket is connected we'll initiate
* TLS connection from the event handler.
@@ -911,7 +910,7 @@ static const char *connTLSGetLastError(connection *conn_) {
return NULL;
}
-int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) {
+static int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) {
conn->write_handler = func;
if (barrier)
conn->flags |= CONN_FLAG_WRITE_BARRIER;
@@ -921,7 +920,7 @@ int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int ba
return C_OK;
}
-int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) {
+static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) {
conn->read_handler = func;
updateSSLEvent((tls_connection *) conn);
return C_OK;
@@ -946,7 +945,7 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
if (conn->c.state != CONN_STATE_NONE) return C_ERR;
/* Initiate socket blocking connect first */
- if (CT_Socket.blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR;
+ if (connectionByType(CONN_TYPE_SOCKET)->blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR;
/* Initiate TLS connection now. We set up a send/recv timeout on the socket,
* which means the specified timeout will not be enforced accurately. */
@@ -1072,7 +1071,7 @@ static void *tlsGetClientCtx(void) {
return redis_tls_client_ctx;
}
-ConnectionType CT_TLS = {
+static ConnectionType CT_TLS = {
/* connection type */
.get_type = connTLSGetType,
@@ -1086,6 +1085,8 @@ ConnectionType CT_TLS = {
.addr = connTLSAddr,
/* create/close connection */
+ .conn_create = connCreateTLS,
+ .conn_create_accepted = connCreateAcceptedTLS,
.close = connTLSClose,
/* connect & accept */
@@ -1126,15 +1127,4 @@ int RedisRegisterConnectionTypeTLS()
return C_ERR;
}
-connection *connCreateTLS(void) {
- return NULL;
-}
-
-connection *connCreateAcceptedTLS(int fd, int require_auth) {
- UNUSED(fd);
- UNUSED(require_auth);
-
- return NULL;
-}
-
#endif