summaryrefslogtreecommitdiff
path: root/src/pubsub.c
diff options
context:
space:
mode:
authorHarkrishn Patro <30795839+hpatro@users.noreply.github.com>2022-01-03 01:54:47 +0100
committerGitHub <noreply@github.com>2022-01-02 16:54:47 -0800
commit9f8885760b53e6d3952b9c9b41f9e6c48dfa6cec (patch)
tree770dfdbff19a1a2a1c71a642ebd844d592ef3d26 /src/pubsub.c
parentb8ba942ac2aabf51fd96134d9fa21b47d3baff4a (diff)
downloadredis-9f8885760b53e6d3952b9c9b41f9e6c48dfa6cec.tar.gz
Sharded pubsub implementation (#8621)
This commit implements a sharded pubsub implementation based off of shard channels. Co-authored-by: Harkrishn Patro <harkrisp@amazon.com> Co-authored-by: Madelyn Olson <madelyneolson@gmail.com>
Diffstat (limited to 'src/pubsub.c')
-rw-r--r--src/pubsub.c350
1 files changed, 298 insertions, 52 deletions
diff --git a/src/pubsub.c b/src/pubsub.c
index 6da5b18cf..e805b16ef 100644
--- a/src/pubsub.c
+++ b/src/pubsub.c
@@ -30,8 +30,68 @@
#include "server.h"
#include "cluster.h"
+/* Structure to hold the pubsub related metadata. Currently used
+ * for pubsub and pubsubshard feature. */
+typedef struct pubsubtype {
+ int shard;
+ dict *(*clientPubSubChannels)(client*);
+ int (*subscriptionCount)(client*);
+ dict **serverPubSubChannels;
+ robj **subscribeMsg;
+ robj **unsubscribeMsg;
+}pubsubtype;
+
+/*
+ * Get client's global Pub/Sub channels subscription count.
+ */
int clientSubscriptionsCount(client *c);
+/*
+ * Get client's shard level Pub/Sub channels subscription count.
+ */
+int clientShardSubscriptionsCount(client *c);
+
+/*
+ * Get client's global Pub/Sub channels dict.
+ */
+dict* getClientPubSubChannels(client *c);
+
+/*
+ * Get client's shard level Pub/Sub channels dict.
+ */
+dict* getClientPubSubShardChannels(client *c);
+
+/*
+ * Get list of channels client is subscribed to.
+ * If a pattern is provided, the subset of channels is returned
+ * matching the pattern.
+ */
+void channelList(client *c, sds pat, dict* pubsub_channels);
+
+/*
+ * Pub/Sub type for global channels.
+ */
+pubsubtype pubSubType = {
+ .shard = 0,
+ .clientPubSubChannels = getClientPubSubChannels,
+ .subscriptionCount = clientSubscriptionsCount,
+ .serverPubSubChannels = &server.pubsub_channels,
+ .subscribeMsg = &shared.subscribebulk,
+ .unsubscribeMsg = &shared.unsubscribebulk,
+};
+
+/*
+ * Pub/Sub type for shard level channels bounded to a slot.
+ */
+pubsubtype pubSubShardType = {
+ .shard = 1,
+ .clientPubSubChannels = getClientPubSubShardChannels,
+ .subscriptionCount = clientShardSubscriptionsCount,
+ .serverPubSubChannels = &server.pubsubshard_channels,
+ .subscribeMsg = &shared.ssubscribebulk,
+ .unsubscribeMsg = &shared.sunsubscribebulk
+};
+
/*-----------------------------------------------------------------------------
* Pubsub client replies API
*----------------------------------------------------------------------------*/
@@ -66,31 +126,31 @@ void addReplyPubsubPatMessage(client *c, robj *pat, robj *channel, robj *msg) {
}
/* Send the pubsub subscription notification to the client. */
-void addReplyPubsubSubscribed(client *c, robj *channel) {
+void addReplyPubsubSubscribed(client *c, robj *channel, pubsubtype type) {
if (c->resp == 2)
addReply(c,shared.mbulkhdr[3]);
else
addReplyPushLen(c,3);
- addReply(c,shared.subscribebulk);
+ addReply(c,*type.subscribeMsg);
addReplyBulk(c,channel);
- addReplyLongLong(c,clientSubscriptionsCount(c));
+ addReplyLongLong(c,type.subscriptionCount(c));
}
/* Send the pubsub unsubscription notification to the client.
* Channel can be NULL: this is useful when the client sends a mass
* unsubscribe command but there are no channels to unsubscribe from: we
* still send a notification. */
-void addReplyPubsubUnsubscribed(client *c, robj *channel) {
+void addReplyPubsubUnsubscribed(client *c, robj *channel, pubsubtype type) {
if (c->resp == 2)
addReply(c,shared.mbulkhdr[3]);
else
addReplyPushLen(c,3);
- addReply(c,shared.unsubscribebulk);
+ addReply(c, *type.unsubscribeMsg);
if (channel)
addReplyBulk(c,channel);
else
addReplyNull(c);
- addReplyLongLong(c,clientSubscriptionsCount(c));
+ addReplyLongLong(c,type.subscriptionCount(c));
}
/* Send the pubsub pattern subscription notification to the client. */
@@ -125,28 +185,57 @@ void addReplyPubsubPatUnsubscribed(client *c, robj *pattern) {
* Pubsub low level API
*----------------------------------------------------------------------------*/
+/* Return the number of pubsub channels + patterns is handled. */
+int serverPubsubSubscriptionCount() {
+ return dictSize(server.pubsub_channels) + dictSize(server.pubsub_patterns);
+}
+
+/* Return the number of pubsub shard level channels is handled. */
+int serverPubsubShardSubscriptionCount() {
+ return dictSize(server.pubsubshard_channels);
+}
+
+
/* Return the number of channels + patterns a client is subscribed to. */
int clientSubscriptionsCount(client *c) {
- return dictSize(c->pubsub_channels)+
- listLength(c->pubsub_patterns);
+ return dictSize(c->pubsub_channels) + listLength(c->pubsub_patterns);
+}
+
+/* Return the number of shard level channels a client is subscribed to. */
+int clientShardSubscriptionsCount(client *c) {
+ return dictSize(c->pubsubshard_channels);
+}
+
+dict* getClientPubSubChannels(client *c) {
+ return c->pubsub_channels;
+}
+
+dict* getClientPubSubShardChannels(client *c) {
+ return c->pubsubshard_channels;
+}
+
+/* Return the number of pubsub + pubsub shard level channels
+ * a client is subscribed to. */
+int clientTotalPubSubSubscriptionCount(client *c) {
+ return clientSubscriptionsCount(c) + clientShardSubscriptionsCount(c);
}
/* Subscribe a client to a channel. Returns 1 if the operation succeeded, or
* 0 if the client was already subscribed to that channel. */
-int pubsubSubscribeChannel(client *c, robj *channel) {
+int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
dictEntry *de;
list *clients = NULL;
int retval = 0;
/* Add the channel to the client -> channels hash table */
- if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) {
+ if (dictAdd(type.clientPubSubChannels(c),channel,NULL) == DICT_OK) {
retval = 1;
incrRefCount(channel);
/* Add the client to the channel -> list of clients hash table */
- de = dictFind(server.pubsub_channels,channel);
+ de = dictFind(*type.serverPubSubChannels, channel);
if (de == NULL) {
clients = listCreate();
- dictAdd(server.pubsub_channels,channel,clients);
+ dictAdd(*type.serverPubSubChannels, channel, clients);
incrRefCount(channel);
} else {
clients = dictGetVal(de);
@@ -154,13 +243,13 @@ int pubsubSubscribeChannel(client *c, robj *channel) {
listAddNodeTail(clients,c);
}
/* Notify the client */
- addReplyPubsubSubscribed(c,channel);
+ addReplyPubsubSubscribed(c,channel,type);
return retval;
}
/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or
* 0 if the client was not subscribed to the specified channel. */
-int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) {
+int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) {
dictEntry *de;
list *clients;
listNode *ln;
@@ -169,10 +258,10 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) {
/* Remove the channel from the client -> channels hash table */
incrRefCount(channel); /* channel may be just a pointer to the same object
we have in the hash tables. Protect it... */
- if (dictDelete(c->pubsub_channels,channel) == DICT_OK) {
+ if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) {
retval = 1;
/* Remove the client from the channel -> clients list hash table */
- de = dictFind(server.pubsub_channels,channel);
+ de = dictFind(*type.serverPubSubChannels, channel);
serverAssertWithInfo(c,NULL,de != NULL);
clients = dictGetVal(de);
ln = listSearchKey(clients,c);
@@ -182,15 +271,53 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) {
/* Free the list and associated hash entry at all if this was
* the latest client, so that it will be possible to abuse
* Redis PUBSUB creating millions of channels. */
- dictDelete(server.pubsub_channels,channel);
+ dictDelete(*type.serverPubSubChannels, channel);
+ /* As this channel isn't subscribed by anyone, it's safe
+ * to remove the channel from the slot. */
+ if (server.cluster_enabled & type.shard) {
+ slotToChannelDel(channel->ptr);
+ }
}
}
/* Notify the client */
- if (notify) addReplyPubsubUnsubscribed(c,channel);
+ if (notify) {
+ addReplyPubsubUnsubscribed(c,channel,type);
+ }
decrRefCount(channel); /* it is finally safe to release it */
return retval;
}
+void pubsubShardUnsubscribeAllClients(robj *channel) {
+ int retval;
+ dictEntry *de = dictFind(server.pubsubshard_channels, channel);
+ serverAssertWithInfo(NULL,channel,de != NULL);
+ list *clients = dictGetVal(de);
+ if (listLength(clients) > 0) {
+ /* For each client subscribed to the channel, unsubscribe it. */
+ listIter li;
+ listNode *ln;
+ listRewind(clients, &li);
+ while ((ln = listNext(&li)) != NULL) {
+ client *c = listNodeValue(ln);
+ retval = dictDelete(c->pubsubshard_channels, channel);
+ serverAssertWithInfo(c,channel,retval == DICT_OK);
+ addReplyPubsubUnsubscribed(c, channel, pubSubShardType);
+ /* If the client has no other pubsub subscription,
+ * move out of pubsub mode. */
+ if (clientTotalPubSubSubscriptionCount(c) == 0) {
+ c->flags &= ~CLIENT_PUBSUB;
+ }
+ }
+ }
+ /* Delete the channel from server pubsubshard channels hash table. */
+ retval = dictDelete(server.pubsubshard_channels, channel);
+ /* Delete the channel from slots_to_channel mapping. */
+ slotToChannelDel(channel->ptr);
+ serverAssertWithInfo(NULL,channel,retval == DICT_OK);
+ decrRefCount(channel); /* it is finally safe to release it */
+}
+
+
/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */
int pubsubSubscribePattern(client *c, robj *pattern) {
dictEntry *de;
@@ -250,24 +377,53 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) {
/* Unsubscribe from all the channels. Return the number of channels the
* client was subscribed to. */
-int pubsubUnsubscribeAllChannels(client *c, int notify) {
+int pubsubUnsubscribeAllChannelsInternal(client *c, int notify, pubsubtype type) {
int count = 0;
- if (dictSize(c->pubsub_channels) > 0) {
- dictIterator *di = dictGetSafeIterator(c->pubsub_channels);
+ if (dictSize(type.clientPubSubChannels(c)) > 0) {
+ dictIterator *di = dictGetSafeIterator(type.clientPubSubChannels(c));
dictEntry *de;
while((de = dictNext(di)) != NULL) {
robj *channel = dictGetKey(de);
- count += pubsubUnsubscribeChannel(c,channel,notify);
+ count += pubsubUnsubscribeChannel(c,channel,notify,type);
}
dictReleaseIterator(di);
}
/* We were subscribed to nothing? Still reply to the client. */
- if (notify && count == 0) addReplyPubsubUnsubscribed(c,NULL);
+ if (notify && count == 0) {
+ addReplyPubsubUnsubscribed(c,NULL,type);
+ }
+ return count;
+}
+
+/*
+ * Unsubscribe a client from all global channels.
+ */
+int pubsubUnsubscribeAllChannels(client *c, int notify) {
+ int count = pubsubUnsubscribeAllChannelsInternal(c,notify,pubSubType);
+ return count;
+}
+
+/*
+ * Unsubscribe a client from all shard subscribed channels.
+ */
+int pubsubUnsubscribeShardAllChannels(client *c, int notify) {
+ int count = pubsubUnsubscribeAllChannelsInternal(c, notify, pubSubShardType);
return count;
}
+/*
+ * Unsubscribe a client from provided shard subscribed channel(s).
+ */
+void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count) {
+ for (unsigned int j = 0; j < count; j++) {
+ /* Remove the channel from server and from the clients
+ * subscribed to it as well as notify them. */
+ pubsubShardUnsubscribeAllClients(channels[j]);
+ }
+}
+
/* Unsubscribe from all the patterns. Return the number of patterns the
* client was subscribed from. */
int pubsubUnsubscribeAllPatterns(client *c, int notify) {
@@ -285,8 +441,10 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) {
return count;
}
-/* Publish a message */
-int pubsubPublishMessage(robj *channel, robj *message) {
+/*
+ * Publish a message to all the subscribers.
+ */
+int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) {
int receivers = 0;
dictEntry *de;
dictIterator *di;
@@ -294,7 +452,7 @@ int pubsubPublishMessage(robj *channel, robj *message) {
listIter li;
/* Send to clients listening for that channel */
- de = dictFind(server.pubsub_channels,channel);
+ de = dictFind(*type.serverPubSubChannels, channel);
if (de) {
list *list = dictGetVal(de);
listNode *ln;
@@ -308,6 +466,12 @@ int pubsubPublishMessage(robj *channel, robj *message) {
receivers++;
}
}
+
+ if (type.shard) {
+ /* Shard pubsub ignores patterns. */
+ return receivers;
+ }
+
/* Send to clients listening to matching channels */
di = dictGetIterator(server.pubsub_patterns);
if (di) {
@@ -334,6 +498,17 @@ int pubsubPublishMessage(robj *channel, robj *message) {
return receivers;
}
+/* Publish a message to all the subscribers. */
+int pubsubPublishMessage(robj *channel, robj *message) {
+ return pubsubPublishMessageInternal(channel,message,pubSubType);
+}
+
+/* Publish a shard message to all the subscribers. */
+int pubsubPublishMessageShard(robj *channel, robj *message) {
+ return pubsubPublishMessageInternal(channel, message, pubSubShardType);
+}
+
+
/*-----------------------------------------------------------------------------
* Pubsub commands implementation
*----------------------------------------------------------------------------*/
@@ -352,13 +527,12 @@ void subscribeCommand(client *c) {
addReplyError(c, "SUBSCRIBE isn't allowed for a DENY BLOCKING client");
return;
}
-
for (j = 1; j < c->argc; j++)
- pubsubSubscribeChannel(c,c->argv[j]);
+ pubsubSubscribeChannel(c,c->argv[j],pubSubType);
c->flags |= CLIENT_PUBSUB;
}
-/* UNSUBSCRIBE [channel [channel ...]] */
+/* UNSUBSCRIBE [channel ...] */
void unsubscribeCommand(client *c) {
if (c->argc == 1) {
pubsubUnsubscribeAllChannels(c,1);
@@ -366,9 +540,9 @@ void unsubscribeCommand(client *c) {
int j;
for (j = 1; j < c->argc; j++)
- pubsubUnsubscribeChannel(c,c->argv[j],1);
+ pubsubUnsubscribeChannel(c,c->argv[j],1,pubSubType);
}
- if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
+ if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
}
/* PSUBSCRIBE pattern [pattern ...] */
@@ -401,7 +575,7 @@ void punsubscribeCommand(client *c) {
for (j = 1; j < c->argc; j++)
pubsubUnsubscribePattern(c,c->argv[j],1);
}
- if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
+ if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
}
/* PUBLISH <channel> <message> */
@@ -429,7 +603,11 @@ void pubsubCommand(client *c) {
" Return number of subscriptions to patterns.",
"NUMSUB [<channel> ...]",
" Return the number of subscribers for the specified channels, excluding",
-" pattern subscriptions(default: no channels).",
+" pattern subscriptions(default: no channels)."
+"SHARDCHANNELS [<pattern>]",
+" Return the currently active shard level channels matching a <pattern> (default: '*').",
+"SHARDNUMSUB [<channel> ...]",
+" Return the number of subscribers for the specified shard level channel(s)",
NULL
};
addReplyHelp(c, help);
@@ -438,25 +616,7 @@ NULL
{
/* PUBSUB CHANNELS [<pattern>] */
sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr;
- dictIterator *di = dictGetIterator(server.pubsub_channels);
- dictEntry *de;
- long mblen = 0;
- void *replylen;
-
- replylen = addReplyDeferredLen(c);
- while((de = dictNext(di)) != NULL) {
- robj *cobj = dictGetKey(de);
- sds channel = cobj->ptr;
-
- if (!pat || stringmatchlen(pat, sdslen(pat),
- channel, sdslen(channel),0))
- {
- addReplyBulk(c,cobj);
- mblen++;
- }
- }
- dictReleaseIterator(di);
- setDeferredArrayLen(c,replylen,mblen);
+ channelList(c, pat, server.pubsub_channels);
} else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) {
/* PUBSUB NUMSUB [Channel_1 ... Channel_N] */
int j;
@@ -471,7 +631,93 @@ NULL
} else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) {
/* PUBSUB NUMPAT */
addReplyLongLong(c,dictSize(server.pubsub_patterns));
+ } else if (!strcasecmp(c->argv[1]->ptr,"shardchannels") &&
+ (c->argc == 2 || c->argc == 3))
+ {
+ /* PUBSUB SHARDCHANNELS */
+ sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr;
+ channelList(c,pat,server.pubsubshard_channels);
+ } else if (!strcasecmp(c->argv[1]->ptr,"shardnumsub") && c->argc >= 2) {
+ /* PUBSUB SHARDNUMSUB [Channel_1 ... Channel_N] */
+ int j;
+
+ addReplyArrayLen(c, (c->argc-2)*2);
+ for (j = 2; j < c->argc; j++) {
+ list *l = dictFetchValue(server.pubsubshard_channels, c->argv[j]);
+
+ addReplyBulk(c,c->argv[j]);
+ addReplyLongLong(c,l ? listLength(l) : 0);
+ }
} else {
addReplySubcommandSyntaxError(c);
}
}
+
+void channelList(client *c, sds pat, dict *pubsub_channels) {
+ dictIterator *di = dictGetIterator(pubsub_channels);
+ dictEntry *de;
+ long mblen = 0;
+ void *replylen;
+
+ replylen = addReplyDeferredLen(c);
+ while((de = dictNext(di)) != NULL) {
+ robj *cobj = dictGetKey(de);
+ sds channel = cobj->ptr;
+
+ if (!pat || stringmatchlen(pat, sdslen(pat),
+ channel, sdslen(channel),0))
+ {
+ addReplyBulk(c,cobj);
+ mblen++;
+ }
+ }
+ dictReleaseIterator(di);
+ setDeferredArrayLen(c,replylen,mblen);
+}
+
+/* SPUBLISH <channel> <message> */
+void spublishCommand(client *c) {
+ int receivers = pubsubPublishMessageInternal(c->argv[1], c->argv[2], pubSubShardType);
+ if (server.cluster_enabled) {
+ clusterPropagatePublishShard(c->argv[1], c->argv[2]);
+ } else {
+ forceCommandPropagation(c,PROPAGATE_REPL);
+ }
+ addReplyLongLong(c,receivers);
+}
+
+/* SSUBSCRIBE channel [channel ...] */
+void ssubscribeCommand(client *c) {
+ if (c->flags & CLIENT_DENY_BLOCKING) {
+ /* A client that has CLIENT_DENY_BLOCKING flag on
+ * expect a reply per command and so can not execute subscribe. */
+ addReplyError(c, "SSUBSCRIBE isn't allowed for a DENY BLOCKING client");
+ return;
+ }
+
+ for (int j = 1; j < c->argc; j++) {
+ /* A channel is only considered to be added, if a
+ * subscriber exists for it. And if a subscriber
+ * already exists the slotToChannel doesn't needs
+ * to be incremented. */
+ if (server.cluster_enabled &
+ (dictFind(*pubSubShardType.serverPubSubChannels, c->argv[j]) == NULL)) {
+ slotToChannelAdd(c->argv[j]->ptr);
+ }
+ pubsubSubscribeChannel(c, c->argv[j], pubSubShardType);
+ }
+ c->flags |= CLIENT_PUBSUB;
+}
+
+
+/* SUNSUBSCRIBE [channel ...] */
+void sunsubscribeCommand(client *c) {
+ if (c->argc == 1) {
+ pubsubUnsubscribeShardAllChannels(c, 1);
+ } else {
+ for (int j = 1; j < c->argc; j++) {
+ pubsubUnsubscribeChannel(c, c->argv[j], 1, pubSubShardType);
+ }
+ }
+ if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
+}