From 9f8885760b53e6d3952b9c9b41f9e6c48dfa6cec Mon Sep 17 00:00:00 2001 From: Harkrishn Patro <30795839+hpatro@users.noreply.github.com> Date: Mon, 3 Jan 2022 01:54:47 +0100 Subject: Sharded pubsub implementation (#8621) This commit implements a sharded pubsub implementation based off of shard channels. Co-authored-by: Harkrishn Patro Co-authored-by: Madelyn Olson --- src/pubsub.c | 350 ++++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 298 insertions(+), 52 deletions(-) (limited to 'src/pubsub.c') diff --git a/src/pubsub.c b/src/pubsub.c index 6da5b18cf..e805b16ef 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -30,8 +30,68 @@ #include "server.h" #include "cluster.h" +/* Structure to hold the pubsub related metadata. Currently used + * for pubsub and pubsubshard feature. */ +typedef struct pubsubtype { + int shard; + dict *(*clientPubSubChannels)(client*); + int (*subscriptionCount)(client*); + dict **serverPubSubChannels; + robj **subscribeMsg; + robj **unsubscribeMsg; +}pubsubtype; + +/* + * Get client's global Pub/Sub channels subscription count. + */ int clientSubscriptionsCount(client *c); +/* + * Get client's shard level Pub/Sub channels subscription count. + */ +int clientShardSubscriptionsCount(client *c); + +/* + * Get client's global Pub/Sub channels dict. + */ +dict* getClientPubSubChannels(client *c); + +/* + * Get client's shard level Pub/Sub channels dict. + */ +dict* getClientPubSubShardChannels(client *c); + +/* + * Get list of channels client is subscribed to. + * If a pattern is provided, the subset of channels is returned + * matching the pattern. + */ +void channelList(client *c, sds pat, dict* pubsub_channels); + +/* + * Pub/Sub type for global channels. + */ +pubsubtype pubSubType = { + .shard = 0, + .clientPubSubChannels = getClientPubSubChannels, + .subscriptionCount = clientSubscriptionsCount, + .serverPubSubChannels = &server.pubsub_channels, + .subscribeMsg = &shared.subscribebulk, + .unsubscribeMsg = &shared.unsubscribebulk, +}; + +/* + * Pub/Sub type for shard level channels bounded to a slot. + */ +pubsubtype pubSubShardType = { + .shard = 1, + .clientPubSubChannels = getClientPubSubShardChannels, + .subscriptionCount = clientShardSubscriptionsCount, + .serverPubSubChannels = &server.pubsubshard_channels, + .subscribeMsg = &shared.ssubscribebulk, + .unsubscribeMsg = &shared.sunsubscribebulk +}; + /*----------------------------------------------------------------------------- * Pubsub client replies API *----------------------------------------------------------------------------*/ @@ -66,31 +126,31 @@ void addReplyPubsubPatMessage(client *c, robj *pat, robj *channel, robj *msg) { } /* Send the pubsub subscription notification to the client. */ -void addReplyPubsubSubscribed(client *c, robj *channel) { +void addReplyPubsubSubscribed(client *c, robj *channel, pubsubtype type) { if (c->resp == 2) addReply(c,shared.mbulkhdr[3]); else addReplyPushLen(c,3); - addReply(c,shared.subscribebulk); + addReply(c,*type.subscribeMsg); addReplyBulk(c,channel); - addReplyLongLong(c,clientSubscriptionsCount(c)); + addReplyLongLong(c,type.subscriptionCount(c)); } /* Send the pubsub unsubscription notification to the client. * Channel can be NULL: this is useful when the client sends a mass * unsubscribe command but there are no channels to unsubscribe from: we * still send a notification. */ -void addReplyPubsubUnsubscribed(client *c, robj *channel) { +void addReplyPubsubUnsubscribed(client *c, robj *channel, pubsubtype type) { if (c->resp == 2) addReply(c,shared.mbulkhdr[3]); else addReplyPushLen(c,3); - addReply(c,shared.unsubscribebulk); + addReply(c, *type.unsubscribeMsg); if (channel) addReplyBulk(c,channel); else addReplyNull(c); - addReplyLongLong(c,clientSubscriptionsCount(c)); + addReplyLongLong(c,type.subscriptionCount(c)); } /* Send the pubsub pattern subscription notification to the client. */ @@ -125,28 +185,57 @@ void addReplyPubsubPatUnsubscribed(client *c, robj *pattern) { * Pubsub low level API *----------------------------------------------------------------------------*/ +/* Return the number of pubsub channels + patterns is handled. */ +int serverPubsubSubscriptionCount() { + return dictSize(server.pubsub_channels) + dictSize(server.pubsub_patterns); +} + +/* Return the number of pubsub shard level channels is handled. */ +int serverPubsubShardSubscriptionCount() { + return dictSize(server.pubsubshard_channels); +} + + /* Return the number of channels + patterns a client is subscribed to. */ int clientSubscriptionsCount(client *c) { - return dictSize(c->pubsub_channels)+ - listLength(c->pubsub_patterns); + return dictSize(c->pubsub_channels) + listLength(c->pubsub_patterns); +} + +/* Return the number of shard level channels a client is subscribed to. */ +int clientShardSubscriptionsCount(client *c) { + return dictSize(c->pubsubshard_channels); +} + +dict* getClientPubSubChannels(client *c) { + return c->pubsub_channels; +} + +dict* getClientPubSubShardChannels(client *c) { + return c->pubsubshard_channels; +} + +/* Return the number of pubsub + pubsub shard level channels + * a client is subscribed to. */ +int clientTotalPubSubSubscriptionCount(client *c) { + return clientSubscriptionsCount(c) + clientShardSubscriptionsCount(c); } /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or * 0 if the client was already subscribed to that channel. */ -int pubsubSubscribeChannel(client *c, robj *channel) { +int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { dictEntry *de; list *clients = NULL; int retval = 0; /* Add the channel to the client -> channels hash table */ - if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) { + if (dictAdd(type.clientPubSubChannels(c),channel,NULL) == DICT_OK) { retval = 1; incrRefCount(channel); /* Add the client to the channel -> list of clients hash table */ - de = dictFind(server.pubsub_channels,channel); + de = dictFind(*type.serverPubSubChannels, channel); if (de == NULL) { clients = listCreate(); - dictAdd(server.pubsub_channels,channel,clients); + dictAdd(*type.serverPubSubChannels, channel, clients); incrRefCount(channel); } else { clients = dictGetVal(de); @@ -154,13 +243,13 @@ int pubsubSubscribeChannel(client *c, robj *channel) { listAddNodeTail(clients,c); } /* Notify the client */ - addReplyPubsubSubscribed(c,channel); + addReplyPubsubSubscribed(c,channel,type); return retval; } /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or * 0 if the client was not subscribed to the specified channel. */ -int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { +int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) { dictEntry *de; list *clients; listNode *ln; @@ -169,10 +258,10 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { /* Remove the channel from the client -> channels hash table */ incrRefCount(channel); /* channel may be just a pointer to the same object we have in the hash tables. Protect it... */ - if (dictDelete(c->pubsub_channels,channel) == DICT_OK) { + if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) { retval = 1; /* Remove the client from the channel -> clients list hash table */ - de = dictFind(server.pubsub_channels,channel); + de = dictFind(*type.serverPubSubChannels, channel); serverAssertWithInfo(c,NULL,de != NULL); clients = dictGetVal(de); ln = listSearchKey(clients,c); @@ -182,15 +271,53 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { /* Free the list and associated hash entry at all if this was * the latest client, so that it will be possible to abuse * Redis PUBSUB creating millions of channels. */ - dictDelete(server.pubsub_channels,channel); + dictDelete(*type.serverPubSubChannels, channel); + /* As this channel isn't subscribed by anyone, it's safe + * to remove the channel from the slot. */ + if (server.cluster_enabled & type.shard) { + slotToChannelDel(channel->ptr); + } } } /* Notify the client */ - if (notify) addReplyPubsubUnsubscribed(c,channel); + if (notify) { + addReplyPubsubUnsubscribed(c,channel,type); + } decrRefCount(channel); /* it is finally safe to release it */ return retval; } +void pubsubShardUnsubscribeAllClients(robj *channel) { + int retval; + dictEntry *de = dictFind(server.pubsubshard_channels, channel); + serverAssertWithInfo(NULL,channel,de != NULL); + list *clients = dictGetVal(de); + if (listLength(clients) > 0) { + /* For each client subscribed to the channel, unsubscribe it. */ + listIter li; + listNode *ln; + listRewind(clients, &li); + while ((ln = listNext(&li)) != NULL) { + client *c = listNodeValue(ln); + retval = dictDelete(c->pubsubshard_channels, channel); + serverAssertWithInfo(c,channel,retval == DICT_OK); + addReplyPubsubUnsubscribed(c, channel, pubSubShardType); + /* If the client has no other pubsub subscription, + * move out of pubsub mode. */ + if (clientTotalPubSubSubscriptionCount(c) == 0) { + c->flags &= ~CLIENT_PUBSUB; + } + } + } + /* Delete the channel from server pubsubshard channels hash table. */ + retval = dictDelete(server.pubsubshard_channels, channel); + /* Delete the channel from slots_to_channel mapping. */ + slotToChannelDel(channel->ptr); + serverAssertWithInfo(NULL,channel,retval == DICT_OK); + decrRefCount(channel); /* it is finally safe to release it */ +} + + /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ int pubsubSubscribePattern(client *c, robj *pattern) { dictEntry *de; @@ -250,24 +377,53 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { /* Unsubscribe from all the channels. Return the number of channels the * client was subscribed to. */ -int pubsubUnsubscribeAllChannels(client *c, int notify) { +int pubsubUnsubscribeAllChannelsInternal(client *c, int notify, pubsubtype type) { int count = 0; - if (dictSize(c->pubsub_channels) > 0) { - dictIterator *di = dictGetSafeIterator(c->pubsub_channels); + if (dictSize(type.clientPubSubChannels(c)) > 0) { + dictIterator *di = dictGetSafeIterator(type.clientPubSubChannels(c)); dictEntry *de; while((de = dictNext(di)) != NULL) { robj *channel = dictGetKey(de); - count += pubsubUnsubscribeChannel(c,channel,notify); + count += pubsubUnsubscribeChannel(c,channel,notify,type); } dictReleaseIterator(di); } /* We were subscribed to nothing? Still reply to the client. */ - if (notify && count == 0) addReplyPubsubUnsubscribed(c,NULL); + if (notify && count == 0) { + addReplyPubsubUnsubscribed(c,NULL,type); + } + return count; +} + +/* + * Unsubscribe a client from all global channels. + */ +int pubsubUnsubscribeAllChannels(client *c, int notify) { + int count = pubsubUnsubscribeAllChannelsInternal(c,notify,pubSubType); + return count; +} + +/* + * Unsubscribe a client from all shard subscribed channels. + */ +int pubsubUnsubscribeShardAllChannels(client *c, int notify) { + int count = pubsubUnsubscribeAllChannelsInternal(c, notify, pubSubShardType); return count; } +/* + * Unsubscribe a client from provided shard subscribed channel(s). + */ +void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count) { + for (unsigned int j = 0; j < count; j++) { + /* Remove the channel from server and from the clients + * subscribed to it as well as notify them. */ + pubsubShardUnsubscribeAllClients(channels[j]); + } +} + /* Unsubscribe from all the patterns. Return the number of patterns the * client was subscribed from. */ int pubsubUnsubscribeAllPatterns(client *c, int notify) { @@ -285,8 +441,10 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) { return count; } -/* Publish a message */ -int pubsubPublishMessage(robj *channel, robj *message) { +/* + * Publish a message to all the subscribers. + */ +int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) { int receivers = 0; dictEntry *de; dictIterator *di; @@ -294,7 +452,7 @@ int pubsubPublishMessage(robj *channel, robj *message) { listIter li; /* Send to clients listening for that channel */ - de = dictFind(server.pubsub_channels,channel); + de = dictFind(*type.serverPubSubChannels, channel); if (de) { list *list = dictGetVal(de); listNode *ln; @@ -308,6 +466,12 @@ int pubsubPublishMessage(robj *channel, robj *message) { receivers++; } } + + if (type.shard) { + /* Shard pubsub ignores patterns. */ + return receivers; + } + /* Send to clients listening to matching channels */ di = dictGetIterator(server.pubsub_patterns); if (di) { @@ -334,6 +498,17 @@ int pubsubPublishMessage(robj *channel, robj *message) { return receivers; } +/* Publish a message to all the subscribers. */ +int pubsubPublishMessage(robj *channel, robj *message) { + return pubsubPublishMessageInternal(channel,message,pubSubType); +} + +/* Publish a shard message to all the subscribers. */ +int pubsubPublishMessageShard(robj *channel, robj *message) { + return pubsubPublishMessageInternal(channel, message, pubSubShardType); +} + + /*----------------------------------------------------------------------------- * Pubsub commands implementation *----------------------------------------------------------------------------*/ @@ -352,13 +527,12 @@ void subscribeCommand(client *c) { addReplyError(c, "SUBSCRIBE isn't allowed for a DENY BLOCKING client"); return; } - for (j = 1; j < c->argc; j++) - pubsubSubscribeChannel(c,c->argv[j]); + pubsubSubscribeChannel(c,c->argv[j],pubSubType); c->flags |= CLIENT_PUBSUB; } -/* UNSUBSCRIBE [channel [channel ...]] */ +/* UNSUBSCRIBE [channel ...] */ void unsubscribeCommand(client *c) { if (c->argc == 1) { pubsubUnsubscribeAllChannels(c,1); @@ -366,9 +540,9 @@ void unsubscribeCommand(client *c) { int j; for (j = 1; j < c->argc; j++) - pubsubUnsubscribeChannel(c,c->argv[j],1); + pubsubUnsubscribeChannel(c,c->argv[j],1,pubSubType); } - if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; + if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; } /* PSUBSCRIBE pattern [pattern ...] */ @@ -401,7 +575,7 @@ void punsubscribeCommand(client *c) { for (j = 1; j < c->argc; j++) pubsubUnsubscribePattern(c,c->argv[j],1); } - if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; + if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; } /* PUBLISH */ @@ -429,7 +603,11 @@ void pubsubCommand(client *c) { " Return number of subscriptions to patterns.", "NUMSUB [ ...]", " Return the number of subscribers for the specified channels, excluding", -" pattern subscriptions(default: no channels).", +" pattern subscriptions(default: no channels)." +"SHARDCHANNELS []", +" Return the currently active shard level channels matching a (default: '*').", +"SHARDNUMSUB [ ...]", +" Return the number of subscribers for the specified shard level channel(s)", NULL }; addReplyHelp(c, help); @@ -438,25 +616,7 @@ NULL { /* PUBSUB CHANNELS [] */ sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; - dictIterator *di = dictGetIterator(server.pubsub_channels); - dictEntry *de; - long mblen = 0; - void *replylen; - - replylen = addReplyDeferredLen(c); - while((de = dictNext(di)) != NULL) { - robj *cobj = dictGetKey(de); - sds channel = cobj->ptr; - - if (!pat || stringmatchlen(pat, sdslen(pat), - channel, sdslen(channel),0)) - { - addReplyBulk(c,cobj); - mblen++; - } - } - dictReleaseIterator(di); - setDeferredArrayLen(c,replylen,mblen); + channelList(c, pat, server.pubsub_channels); } else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) { /* PUBSUB NUMSUB [Channel_1 ... Channel_N] */ int j; @@ -471,7 +631,93 @@ NULL } else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) { /* PUBSUB NUMPAT */ addReplyLongLong(c,dictSize(server.pubsub_patterns)); + } else if (!strcasecmp(c->argv[1]->ptr,"shardchannels") && + (c->argc == 2 || c->argc == 3)) + { + /* PUBSUB SHARDCHANNELS */ + sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; + channelList(c,pat,server.pubsubshard_channels); + } else if (!strcasecmp(c->argv[1]->ptr,"shardnumsub") && c->argc >= 2) { + /* PUBSUB SHARDNUMSUB [Channel_1 ... Channel_N] */ + int j; + + addReplyArrayLen(c, (c->argc-2)*2); + for (j = 2; j < c->argc; j++) { + list *l = dictFetchValue(server.pubsubshard_channels, c->argv[j]); + + addReplyBulk(c,c->argv[j]); + addReplyLongLong(c,l ? listLength(l) : 0); + } } else { addReplySubcommandSyntaxError(c); } } + +void channelList(client *c, sds pat, dict *pubsub_channels) { + dictIterator *di = dictGetIterator(pubsub_channels); + dictEntry *de; + long mblen = 0; + void *replylen; + + replylen = addReplyDeferredLen(c); + while((de = dictNext(di)) != NULL) { + robj *cobj = dictGetKey(de); + sds channel = cobj->ptr; + + if (!pat || stringmatchlen(pat, sdslen(pat), + channel, sdslen(channel),0)) + { + addReplyBulk(c,cobj); + mblen++; + } + } + dictReleaseIterator(di); + setDeferredArrayLen(c,replylen,mblen); +} + +/* SPUBLISH */ +void spublishCommand(client *c) { + int receivers = pubsubPublishMessageInternal(c->argv[1], c->argv[2], pubSubShardType); + if (server.cluster_enabled) { + clusterPropagatePublishShard(c->argv[1], c->argv[2]); + } else { + forceCommandPropagation(c,PROPAGATE_REPL); + } + addReplyLongLong(c,receivers); +} + +/* SSUBSCRIBE channel [channel ...] */ +void ssubscribeCommand(client *c) { + if (c->flags & CLIENT_DENY_BLOCKING) { + /* A client that has CLIENT_DENY_BLOCKING flag on + * expect a reply per command and so can not execute subscribe. */ + addReplyError(c, "SSUBSCRIBE isn't allowed for a DENY BLOCKING client"); + return; + } + + for (int j = 1; j < c->argc; j++) { + /* A channel is only considered to be added, if a + * subscriber exists for it. And if a subscriber + * already exists the slotToChannel doesn't needs + * to be incremented. */ + if (server.cluster_enabled & + (dictFind(*pubSubShardType.serverPubSubChannels, c->argv[j]) == NULL)) { + slotToChannelAdd(c->argv[j]->ptr); + } + pubsubSubscribeChannel(c, c->argv[j], pubSubShardType); + } + c->flags |= CLIENT_PUBSUB; +} + + +/* SUNSUBSCRIBE [channel ...] */ +void sunsubscribeCommand(client *c) { + if (c->argc == 1) { + pubsubUnsubscribeShardAllChannels(c, 1); + } else { + for (int j = 1; j < c->argc; j++) { + pubsubUnsubscribeChannel(c, c->argv[j], 1, pubSubShardType); + } + } + if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; +} -- cgit v1.2.1