diff options
author | Harkrishn Patro <30795839+hpatro@users.noreply.github.com> | 2022-01-03 01:54:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-02 16:54:47 -0800 |
commit | 9f8885760b53e6d3952b9c9b41f9e6c48dfa6cec (patch) | |
tree | 770dfdbff19a1a2a1c71a642ebd844d592ef3d26 /src | |
parent | b8ba942ac2aabf51fd96134d9fa21b47d3baff4a (diff) | |
download | redis-9f8885760b53e6d3952b9c9b41f9e6c48dfa6cec.tar.gz |
Sharded pubsub implementation (#8621)
This commit implements a sharded pubsub implementation based off of shard channels.
Co-authored-by: Harkrishn Patro <harkrisp@amazon.com>
Co-authored-by: Madelyn Olson <madelyneolson@gmail.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/acl.c | 20 | ||||
-rw-r--r-- | src/cluster.c | 169 | ||||
-rw-r--r-- | src/cluster.h | 5 | ||||
-rw-r--r-- | src/commands.c | 70 | ||||
-rw-r--r-- | src/commands/pubsub-shardchannels.json | 23 | ||||
-rw-r--r-- | src/commands/pubsub-shardnumsub.json | 16 | ||||
-rw-r--r-- | src/commands/spublish.json | 46 | ||||
-rw-r--r-- | src/commands/ssubscribe.json | 42 | ||||
-rw-r--r-- | src/commands/sunsubscribe.json | 43 | ||||
-rw-r--r-- | src/config.c | 1 | ||||
-rw-r--r-- | src/module.c | 1 | ||||
-rw-r--r-- | src/networking.c | 4 | ||||
-rw-r--r-- | src/pubsub.c | 350 | ||||
-rw-r--r-- | src/redis-cli.c | 3 | ||||
-rw-r--r-- | src/server.c | 10 | ||||
-rw-r--r-- | src/server.h | 23 | ||||
-rw-r--r-- | src/tracking.c | 6 |
17 files changed, 756 insertions, 76 deletions
@@ -1307,8 +1307,11 @@ int ACLCheckCommandPerm(const user *u, struct redisCommand *cmd, robj **argv, in } /* Check if the user can execute commands explicitly touching the keys - * mentioned in the command arguments. */ + * mentioned in the command arguments. Shard channels are treated as + * special keys for client library to rely on `COMMAND` command + * to discover the node to connect to. These don't need acl key check. */ if (!(u->flags & USER_FLAG_ALLKEYS) && + !(cmd->flags & CMD_PUBSUB) && (cmd->getkeys_proc || cmd->key_specs_num)) { getKeysResult result = GETKEYS_RESULT_INIT; @@ -1392,6 +1395,7 @@ void ACLKillPubsubClientsIfNeeded(user *u, list *upcoming) { } /* Check for channel violations. */ if (!kill) { + /* Check for global channels violation. */ dictIterator *di = dictGetIterator(c->pubsub_channels); dictEntry *de; while (!kill && ((de = dictNext(di)) != NULL)) { @@ -1400,6 +1404,16 @@ void ACLKillPubsubClientsIfNeeded(user *u, list *upcoming) { ACL_DENIED_CHANNEL); } dictReleaseIterator(di); + + /* Check for shard channels violation. */ + di = dictGetIterator(c->pubsubshard_channels); + while (!kill && ((de = dictNext(di)) != NULL)) { + o = dictGetKey(de); + kill = (ACLCheckPubsubChannelPerm(o->ptr,upcoming,0) == + ACL_DENIED_CHANNEL); + } + + dictReleaseIterator(di); } /* Kill it. */ @@ -1448,9 +1462,9 @@ int ACLCheckAllUserCommandPerm(const user *u, struct redisCommand *cmd, robj **a int acl_retval = ACLCheckCommandPerm(u,cmd,argv,argc,idxptr); if (acl_retval != ACL_OK) return acl_retval; - if (cmd->proc == publishCommand) + if (cmd->proc == publishCommand || cmd->proc == spublishCommand) acl_retval = ACLCheckPubsubPerm(u,argv,1,1,0,idxptr); - else if (cmd->proc == subscribeCommand) + else if (cmd->proc == subscribeCommand || cmd->proc == ssubscribeCommand) acl_retval = ACLCheckPubsubPerm(u,argv,1,argc-1,0,idxptr); else if (cmd->proc == psubscribeCommand) acl_retval = ACLCheckPubsubPerm(u,argv,1,argc-1,1,idxptr); diff --git a/src/cluster.c b/src/cluster.c index 78e273f34..81322a8aa 100644 --- a/src/cluster.c +++ b/src/cluster.c @@ -57,6 +57,7 @@ void clusterUpdateState(void); int clusterNodeGetSlotBit(clusterNode *n, int slot); sds clusterGenNodesDescription(int filter, int use_pport); clusterNode *clusterLookupNode(const char *name); +list *clusterGetNodesServingMySlots(clusterNode *node); int clusterNodeAddSlave(clusterNode *master, clusterNode *slave); int clusterAddSlot(clusterNode *n, int slot); int clusterDelSlot(int slot); @@ -77,7 +78,9 @@ uint64_t clusterGetMaxEpoch(void); int clusterBumpConfigEpochWithoutConsensus(void); void moduleCallClusterReceivers(const char *sender_id, uint64_t module_id, uint8_t type, const unsigned char *payload, uint32_t len); const char *clusterGetMessageTypeString(int type); +void removeChannelsInSlot(unsigned int slot); unsigned int countKeysInSlot(unsigned int hashslot); +unsigned int countChannelsInSlot(unsigned int hashslot); unsigned int delKeysInSlot(unsigned int hashslot); /* Links to the next and previous entries for keys in the same slot are stored @@ -631,6 +634,9 @@ void clusterInit(void) { /* Initialize data for the Slot to key API. */ slotToKeyInit(server.db); + /* The slots -> channels map is a radix tree. Initialize it here. */ + server.cluster->slots_to_channels = raxNew(); + /* Set myself->port/cport/pport to my listening ports, we'll just need to * discover the IP address via MEET messages. */ deriveAnnouncedPorts(&myself->port, &myself->pport, &myself->cport); @@ -1146,6 +1152,17 @@ clusterNode *clusterLookupNode(const char *name) { return dictGetVal(de); } +/* Get all the nodes serving the same slots as myself. */ +list *clusterGetNodesServingMySlots(clusterNode *node) { + list *nodes_for_slot = listCreate(); + clusterNode *my_primary = nodeIsMaster(node) ? node : node->slaveof; + listAddNodeTail(nodes_for_slot, my_primary); + for (int i=0; i < my_primary->numslaves; i++) { + listAddNodeTail(nodes_for_slot, my_primary->slaves[i]); + } + return nodes_for_slot; +} + /* This is only used after the handshake. When we connect a given IP/PORT * as a result of CLUSTER MEET we don't have the node name yet, so we * pick a random one, and will fix it when we receive the PONG request using @@ -1921,7 +1938,7 @@ int clusterProcessPacket(clusterLink *link) { explen += sizeof(clusterMsgDataFail); if (totlen != explen) return 1; - } else if (type == CLUSTERMSG_TYPE_PUBLISH) { + } else if (type == CLUSTERMSG_TYPE_PUBLISH || type == CLUSTERMSG_TYPE_PUBLISHSHARD) { uint32_t explen = sizeof(clusterMsg)-sizeof(union clusterMsgData); explen += sizeof(clusterMsgDataPublish) - @@ -2278,7 +2295,7 @@ int clusterProcessPacket(clusterLink *link) { "Ignoring FAIL message from unknown node %.40s about %.40s", hdr->sender, hdr->data.fail.about.nodename); } - } else if (type == CLUSTERMSG_TYPE_PUBLISH) { + } else if (type == CLUSTERMSG_TYPE_PUBLISH || type == CLUSTERMSG_TYPE_PUBLISHSHARD) { if (!sender) return 1; /* We don't know that node. */ robj *channel, *message; @@ -2286,8 +2303,10 @@ int clusterProcessPacket(clusterLink *link) { /* Don't bother creating useless objects if there are no * Pub/Sub subscribers. */ - if (dictSize(server.pubsub_channels) || - dictSize(server.pubsub_patterns)) + if ((type == CLUSTERMSG_TYPE_PUBLISH + && serverPubsubSubscriptionCount() > 0) + || (type == CLUSTERMSG_TYPE_PUBLISHSHARD + && serverPubsubShardSubscriptionCount() > 0)) { channel_len = ntohl(hdr->data.publish.msg.channel_len); message_len = ntohl(hdr->data.publish.msg.message_len); @@ -2296,7 +2315,11 @@ int clusterProcessPacket(clusterLink *link) { message = createStringObject( (char*)hdr->data.publish.msg.bulk_data+channel_len, message_len); - pubsubPublishMessage(channel,message); + if (type == CLUSTERMSG_TYPE_PUBLISHSHARD) { + pubsubPublishMessageShard(channel, message); + } else { + pubsubPublishMessage(channel,message); + } decrRefCount(channel); decrRefCount(message); } @@ -2841,7 +2864,7 @@ void clusterBroadcastPong(int target) { * the 'bulk_data', sanitizer generates an out-of-bounds error which is a false * positive in this context. */ REDIS_NO_SANITIZE("bounds") -void clusterSendPublish(clusterLink *link, robj *channel, robj *message) { +void clusterSendPublish(clusterLink *link, robj *channel, robj *message, uint16_t type) { unsigned char *payload; clusterMsg buf[1]; clusterMsg *hdr = (clusterMsg*) buf; @@ -2853,7 +2876,7 @@ void clusterSendPublish(clusterLink *link, robj *channel, robj *message) { channel_len = sdslen(channel->ptr); message_len = sdslen(message->ptr); - clusterBuildMessageHdr(hdr,CLUSTERMSG_TYPE_PUBLISH); + clusterBuildMessageHdr(hdr,type); totlen = sizeof(clusterMsg)-sizeof(union clusterMsgData); totlen += sizeof(clusterMsgDataPublish) - 8 + channel_len + message_len; @@ -2976,7 +2999,28 @@ int clusterSendModuleMessageToTarget(const char *target, uint64_t module_id, uin * messages to hosts without receives for a given channel. * -------------------------------------------------------------------------- */ void clusterPropagatePublish(robj *channel, robj *message) { - clusterSendPublish(NULL, channel, message); + clusterSendPublish(NULL, channel, message, CLUSTERMSG_TYPE_PUBLISH); +} + +/* ----------------------------------------------------------------------------- + * CLUSTER Pub/Sub shard support + * + * Publish this message across the slot (primary/replica). + * -------------------------------------------------------------------------- */ +void clusterPropagatePublishShard(robj *channel, robj *message) { + list *nodes_for_slot = clusterGetNodesServingMySlots(server.cluster->myself); + if (listLength(nodes_for_slot) != 0) { + listIter li; + listNode *ln; + listRewind(nodes_for_slot, &li); + while((ln = listNext(&li))) { + clusterNode *node = listNodeValue(ln); + if (node != myself) { + clusterSendPublish(node->link, channel, message, CLUSTERMSG_TYPE_PUBLISHSHARD); + } + } + } + listRelease(nodes_for_slot); } /* ----------------------------------------------------------------------------- @@ -4075,6 +4119,14 @@ int clusterDelSlot(int slot) { clusterNode *n = server.cluster->slots[slot]; if (!n) return C_ERR; + + /* Cleanup the channels in master/replica as part of slot deletion. */ + list *nodes_for_slot = clusterGetNodesServingMySlots(n); + listNode *ln = listSearchKey(nodes_for_slot, myself); + if (ln != NULL) { + removeChannelsInSlot(slot); + } + listRelease(nodes_for_slot); serverAssert(clusterNodeClearSlotBit(n,slot) == 1); server.cluster->slots[slot] = NULL; return C_OK; @@ -4574,6 +4626,7 @@ const char *clusterGetMessageTypeString(int type) { case CLUSTERMSG_TYPE_MEET: return "meet"; case CLUSTERMSG_TYPE_FAIL: return "fail"; case CLUSTERMSG_TYPE_PUBLISH: return "publish"; + case CLUSTERMSG_TYPE_PUBLISHSHARD: return "publishshard"; case CLUSTERMSG_TYPE_FAILOVER_AUTH_REQUEST: return "auth-req"; case CLUSTERMSG_TYPE_FAILOVER_AUTH_ACK: return "auth-ack"; case CLUSTERMSG_TYPE_UPDATE: return "update"; @@ -5362,6 +5415,30 @@ NULL } } +void removeChannelsInSlot(unsigned int slot) { + unsigned int channelcount = countChannelsInSlot(slot); + if (channelcount == 0) return; + + /* Retrieve all the channels for the slot. */ + robj **channels = zmalloc(sizeof(robj*)*channelcount); + raxIterator iter; + int j = 0; + unsigned char indexed[2]; + + indexed[0] = (slot >> 8) & 0xff; + indexed[1] = slot & 0xff; + raxStart(&iter,server.cluster->slots_to_channels); + raxSeek(&iter,">=",indexed,2); + while(raxNext(&iter)) { + if (iter.key[0] != indexed[0] || iter.key[1] != indexed[1]) break; + channels[j++] = createStringObject((char*)iter.key + 2, iter.key_len - 2); + } + raxStop(&iter); + + pubsubUnsubscribeShardChannels(channels, channelcount); + zfree(channels); +} + /* ----------------------------------------------------------------------------- * DUMP, RESTORE and MIGRATE commands * -------------------------------------------------------------------------- */ @@ -6121,6 +6198,10 @@ clusterNode *getNodeByQuery(client *c, struct redisCommand *cmd, robj **argv, in mc.cmd = cmd; } + int is_pubsubshard = cmd->proc == ssubscribeCommand || + cmd->proc == sunsubscribeCommand || + cmd->proc == spublishCommand; + /* Check that all the keys are in the same hash slot, and obtain this * slot and the node associated. */ for (i = 0; i < ms->count; i++) { @@ -6172,8 +6253,8 @@ clusterNode *getNodeByQuery(client *c, struct redisCommand *cmd, robj **argv, in importing_slot = 1; } } else { - /* If it is not the first key, make sure it is exactly - * the same key as the first we saw. */ + /* If it is not the first key/channel, make sure it is exactly + * the same key/channel as the first we saw. */ if (!equalStringObjects(firstkey,thiskey)) { if (slot != thisslot) { /* Error: multiple keys from different slots. */ @@ -6183,15 +6264,20 @@ clusterNode *getNodeByQuery(client *c, struct redisCommand *cmd, robj **argv, in return NULL; } else { /* Flag this request as one with multiple different - * keys. */ + * keys/channels. */ multiple_keys = 1; } } } - /* Migrating / Importing slot? Count keys we don't have. */ + /* Migrating / Importing slot? Count keys we don't have. + * If it is pubsubshard command, it isn't required to check + * the channel being present or not in the node during the + * slot migration, the channel will be served from the source + * node until the migration completes with CLUSTER SETSLOT <slot> + * NODE <node-id>. */ int flags = LOOKUP_NOTOUCH | LOOKUP_NOSTATS | LOOKUP_NONOTIFY; - if ((migrating_slot || importing_slot) && + if ((migrating_slot || importing_slot) && !is_pubsubshard && lookupKeyReadWithFlags(&server.db[0], thiskey, flags) == NULL) { missing_keys++; @@ -6207,7 +6293,12 @@ clusterNode *getNodeByQuery(client *c, struct redisCommand *cmd, robj **argv, in /* Cluster is globally down but we got keys? We only serve the request * if it is a read command and when allow_reads_when_down is enabled. */ if (server.cluster->state != CLUSTER_OK) { - if (!server.cluster_allow_reads_when_down) { + if (is_pubsubshard) { + if (!server.cluster_allow_pubsubshard_when_down) { + if (error_code) *error_code = CLUSTER_REDIR_DOWN_STATE; + return NULL; + } + } else if (!server.cluster_allow_reads_when_down) { /* The cluster is configured to block commands when the * cluster is down. */ if (error_code) *error_code = CLUSTER_REDIR_DOWN_STATE; @@ -6259,7 +6350,7 @@ clusterNode *getNodeByQuery(client *c, struct redisCommand *cmd, robj **argv, in * is serving, we can reply without redirection. */ int is_write_command = (c->cmd->flags & CMD_WRITE) || (c->cmd->proc == execCommand && (c->mstate.cmd_flags & CMD_WRITE)); - if (c->flags & CLIENT_READONLY && + if (((c->flags & CLIENT_READONLY) || is_pubsubshard) && !is_write_command && nodeIsSlave(myself) && myself->slaveof == n) @@ -6482,3 +6573,51 @@ unsigned int delKeysInSlot(unsigned int hashslot) { unsigned int countKeysInSlot(unsigned int hashslot) { return (*server.db->slots_to_keys).by_slot[hashslot].count; } + +/* ----------------------------------------------------------------------------- + * Operation(s) on channel rax tree. + * -------------------------------------------------------------------------- */ + +void slotToChannelUpdate(sds channel, int add) { + size_t keylen = sdslen(channel); + unsigned int hashslot = keyHashSlot(channel,keylen); + unsigned char buf[64]; + unsigned char *indexed = buf; + + if (keylen+2 > 64) indexed = zmalloc(keylen+2); + indexed[0] = (hashslot >> 8) & 0xff; + indexed[1] = hashslot & 0xff; + memcpy(indexed+2,channel,keylen); + if (add) { + raxInsert(server.cluster->slots_to_channels,indexed,keylen+2,NULL,NULL); + } else { + raxRemove(server.cluster->slots_to_channels,indexed,keylen+2,NULL); + } + if (indexed != buf) zfree(indexed); +} + +void slotToChannelAdd(sds channel) { + slotToChannelUpdate(channel,1); +} + +void slotToChannelDel(sds channel) { + slotToChannelUpdate(channel,0); +} + +/* Get the count of the channels for a given slot. */ +unsigned int countChannelsInSlot(unsigned int hashslot) { + raxIterator iter; + int j = 0; + unsigned char indexed[2]; + + indexed[0] = (hashslot >> 8) & 0xff; + indexed[1] = hashslot & 0xff; + raxStart(&iter,server.cluster->slots_to_channels); + raxSeek(&iter,">=",indexed,2); + while(raxNext(&iter)) { + if (iter.key[0] != indexed[0] || iter.key[1] != indexed[1]) break; + j++; + } + raxStop(&iter); + return j; +} diff --git a/src/cluster.h b/src/cluster.h index d64e2a5b9..a28176e4b 100644 --- a/src/cluster.h +++ b/src/cluster.h @@ -97,6 +97,7 @@ typedef struct clusterLink { #define CLUSTERMSG_TYPE_MFSTART 8 /* Pause clients for manual failover */ #define CLUSTERMSG_TYPE_MODULE 9 /* Module cluster API message. */ #define CLUSTERMSG_TYPE_COUNT 10 /* Total number of message types. */ +#define CLUSTERMSG_TYPE_PUBLISHSHARD 11 /* Pub/Sub Publish shard propagation */ /* Flags that a module can set in order to prevent certain Redis Cluster * features to be enabled. Useful when implementing a different distributed @@ -173,6 +174,7 @@ typedef struct clusterState { clusterNode *migrating_slots_to[CLUSTER_SLOTS]; clusterNode *importing_slots_from[CLUSTER_SLOTS]; clusterNode *slots[CLUSTER_SLOTS]; + rax *slots_to_channels; /* The following fields are used to take the slave state on elections. */ mstime_t failover_auth_time; /* Time of previous or next election. */ int failover_auth_count; /* Number of votes received so far. */ @@ -320,6 +322,7 @@ int verifyClusterConfigWithData(void); unsigned long getClusterConnectionsCount(void); int clusterSendModuleMessageToTarget(const char *target, uint64_t module_id, uint8_t type, unsigned char *payload, uint32_t len); void clusterPropagatePublish(robj *channel, robj *message); +void clusterPropagatePublishShard(robj *channel, robj *message); unsigned int keyHashSlot(char *key, int keylen); void slotToKeyAddEntry(dictEntry *entry, redisDb *db); void slotToKeyDelEntry(dictEntry *entry, redisDb *db); @@ -329,5 +332,7 @@ void slotToKeyFlush(redisDb *db); void slotToKeyDestroy(redisDb *db); void clusterUpdateMyselfFlags(void); void clusterUpdateMyselfIp(void); +void slotToChannelAdd(sds channel); +void slotToChannelDel(sds channel); #endif /* __CLUSTER_H */ diff --git a/src/commands.c b/src/commands.c index bf0537b4f..6e1a511ca 100644 --- a/src/commands.c +++ b/src/commands.c @@ -2913,12 +2913,36 @@ struct redisCommandArg PUBSUB_NUMSUB_Args[] = { {0} }; +/********** PUBSUB SHARDCHANNELS ********************/ + +/* PUBSUB SHARDCHANNELS history */ +#define PUBSUB_SHARDCHANNELS_History NULL + +/* PUBSUB SHARDCHANNELS hints */ +#define PUBSUB_SHARDCHANNELS_Hints NULL + +/* PUBSUB SHARDCHANNELS argument table */ +struct redisCommandArg PUBSUB_SHARDCHANNELS_Args[] = { +{"pattern",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_OPTIONAL}, +{0} +}; + +/********** PUBSUB SHARDNUMSUB ********************/ + +/* PUBSUB SHARDNUMSUB history */ +#define PUBSUB_SHARDNUMSUB_History NULL + +/* PUBSUB SHARDNUMSUB hints */ +#define PUBSUB_SHARDNUMSUB_Hints NULL + /* PUBSUB command table */ struct redisCommand PUBSUB_Subcommands[] = { {"channels","List active channels","O(N) where N is the number of active channels, and assuming constant time pattern matching (relatively short channels and patterns)","2.8.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_CHANNELS_History,PUBSUB_CHANNELS_Hints,pubsubCommand,-2,CMD_PUBSUB|CMD_LOADING|CMD_STALE,0,.args=PUBSUB_CHANNELS_Args}, {"help","Show helpful text about the different subcommands","O(1)","6.2.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_HELP_History,PUBSUB_HELP_Hints,pubsubCommand,2,CMD_LOADING|CMD_STALE,0}, {"numpat","Get the count of unique patterns pattern subscriptions","O(1)","2.8.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_NUMPAT_History,PUBSUB_NUMPAT_Hints,pubsubCommand,2,CMD_PUBSUB|CMD_LOADING|CMD_STALE,0}, {"numsub","Get the count of subscribers for channels","O(N) for the NUMSUB subcommand, where N is the number of requested channels","2.8.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_NUMSUB_History,PUBSUB_NUMSUB_Hints,pubsubCommand,-2,CMD_PUBSUB|CMD_LOADING|CMD_STALE,0,.args=PUBSUB_NUMSUB_Args}, +{"shardchannels","List active shard channels","O(N) where N is the number of active shard channels, and assuming constant time pattern matching (relatively short channels).","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_SHARDCHANNELS_History,PUBSUB_SHARDCHANNELS_Hints,pubsubCommand,-2,CMD_PUBSUB|CMD_LOADING|CMD_STALE,0,.args=PUBSUB_SHARDCHANNELS_Args}, +{"shardnumsub","Get the count of subscribers for shard channels","O(N) for the SHARDNUMSUB subcommand, where N is the number of requested channels","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_SHARDNUMSUB_History,PUBSUB_SHARDNUMSUB_Hints,pubsubCommand,-2,CMD_PUBSUB|CMD_LOADING|CMD_STALE,0}, {0} }; @@ -2944,6 +2968,35 @@ struct redisCommandArg PUNSUBSCRIBE_Args[] = { {0} }; +/********** SPUBLISH ********************/ + +/* SPUBLISH history */ +#define SPUBLISH_History NULL + +/* SPUBLISH hints */ +#define SPUBLISH_Hints NULL + +/* SPUBLISH argument table */ +struct redisCommandArg SPUBLISH_Args[] = { +{"channel",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, +{"message",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, +{0} +}; + +/********** SSUBSCRIBE ********************/ + +/* SSUBSCRIBE history */ +#define SSUBSCRIBE_History NULL + +/* SSUBSCRIBE hints */ +#define SSUBSCRIBE_Hints NULL + +/* SSUBSCRIBE argument table */ +struct redisCommandArg SSUBSCRIBE_Args[] = { +{"channel",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_MULTIPLE}, +{0} +}; + /********** SUBSCRIBE ********************/ /* SUBSCRIBE history */ @@ -2961,6 +3014,20 @@ struct redisCommandArg SUBSCRIBE_Args[] = { {0} }; +/********** SUNSUBSCRIBE ********************/ + +/* SUNSUBSCRIBE history */ +#define SUNSUBSCRIBE_History NULL + +/* SUNSUBSCRIBE hints */ +#define SUNSUBSCRIBE_Hints NULL + +/* SUNSUBSCRIBE argument table */ +struct redisCommandArg SUNSUBSCRIBE_Args[] = { +{"channel",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_OPTIONAL|CMD_ARG_MULTIPLE}, +{0} +}; + /********** UNSUBSCRIBE ********************/ /* UNSUBSCRIBE history */ @@ -6511,7 +6578,10 @@ struct redisCommand redisCommandTable[] = { {"publish","Post a message to a channel","O(N+M) where N is the number of clients subscribed to the receiving channel and M is the total number of subscribed patterns (by any client).","2.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBLISH_History,PUBLISH_Hints,publishCommand,3,CMD_PUBSUB|CMD_LOADING|CMD_STALE|CMD_FAST|CMD_MAY_REPLICATE|CMD_SENTINEL,0,.args=PUBLISH_Args}, {"pubsub","A container for Pub/Sun commands","Depends on subcommand.","2.8.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUBSUB_History,PUBSUB_Hints,NULL,-2,0,0,.subcommands=PUBSUB_Subcommands}, {"punsubscribe","Stop listening for messages posted to channels matching the given patterns","O(N+M) where N is the number of patterns the client is already subscribed and M is the number of total patterns subscribed in the system (by any client).","2.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,PUNSUBSCRIBE_History,PUNSUBSCRIBE_Hints,punsubscribeCommand,-1,CMD_PUBSUB|CMD_NOSCRIPT|CMD_LOADING|CMD_STALE|CMD_SENTINEL,0,.args=PUNSUBSCRIBE_Args}, +{"spublish","Post a message to a shard channel","O(N) where N is the number of clients subscribed to the receiving shard channel.","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,SPUBLISH_History,SPUBLISH_Hints,spublishCommand,3,CMD_PUBSUB|CMD_LOADING|CMD_STALE|CMD_FAST|CMD_MAY_REPLICATE,0,{{CMD_KEY_SHARD_CHANNEL,KSPEC_BS_INDEX,.bs.index={1},KSPEC_FK_RANGE,.fk.range={0,1,0}}},.args=SPUBLISH_Args}, +{"ssubscribe","Listen for messages published to the given shard channels","O(N) where N is the number of shard channels to subscribe to.","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,SSUBSCRIBE_History,SSUBSCRIBE_Hints,ssubscribeCommand,-2,CMD_PUBSUB|CMD_NOSCRIPT|CMD_LOADING|CMD_STALE,0,{{CMD_KEY_SHARD_CHANNEL,KSPEC_BS_INDEX,.bs.index={1},KSPEC_FK_RANGE,.fk.range={-1,1,0}}},.args=SSUBSCRIBE_Args}, {"subscribe","Listen for messages published to the given channels","O(N) where N is the number of channels to subscribe to.","2.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,SUBSCRIBE_History,SUBSCRIBE_Hints,subscribeCommand,-2,CMD_PUBSUB|CMD_NOSCRIPT|CMD_LOADING|CMD_STALE|CMD_SENTINEL,0,.args=SUBSCRIBE_Args}, +{"sunsubscribe","Stop listening for messages posted to the given shard channels","O(N) where N is the number of clients already subscribed to a channel.","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,SUNSUBSCRIBE_History,SUNSUBSCRIBE_Hints,sunsubscribeCommand,-1,CMD_PUBSUB|CMD_NOSCRIPT|CMD_LOADING|CMD_STALE,0,{{CMD_KEY_SHARD_CHANNEL,KSPEC_BS_INDEX,.bs.index={1},KSPEC_FK_RANGE,.fk.range={-1,1,0}}},.args=SUNSUBSCRIBE_Args}, {"unsubscribe","Stop listening for messages posted to the given channels","O(N) where N is the number of clients already subscribed to a channel.","2.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_PUBSUB,UNSUBSCRIBE_History,UNSUBSCRIBE_Hints,unsubscribeCommand,-1,CMD_PUBSUB|CMD_NOSCRIPT|CMD_LOADING|CMD_STALE|CMD_SENTINEL,0,.args=UNSUBSCRIBE_Args}, /* scripting */ {"eval","Execute a Lua script server side","Depends on the script that is executed.","2.6.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,EVAL_History,EVAL_Hints,evalCommand,-3,CMD_NOSCRIPT|CMD_SKIP_MONITOR|CMD_MAY_REPLICATE|CMD_NO_MANDATORY_KEYS,ACL_CATEGORY_SCRIPTING,{{CMD_KEY_WRITE|CMD_KEY_READ,KSPEC_BS_INDEX,.bs.index={2},KSPEC_FK_KEYNUM,.fk.keynum={0,1,1}}},evalGetKeys,.args=EVAL_Args}, diff --git a/src/commands/pubsub-shardchannels.json b/src/commands/pubsub-shardchannels.json new file mode 100644 index 000000000..450cd8dcd --- /dev/null +++ b/src/commands/pubsub-shardchannels.json @@ -0,0 +1,23 @@ +{ + "SHARDCHANNELS": { + "summary": "List active shard channels", + "complexity": "O(N) where N is the number of active shard channels, and assuming constant time pattern matching (relatively short channels).", + "group": "pubsub", + "since": "7.0.0", + "arity": -2, + "container": "PUBSUB", + "function": "pubsubCommand", + "command_flags": [ + "PUBSUB", + "LOADING", + "STALE" + ], + "arguments": [ + { + "name": "pattern", + "type": "string", + "optional": true + } + ] + } +} diff --git a/src/commands/pubsub-shardnumsub.json b/src/commands/pubsub-shardnumsub.json new file mode 100644 index 000000000..167132b31 --- /dev/null +++ b/src/commands/pubsub-shardnumsub.json @@ -0,0 +1,16 @@ +{ + "SHARDNUMSUB": { + "summary": "Get the count of subscribers for shard channels", + "complexity": "O(N) for the SHARDNUMSUB subcommand, where N is the number of requested channels", + "group": "pubsub", + "since": "7.0.0", + "arity": -2, + "container": "PUBSUB", + "function": "pubsubCommand", + "command_flags": [ + "PUBSUB", + "LOADING", + "STALE" + ] + } +} diff --git a/src/commands/spublish.json b/src/commands/spublish.json new file mode 100644 index 000000000..34835b9dc --- /dev/null +++ b/src/commands/spublish.json @@ -0,0 +1,46 @@ +{ + "SPUBLISH": { + "summary": "Post a message to a shard channel", + "complexity": "O(N) where N is the number of clients subscribed to the receiving shard channel.", + "group": "pubsub", + "since": "7.0.0", + "arity": 3, + "function": "spublishCommand", + "command_flags": [ + "PUBSUB", + "LOADING", + "STALE", + "FAST", + "MAY_REPLICATE" + ], + "arguments": [ + { + "name": "channel", + "type": "string" + }, + { + "name": "message", + "type": "string" + } + ], + "key_specs": [ + { + "flags": [ + "SHARD_CHANNEL" + ], + "begin_search": { + "index": { + "pos": 1 + } + }, + "find_keys": { + "range": { + "lastkey": 0, + "step": 1, + "limit": 0 + } + } + } + ] + } +} diff --git a/src/commands/ssubscribe.json b/src/commands/ssubscribe.json new file mode 100644 index 000000000..541e4aac0 --- /dev/null +++ b/src/commands/ssubscribe.json @@ -0,0 +1,42 @@ +{ + "SSUBSCRIBE": { + "summary": "Listen for messages published to the given shard channels", + "complexity": "O(N) where N is the number of shard channels to subscribe to.", + "group": "pubsub", + "since": "7.0.0", + "arity": -2, + "function": "ssubscribeCommand", + "command_flags": [ + "PUBSUB", + "NOSCRIPT", + "LOADING", + "STALE" + ], + "arguments": [ + { + "name": "channel", + "type": "string", + "multiple": true + } + ], + "key_specs": [ + { + "flags": [ + "SHARD_CHANNEL" + ], + "begin_search": { + "index": { + "pos": 1 + } + }, + "find_keys": { + "range": { + "lastkey": -1, + "step": 1, + "limit": 0 + } + } + } + ] + } +} diff --git a/src/commands/sunsubscribe.json b/src/commands/sunsubscribe.json new file mode 100644 index 000000000..ba43a02b6 --- /dev/null +++ b/src/commands/sunsubscribe.json @@ -0,0 +1,43 @@ +{ + "SUNSUBSCRIBE": { + "summary": "Stop listening for messages posted to the given shard channels", + "complexity": "O(N) where N is the number of clients already subscribed to a channel.", + "group": "pubsub", + "since": "7.0.0", + "arity": -1, + "function": "sunsubscribeCommand", + "command_flags": [ + "PUBSUB", + "NOSCRIPT", + "LOADING", + "STALE" + ], + "arguments": [ + { + "name": "channel", + "type": "string", + "optional": true, + "multiple": true + } + ], + "key_specs": [ + { + "flags": [ + "SHARD_CHANNEL" + ], + "begin_search": { + "index": { + "pos": 1 + } + }, + "find_keys": { + "range": { + "lastkey": -1, + "step": 1, + "limit": 0 + } + } + } + ] + } +} diff --git a/src/config.c b/src/config.c index 03b6af29a..317b92ea2 100644 --- a/src/config.c +++ b/src/config.c @@ -2636,6 +2636,7 @@ standardConfig configs[] = { createBoolConfig("cluster-enabled", NULL, IMMUTABLE_CONFIG, server.cluster_enabled, 0, NULL, NULL), createBoolConfig("appendonly", NULL, MODIFIABLE_CONFIG | DENY_LOADING_CONFIG, server.aof_enabled, 0, NULL, updateAppendonly), createBoolConfig("cluster-allow-reads-when-down", NULL, MODIFIABLE_CONFIG, server.cluster_allow_reads_when_down, 0, NULL, NULL), + createBoolConfig("cluster-allow-pubsubshard-when-down", NULL, MODIFIABLE_CONFIG, server.cluster_allow_pubsubshard_when_down, 1, NULL, NULL), createBoolConfig("crash-log-enabled", NULL, MODIFIABLE_CONFIG, server.crashlog_enabled, 1, NULL, updateSighandlerEnabled), createBoolConfig("crash-memcheck-enabled", NULL, MODIFIABLE_CONFIG, server.memcheck_enabled, 1, NULL, NULL), createBoolConfig("use-exit-on-panic", NULL, MODIFIABLE_CONFIG | HIDDEN_CONFIG, server.use_exit_on_panic, 0, NULL, NULL), diff --git a/src/module.c b/src/module.c index 9cae23db5..58e3bb666 100644 --- a/src/module.c +++ b/src/module.c @@ -790,6 +790,7 @@ int64_t commandKeySpecsFlagsFromString(const char *s) { char *t = tokens[j]; if (!strcasecmp(t,"write")) flags |= CMD_KEY_WRITE; else if (!strcasecmp(t,"read")) flags |= CMD_KEY_READ; + else if (!strcasecmp(t,"shard_channel")) flags |= CMD_KEY_SHARD_CHANNEL; else if (!strcasecmp(t,"incomplete")) flags |= CMD_KEY_INCOMPLETE; else break; } diff --git a/src/networking.c b/src/networking.c index 20d05a9e3..50c7d99ca 100644 --- a/src/networking.c +++ b/src/networking.c @@ -190,6 +190,7 @@ client *createClient(connection *conn) { c->watched_keys = listCreate(); c->pubsub_channels = dictCreate(&objectKeyPointerValueDictType); c->pubsub_patterns = listCreate(); + c->pubsubshard_channels = dictCreate(&objectKeyPointerValueDictType); c->peerid = NULL; c->sockname = NULL; c->client_list_node = NULL; @@ -1424,9 +1425,11 @@ void freeClient(client *c) { /* Unsubscribe from all the pubsub channels */ pubsubUnsubscribeAllChannels(c,0); + pubsubUnsubscribeShardAllChannels(c, 0); pubsubUnsubscribeAllPatterns(c,0); dictRelease(c->pubsub_channels); listRelease(c->pubsub_patterns); + dictRelease(c->pubsubshard_channels); /* Free data structures. */ listRelease(c->reply); @@ -2592,6 +2595,7 @@ void resetCommand(client *c) { discardTransaction(c); pubsubUnsubscribeAllChannels(c,0); + pubsubUnsubscribeShardAllChannels(c, 0); pubsubUnsubscribeAllPatterns(c,0); if (c->name) { diff --git a/src/pubsub.c b/src/pubsub.c index 6da5b18cf..e805b16ef 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -30,8 +30,68 @@ #include "server.h" #include "cluster.h" +/* Structure to hold the pubsub related metadata. Currently used + * for pubsub and pubsubshard feature. */ +typedef struct pubsubtype { + int shard; + dict *(*clientPubSubChannels)(client*); + int (*subscriptionCount)(client*); + dict **serverPubSubChannels; + robj **subscribeMsg; + robj **unsubscribeMsg; +}pubsubtype; + +/* + * Get client's global Pub/Sub channels subscription count. + */ int clientSubscriptionsCount(client *c); +/* + * Get client's shard level Pub/Sub channels subscription count. + */ +int clientShardSubscriptionsCount(client *c); + +/* + * Get client's global Pub/Sub channels dict. + */ +dict* getClientPubSubChannels(client *c); + +/* + * Get client's shard level Pub/Sub channels dict. + */ +dict* getClientPubSubShardChannels(client *c); + +/* + * Get list of channels client is subscribed to. + * If a pattern is provided, the subset of channels is returned + * matching the pattern. + */ +void channelList(client *c, sds pat, dict* pubsub_channels); + +/* + * Pub/Sub type for global channels. + */ +pubsubtype pubSubType = { + .shard = 0, + .clientPubSubChannels = getClientPubSubChannels, + .subscriptionCount = clientSubscriptionsCount, + .serverPubSubChannels = &server.pubsub_channels, + .subscribeMsg = &shared.subscribebulk, + .unsubscribeMsg = &shared.unsubscribebulk, +}; + +/* + * Pub/Sub type for shard level channels bounded to a slot. + */ +pubsubtype pubSubShardType = { + .shard = 1, + .clientPubSubChannels = getClientPubSubShardChannels, + .subscriptionCount = clientShardSubscriptionsCount, + .serverPubSubChannels = &server.pubsubshard_channels, + .subscribeMsg = &shared.ssubscribebulk, + .unsubscribeMsg = &shared.sunsubscribebulk +}; + /*----------------------------------------------------------------------------- * Pubsub client replies API *----------------------------------------------------------------------------*/ @@ -66,31 +126,31 @@ void addReplyPubsubPatMessage(client *c, robj *pat, robj *channel, robj *msg) { } /* Send the pubsub subscription notification to the client. */ -void addReplyPubsubSubscribed(client *c, robj *channel) { +void addReplyPubsubSubscribed(client *c, robj *channel, pubsubtype type) { if (c->resp == 2) addReply(c,shared.mbulkhdr[3]); else addReplyPushLen(c,3); - addReply(c,shared.subscribebulk); + addReply(c,*type.subscribeMsg); addReplyBulk(c,channel); - addReplyLongLong(c,clientSubscriptionsCount(c)); + addReplyLongLong(c,type.subscriptionCount(c)); } /* Send the pubsub unsubscription notification to the client. * Channel can be NULL: this is useful when the client sends a mass * unsubscribe command but there are no channels to unsubscribe from: we * still send a notification. */ -void addReplyPubsubUnsubscribed(client *c, robj *channel) { +void addReplyPubsubUnsubscribed(client *c, robj *channel, pubsubtype type) { if (c->resp == 2) addReply(c,shared.mbulkhdr[3]); else addReplyPushLen(c,3); - addReply(c,shared.unsubscribebulk); + addReply(c, *type.unsubscribeMsg); if (channel) addReplyBulk(c,channel); else addReplyNull(c); - addReplyLongLong(c,clientSubscriptionsCount(c)); + addReplyLongLong(c,type.subscriptionCount(c)); } /* Send the pubsub pattern subscription notification to the client. */ @@ -125,28 +185,57 @@ void addReplyPubsubPatUnsubscribed(client *c, robj *pattern) { * Pubsub low level API *----------------------------------------------------------------------------*/ +/* Return the number of pubsub channels + patterns is handled. */ +int serverPubsubSubscriptionCount() { + return dictSize(server.pubsub_channels) + dictSize(server.pubsub_patterns); +} + +/* Return the number of pubsub shard level channels is handled. */ +int serverPubsubShardSubscriptionCount() { + return dictSize(server.pubsubshard_channels); +} + + /* Return the number of channels + patterns a client is subscribed to. */ int clientSubscriptionsCount(client *c) { - return dictSize(c->pubsub_channels)+ - listLength(c->pubsub_patterns); + return dictSize(c->pubsub_channels) + listLength(c->pubsub_patterns); +} + +/* Return the number of shard level channels a client is subscribed to. */ +int clientShardSubscriptionsCount(client *c) { + return dictSize(c->pubsubshard_channels); +} + +dict* getClientPubSubChannels(client *c) { + return c->pubsub_channels; +} + +dict* getClientPubSubShardChannels(client *c) { + return c->pubsubshard_channels; +} + +/* Return the number of pubsub + pubsub shard level channels + * a client is subscribed to. */ +int clientTotalPubSubSubscriptionCount(client *c) { + return clientSubscriptionsCount(c) + clientShardSubscriptionsCount(c); } /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or * 0 if the client was already subscribed to that channel. */ -int pubsubSubscribeChannel(client *c, robj *channel) { +int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { dictEntry *de; list *clients = NULL; int retval = 0; /* Add the channel to the client -> channels hash table */ - if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) { + if (dictAdd(type.clientPubSubChannels(c),channel,NULL) == DICT_OK) { retval = 1; incrRefCount(channel); /* Add the client to the channel -> list of clients hash table */ - de = dictFind(server.pubsub_channels,channel); + de = dictFind(*type.serverPubSubChannels, channel); if (de == NULL) { clients = listCreate(); - dictAdd(server.pubsub_channels,channel,clients); + dictAdd(*type.serverPubSubChannels, channel, clients); incrRefCount(channel); } else { clients = dictGetVal(de); @@ -154,13 +243,13 @@ int pubsubSubscribeChannel(client *c, robj *channel) { listAddNodeTail(clients,c); } /* Notify the client */ - addReplyPubsubSubscribed(c,channel); + addReplyPubsubSubscribed(c,channel,type); return retval; } /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or * 0 if the client was not subscribed to the specified channel. */ -int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { +int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) { dictEntry *de; list *clients; listNode *ln; @@ -169,10 +258,10 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { /* Remove the channel from the client -> channels hash table */ incrRefCount(channel); /* channel may be just a pointer to the same object we have in the hash tables. Protect it... */ - if (dictDelete(c->pubsub_channels,channel) == DICT_OK) { + if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) { retval = 1; /* Remove the client from the channel -> clients list hash table */ - de = dictFind(server.pubsub_channels,channel); + de = dictFind(*type.serverPubSubChannels, channel); serverAssertWithInfo(c,NULL,de != NULL); clients = dictGetVal(de); ln = listSearchKey(clients,c); @@ -182,15 +271,53 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { /* Free the list and associated hash entry at all if this was * the latest client, so that it will be possible to abuse * Redis PUBSUB creating millions of channels. */ - dictDelete(server.pubsub_channels,channel); + dictDelete(*type.serverPubSubChannels, channel); + /* As this channel isn't subscribed by anyone, it's safe + * to remove the channel from the slot. */ + if (server.cluster_enabled & type.shard) { + slotToChannelDel(channel->ptr); + } } } /* Notify the client */ - if (notify) addReplyPubsubUnsubscribed(c,channel); + if (notify) { + addReplyPubsubUnsubscribed(c,channel,type); + } decrRefCount(channel); /* it is finally safe to release it */ return retval; } +void pubsubShardUnsubscribeAllClients(robj *channel) { + int retval; + dictEntry *de = dictFind(server.pubsubshard_channels, channel); + serverAssertWithInfo(NULL,channel,de != NULL); + list *clients = dictGetVal(de); + if (listLength(clients) > 0) { + /* For each client subscribed to the channel, unsubscribe it. */ + listIter li; + listNode *ln; + listRewind(clients, &li); + while ((ln = listNext(&li)) != NULL) { + client *c = listNodeValue(ln); + retval = dictDelete(c->pubsubshard_channels, channel); + serverAssertWithInfo(c,channel,retval == DICT_OK); + addReplyPubsubUnsubscribed(c, channel, pubSubShardType); + /* If the client has no other pubsub subscription, + * move out of pubsub mode. */ + if (clientTotalPubSubSubscriptionCount(c) == 0) { + c->flags &= ~CLIENT_PUBSUB; + } + } + } + /* Delete the channel from server pubsubshard channels hash table. */ + retval = dictDelete(server.pubsubshard_channels, channel); + /* Delete the channel from slots_to_channel mapping. */ + slotToChannelDel(channel->ptr); + serverAssertWithInfo(NULL,channel,retval == DICT_OK); + decrRefCount(channel); /* it is finally safe to release it */ +} + + /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ int pubsubSubscribePattern(client *c, robj *pattern) { dictEntry *de; @@ -250,24 +377,53 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { /* Unsubscribe from all the channels. Return the number of channels the * client was subscribed to. */ -int pubsubUnsubscribeAllChannels(client *c, int notify) { +int pubsubUnsubscribeAllChannelsInternal(client *c, int notify, pubsubtype type) { int count = 0; - if (dictSize(c->pubsub_channels) > 0) { - dictIterator *di = dictGetSafeIterator(c->pubsub_channels); + if (dictSize(type.clientPubSubChannels(c)) > 0) { + dictIterator *di = dictGetSafeIterator(type.clientPubSubChannels(c)); dictEntry *de; while((de = dictNext(di)) != NULL) { robj *channel = dictGetKey(de); - count += pubsubUnsubscribeChannel(c,channel,notify); + count += pubsubUnsubscribeChannel(c,channel,notify,type); } dictReleaseIterator(di); } /* We were subscribed to nothing? Still reply to the client. */ - if (notify && count == 0) addReplyPubsubUnsubscribed(c,NULL); + if (notify && count == 0) { + addReplyPubsubUnsubscribed(c,NULL,type); + } + return count; +} + +/* + * Unsubscribe a client from all global channels. + */ +int pubsubUnsubscribeAllChannels(client *c, int notify) { + int count = pubsubUnsubscribeAllChannelsInternal(c,notify,pubSubType); + return count; +} + +/* + * Unsubscribe a client from all shard subscribed channels. + */ +int pubsubUnsubscribeShardAllChannels(client *c, int notify) { + int count = pubsubUnsubscribeAllChannelsInternal(c, notify, pubSubShardType); return count; } +/* + * Unsubscribe a client from provided shard subscribed channel(s). + */ +void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count) { + for (unsigned int j = 0; j < count; j++) { + /* Remove the channel from server and from the clients + * subscribed to it as well as notify them. */ + pubsubShardUnsubscribeAllClients(channels[j]); + } +} + /* Unsubscribe from all the patterns. Return the number of patterns the * client was subscribed from. */ int pubsubUnsubscribeAllPatterns(client *c, int notify) { @@ -285,8 +441,10 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) { return count; } -/* Publish a message */ -int pubsubPublishMessage(robj *channel, robj *message) { +/* + * Publish a message to all the subscribers. + */ +int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) { int receivers = 0; dictEntry *de; dictIterator *di; @@ -294,7 +452,7 @@ int pubsubPublishMessage(robj *channel, robj *message) { listIter li; /* Send to clients listening for that channel */ - de = dictFind(server.pubsub_channels,channel); + de = dictFind(*type.serverPubSubChannels, channel); if (de) { list *list = dictGetVal(de); listNode *ln; @@ -308,6 +466,12 @@ int pubsubPublishMessage(robj *channel, robj *message) { receivers++; } } + + if (type.shard) { + /* Shard pubsub ignores patterns. */ + return receivers; + } + /* Send to clients listening to matching channels */ di = dictGetIterator(server.pubsub_patterns); if (di) { @@ -334,6 +498,17 @@ int pubsubPublishMessage(robj *channel, robj *message) { return receivers; } +/* Publish a message to all the subscribers. */ +int pubsubPublishMessage(robj *channel, robj *message) { + return pubsubPublishMessageInternal(channel,message,pubSubType); +} + +/* Publish a shard message to all the subscribers. */ +int pubsubPublishMessageShard(robj *channel, robj *message) { + return pubsubPublishMessageInternal(channel, message, pubSubShardType); +} + + /*----------------------------------------------------------------------------- * Pubsub commands implementation *----------------------------------------------------------------------------*/ @@ -352,13 +527,12 @@ void subscribeCommand(client *c) { addReplyError(c, "SUBSCRIBE isn't allowed for a DENY BLOCKING client"); return; } - for (j = 1; j < c->argc; j++) - pubsubSubscribeChannel(c,c->argv[j]); + pubsubSubscribeChannel(c,c->argv[j],pubSubType); c->flags |= CLIENT_PUBSUB; } -/* UNSUBSCRIBE [channel [channel ...]] */ +/* UNSUBSCRIBE [channel ...] */ void unsubscribeCommand(client *c) { if (c->argc == 1) { pubsubUnsubscribeAllChannels(c,1); @@ -366,9 +540,9 @@ void unsubscribeCommand(client *c) { int j; for (j = 1; j < c->argc; j++) - pubsubUnsubscribeChannel(c,c->argv[j],1); + pubsubUnsubscribeChannel(c,c->argv[j],1,pubSubType); } - if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; + if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; } /* PSUBSCRIBE pattern [pattern ...] */ @@ -401,7 +575,7 @@ void punsubscribeCommand(client *c) { for (j = 1; j < c->argc; j++) pubsubUnsubscribePattern(c,c->argv[j],1); } - if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; + if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; } /* PUBLISH <channel> <message> */ @@ -429,7 +603,11 @@ void pubsubCommand(client *c) { " Return number of subscriptions to patterns.", "NUMSUB [<channel> ...]", " Return the number of subscribers for the specified channels, excluding", -" pattern subscriptions(default: no channels).", +" pattern subscriptions(default: no channels)." +"SHARDCHANNELS [<pattern>]", +" Return the currently active shard level channels matching a <pattern> (default: '*').", +"SHARDNUMSUB [<channel> ...]", +" Return the number of subscribers for the specified shard level channel(s)", NULL }; addReplyHelp(c, help); @@ -438,25 +616,7 @@ NULL { /* PUBSUB CHANNELS [<pattern>] */ sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; - dictIterator *di = dictGetIterator(server.pubsub_channels); - dictEntry *de; - long mblen = 0; - void *replylen; - - replylen = addReplyDeferredLen(c); - while((de = dictNext(di)) != NULL) { - robj *cobj = dictGetKey(de); - sds channel = cobj->ptr; - - if (!pat || stringmatchlen(pat, sdslen(pat), - channel, sdslen(channel),0)) - { - addReplyBulk(c,cobj); - mblen++; - } - } - dictReleaseIterator(di); - setDeferredArrayLen(c,replylen,mblen); + channelList(c, pat, server.pubsub_channels); } else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) { /* PUBSUB NUMSUB [Channel_1 ... Channel_N] */ int j; @@ -471,7 +631,93 @@ NULL } else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) { /* PUBSUB NUMPAT */ addReplyLongLong(c,dictSize(server.pubsub_patterns)); + } else if (!strcasecmp(c->argv[1]->ptr,"shardchannels") && + (c->argc == 2 || c->argc == 3)) + { + /* PUBSUB SHARDCHANNELS */ + sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; + channelList(c,pat,server.pubsubshard_channels); + } else if (!strcasecmp(c->argv[1]->ptr,"shardnumsub") && c->argc >= 2) { + /* PUBSUB SHARDNUMSUB [Channel_1 ... Channel_N] */ + int j; + + addReplyArrayLen(c, (c->argc-2)*2); + for (j = 2; j < c->argc; j++) { + list *l = dictFetchValue(server.pubsubshard_channels, c->argv[j]); + + addReplyBulk(c,c->argv[j]); + addReplyLongLong(c,l ? listLength(l) : 0); + } } else { addReplySubcommandSyntaxError(c); } } + +void channelList(client *c, sds pat, dict *pubsub_channels) { + dictIterator *di = dictGetIterator(pubsub_channels); + dictEntry *de; + long mblen = 0; + void *replylen; + + replylen = addReplyDeferredLen(c); + while((de = dictNext(di)) != NULL) { + robj *cobj = dictGetKey(de); + sds channel = cobj->ptr; + + if (!pat || stringmatchlen(pat, sdslen(pat), + channel, sdslen(channel),0)) + { + addReplyBulk(c,cobj); + mblen++; + } + } + dictReleaseIterator(di); + setDeferredArrayLen(c,replylen,mblen); +} + +/* SPUBLISH <channel> <message> */ +void spublishCommand(client *c) { + int receivers = pubsubPublishMessageInternal(c->argv[1], c->argv[2], pubSubShardType); + if (server.cluster_enabled) { + clusterPropagatePublishShard(c->argv[1], c->argv[2]); + } else { + forceCommandPropagation(c,PROPAGATE_REPL); + } + addReplyLongLong(c,receivers); +} + +/* SSUBSCRIBE channel [channel ...] */ +void ssubscribeCommand(client *c) { + if (c->flags & CLIENT_DENY_BLOCKING) { + /* A client that has CLIENT_DENY_BLOCKING flag on + * expect a reply per command and so can not execute subscribe. */ + addReplyError(c, "SSUBSCRIBE isn't allowed for a DENY BLOCKING client"); + return; + } + + for (int j = 1; j < c->argc; j++) { + /* A channel is only considered to be added, if a + * subscriber exists for it. And if a subscriber + * already exists the slotToChannel doesn't needs + * to be incremented. */ + if (server.cluster_enabled & + (dictFind(*pubSubShardType.serverPubSubChannels, c->argv[j]) == NULL)) { + slotToChannelAdd(c->argv[j]->ptr); + } + pubsubSubscribeChannel(c, c->argv[j], pubSubShardType); + } + c->flags |= CLIENT_PUBSUB; +} + + +/* SUNSUBSCRIBE [channel ...] */ +void sunsubscribeCommand(client *c) { + if (c->argc == 1) { + pubsubUnsubscribeShardAllChannels(c, 1); + } else { + for (int j = 1; j < c->argc; j++) { + pubsubUnsubscribeChannel(c, c->argv[j], 1, pubSubShardType); + } + } + if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; +} diff --git a/src/redis-cli.c b/src/redis-cli.c index 397e3bfa8..94697f99f 100644 --- a/src/redis-cli.c +++ b/src/redis-cli.c @@ -1381,7 +1381,8 @@ static int cliSendCommand(int argc, char **argv, long repeat) { if (!strcasecmp(command,"shutdown")) config.shutdown = 1; if (!strcasecmp(command,"monitor")) config.monitor_mode = 1; if (!strcasecmp(command,"subscribe") || - !strcasecmp(command,"psubscribe")) config.pubsub_mode = 1; + !strcasecmp(command,"psubscribe") || + !strcasecmp(command,"ssubscribe")) config.pubsub_mode = 1; if (!strcasecmp(command,"sync") || !strcasecmp(command,"psync")) config.slave_mode = 1; diff --git a/src/server.c b/src/server.c index e7cf5740b..4770a5df0 100644 --- a/src/server.c +++ b/src/server.c @@ -1648,6 +1648,8 @@ void createSharedObjects(void) { shared.pmessagebulk = createStringObject("$8\r\npmessage\r\n",14); shared.subscribebulk = createStringObject("$9\r\nsubscribe\r\n",15); shared.unsubscribebulk = createStringObject("$11\r\nunsubscribe\r\n",18); + shared.ssubscribebulk = createStringObject("$10\r\nssubscribe\r\n", 17); + shared.sunsubscribebulk = createStringObject("$12\r\nsunsubscribe\r\n", 19); shared.psubscribebulk = createStringObject("$10\r\npsubscribe\r\n",17); shared.punsubscribebulk = createStringObject("$12\r\npunsubscribe\r\n",19); @@ -2367,6 +2369,7 @@ void initServer(void) { evictionPoolAlloc(); /* Initialize the LRU keys pool. */ server.pubsub_channels = dictCreate(&keylistDictType); server.pubsub_patterns = dictCreate(&keylistDictType); + server.pubsubshard_channels = dictCreate(&keylistDictType); server.cronloops = 0; server.in_script = 0; server.in_exec = 0; @@ -3499,14 +3502,16 @@ int processCommand(client *c) { if ((c->flags & CLIENT_PUBSUB && c->resp == 2) && c->cmd->proc != pingCommand && c->cmd->proc != subscribeCommand && + c->cmd->proc != ssubscribeCommand && c->cmd->proc != unsubscribeCommand && + c->cmd->proc != sunsubscribeCommand && c->cmd->proc != psubscribeCommand && c->cmd->proc != punsubscribeCommand && c->cmd->proc != quitCommand && c->cmd->proc != resetCommand) { rejectCommandFormat(c, - "Can't execute '%s': only (P)SUBSCRIBE / " - "(P)UNSUBSCRIBE / PING / QUIT / RESET are allowed in this context", + "Can't execute '%s': only (P|S)SUBSCRIBE / " + "(P|S)UNSUBSCRIBE / PING / QUIT / RESET are allowed in this context", c->cmd->name); return C_OK; } @@ -4001,6 +4006,7 @@ void addReplyFlagsForKeyArgs(client *c, uint64_t flags) { void *flaglen = addReplyDeferredLen(c); flagcount += addReplyCommandFlag(c,flags,CMD_KEY_WRITE, "write"); flagcount += addReplyCommandFlag(c,flags,CMD_KEY_READ, "read"); + flagcount += addReplyCommandFlag(c,flags,CMD_KEY_SHARD_CHANNEL, "shard_channel"); flagcount += addReplyCommandFlag(c,flags,CMD_KEY_INCOMPLETE, "incomplete"); setDeferredSetLen(c, flaglen, flagcount); } diff --git a/src/server.h b/src/server.h index 792eb30a1..c1a0af355 100644 --- a/src/server.h +++ b/src/server.h @@ -233,9 +233,12 @@ extern int configOOMScoreAdjValuesDefaults[CONFIG_OOM_COUNT]; /* Key argument flags. Please check the command table defined in the server.c file * for more information about the meaning of every flag. */ -#define CMD_KEY_WRITE (1ULL<<0) -#define CMD_KEY_READ (1ULL<<1) -#define CMD_KEY_INCOMPLETE (1ULL<<2) /* meaning that the keyspec might not point out to all keys it should cover */ +#define CMD_KEY_WRITE (1ULL<<0) /* "write" flag */ +#define CMD_KEY_READ (1ULL<<1) /* "read" flag */ +#define CMD_KEY_SHARD_CHANNEL (1ULL<<2) /* "shard_channel" flag */ +#define CMD_KEY_INCOMPLETE (1ULL<<3) /* "incomplete" flag (meaning that + * the keyspec might not point out + * to all keys it should cover) */ /* AOF states */ #define AOF_OFF 0 /* AOF is off */ @@ -1086,6 +1089,7 @@ typedef struct client { list *watched_keys; /* Keys WATCHED for MULTI/EXEC CAS */ dict *pubsub_channels; /* channels a client is interested in (SUBSCRIBE) */ list *pubsub_patterns; /* patterns a client is interested in (SUBSCRIBE) */ + dict *pubsubshard_channels; /* shard level channels a client is interested in (SSUBSCRIBE) */ sds peerid; /* Cached peer ID. */ sds sockname; /* Cached connection target address. */ listNode *client_list_node; /* list node in client list */ @@ -1174,6 +1178,7 @@ struct sharedObjectsStruct { *time, *pxat, *absttl, *retrycount, *force, *justid, *lastid, *ping, *setid, *keepttl, *load, *createconsumer, *getack, *special_asterick, *special_equals, *default_username, *redacted, + *ssubscribebulk,*sunsubscribebulk, *select[PROTO_SHARED_SELECT_CMDS], *integers[OBJ_SHARED_INTEGERS], *mbulkhdr[OBJ_SHARED_BULKHDR_LEN], /* "*<value>\r\n" */ @@ -1751,6 +1756,7 @@ struct redisServer { dict *pubsub_patterns; /* A dict of pubsub_patterns */ int notify_keyspace_events; /* Events to propagate via Pub/Sub. This is an xor of NOTIFY_... flags. */ + dict *pubsubshard_channels; /* Map channels to list of subscribed clients */ /* Cluster */ int cluster_enabled; /* Is cluster enabled? */ int cluster_port; /* Set the cluster port for a node. */ @@ -1821,6 +1827,8 @@ struct redisServer { * failover then any replica can be used. */ int target_replica_port; /* Failover target port */ int failover_state; /* Failover state */ + int cluster_allow_pubsubshard_when_down; /* Is pubsubshard allowed when the cluster + is down, doesn't affect pubsub global. */ }; #define MAX_KEYS_BUFFER 256 @@ -2816,9 +2824,14 @@ robj *hashTypeDup(robj *o); /* Pub / Sub */ int pubsubUnsubscribeAllChannels(client *c, int notify); +int pubsubUnsubscribeShardAllChannels(client *c, int notify); +void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count); int pubsubUnsubscribeAllPatterns(client *c, int notify); int pubsubPublishMessage(robj *channel, robj *message); +int pubsubPublishMessageShard(robj *channel, robj *message); void addReplyPubsubMessage(client *c, robj *channel, robj *msg); +int serverPubsubSubscriptionCount(); +int serverPubsubShardSubscriptionCount(); /* Keyspace events notification */ void notifyKeyspaceEvent(int type, char *event, robj *key, int dbid); @@ -2902,6 +2915,7 @@ void freeReplicationBacklogRefMemAsync(list *blocks, rax *index); /* API to get key arguments from commands */ int *getKeysPrepareResult(getKeysResult *result, int numkeys); int getKeysFromCommand(struct redisCommand *cmd, robj **argv, int argc, getKeysResult *result); +int getChannelsFromCommand(struct redisCommand *cmd, int argc, getKeysResult *result); void getKeysFreeResult(getKeysResult *result); int sintercardGetKeys(struct redisCommand *cmd,robj **argv, int argc, getKeysResult *result); int zunionInterDiffGetKeys(struct redisCommand *cmd,robj **argv, int argc, getKeysResult *result); @@ -3184,6 +3198,9 @@ void psubscribeCommand(client *c); void punsubscribeCommand(client *c); void publishCommand(client *c); void pubsubCommand(client *c); +void spublishCommand(client *c); +void ssubscribeCommand(client *c); +void sunsubscribeCommand(client *c); void watchCommand(client *c); void unwatchCommand(client *c); void clusterCommand(client *c); diff --git a/src/tracking.c b/src/tracking.c index 11e2587e2..bb36f742d 100644 --- a/src/tracking.c +++ b/src/tracking.c @@ -228,6 +228,12 @@ void trackingRememberKeys(client *c) { getKeysFreeResult(&result); return; } + /* Shard channels are treated as special keys for client + * library to rely on `COMMAND` command to discover the node + * to connect to. These channels doesn't need to be tracked. */ + if (c->cmd->flags & CMD_PUBSUB) { + return; + } int *keys = result.keys; |