diff options
Diffstat (limited to 'src/pubsub.c')
-rw-r--r-- | src/pubsub.c | 259 |
1 files changed, 259 insertions, 0 deletions
diff --git a/src/pubsub.c b/src/pubsub.c new file mode 100644 index 000000000..c9f5f310e --- /dev/null +++ b/src/pubsub.c @@ -0,0 +1,259 @@ +#include "redis.h" + +void freePubsubPattern(void *p) { + pubsubPattern *pat = p; + + decrRefCount(pat->pattern); + zfree(pat); +} + +int listMatchPubsubPattern(void *a, void *b) { + pubsubPattern *pa = a, *pb = b; + + return (pa->client == pb->client) && + (equalStringObjects(pa->pattern,pb->pattern)); +} + +/* Subscribe a client to a channel. Returns 1 if the operation succeeded, or + * 0 if the client was already subscribed to that channel. */ +int pubsubSubscribeChannel(redisClient *c, robj *channel) { + struct dictEntry *de; + list *clients = NULL; + int retval = 0; + + /* Add the channel to the client -> channels hash table */ + if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) { + retval = 1; + incrRefCount(channel); + /* Add the client to the channel -> list of clients hash table */ + de = dictFind(server.pubsub_channels,channel); + if (de == NULL) { + clients = listCreate(); + dictAdd(server.pubsub_channels,channel,clients); + incrRefCount(channel); + } else { + clients = dictGetEntryVal(de); + } + listAddNodeTail(clients,c); + } + /* Notify the client */ + addReply(c,shared.mbulk3); + addReply(c,shared.subscribebulk); + addReplyBulk(c,channel); + addReplyLongLong(c,dictSize(c->pubsub_channels)+listLength(c->pubsub_patterns)); + return retval; +} + +/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or + * 0 if the client was not subscribed to the specified channel. */ +int pubsubUnsubscribeChannel(redisClient *c, robj *channel, int notify) { + struct dictEntry *de; + list *clients; + listNode *ln; + int retval = 0; + + /* Remove the channel from the client -> channels hash table */ + incrRefCount(channel); /* channel may be just a pointer to the same object + we have in the hash tables. Protect it... */ + if (dictDelete(c->pubsub_channels,channel) == DICT_OK) { + retval = 1; + /* Remove the client from the channel -> clients list hash table */ + de = dictFind(server.pubsub_channels,channel); + redisAssert(de != NULL); + clients = dictGetEntryVal(de); + ln = listSearchKey(clients,c); + redisAssert(ln != NULL); + listDelNode(clients,ln); + if (listLength(clients) == 0) { + /* Free the list and associated hash entry at all if this was + * the latest client, so that it will be possible to abuse + * Redis PUBSUB creating millions of channels. */ + dictDelete(server.pubsub_channels,channel); + } + } + /* Notify the client */ + if (notify) { + addReply(c,shared.mbulk3); + addReply(c,shared.unsubscribebulk); + addReplyBulk(c,channel); + addReplyLongLong(c,dictSize(c->pubsub_channels)+ + listLength(c->pubsub_patterns)); + + } + decrRefCount(channel); /* it is finally safe to release it */ + return retval; +} + +/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the clinet was already subscribed to that pattern. */ +int pubsubSubscribePattern(redisClient *c, robj *pattern) { + int retval = 0; + + if (listSearchKey(c->pubsub_patterns,pattern) == NULL) { + retval = 1; + pubsubPattern *pat; + listAddNodeTail(c->pubsub_patterns,pattern); + incrRefCount(pattern); + pat = zmalloc(sizeof(*pat)); + pat->pattern = getDecodedObject(pattern); + pat->client = c; + listAddNodeTail(server.pubsub_patterns,pat); + } + /* Notify the client */ + addReply(c,shared.mbulk3); + addReply(c,shared.psubscribebulk); + addReplyBulk(c,pattern); + addReplyLongLong(c,dictSize(c->pubsub_channels)+listLength(c->pubsub_patterns)); + return retval; +} + +/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or + * 0 if the client was not subscribed to the specified channel. */ +int pubsubUnsubscribePattern(redisClient *c, robj *pattern, int notify) { + listNode *ln; + pubsubPattern pat; + int retval = 0; + + incrRefCount(pattern); /* Protect the object. May be the same we remove */ + if ((ln = listSearchKey(c->pubsub_patterns,pattern)) != NULL) { + retval = 1; + listDelNode(c->pubsub_patterns,ln); + pat.client = c; + pat.pattern = pattern; + ln = listSearchKey(server.pubsub_patterns,&pat); + listDelNode(server.pubsub_patterns,ln); + } + /* Notify the client */ + if (notify) { + addReply(c,shared.mbulk3); + addReply(c,shared.punsubscribebulk); + addReplyBulk(c,pattern); + addReplyLongLong(c,dictSize(c->pubsub_channels)+ + listLength(c->pubsub_patterns)); + } + decrRefCount(pattern); + return retval; +} + +/* Unsubscribe from all the channels. Return the number of channels the + * client was subscribed from. */ +int pubsubUnsubscribeAllChannels(redisClient *c, int notify) { + dictIterator *di = dictGetIterator(c->pubsub_channels); + dictEntry *de; + int count = 0; + + while((de = dictNext(di)) != NULL) { + robj *channel = dictGetEntryKey(de); + + count += pubsubUnsubscribeChannel(c,channel,notify); + } + dictReleaseIterator(di); + return count; +} + +/* Unsubscribe from all the patterns. Return the number of patterns the + * client was subscribed from. */ +int pubsubUnsubscribeAllPatterns(redisClient *c, int notify) { + listNode *ln; + listIter li; + int count = 0; + + listRewind(c->pubsub_patterns,&li); + while ((ln = listNext(&li)) != NULL) { + robj *pattern = ln->value; + + count += pubsubUnsubscribePattern(c,pattern,notify); + } + return count; +} + +/* Publish a message */ +int pubsubPublishMessage(robj *channel, robj *message) { + int receivers = 0; + struct dictEntry *de; + listNode *ln; + listIter li; + + /* Send to clients listening for that channel */ + de = dictFind(server.pubsub_channels,channel); + if (de) { + list *list = dictGetEntryVal(de); + listNode *ln; + listIter li; + + listRewind(list,&li); + while ((ln = listNext(&li)) != NULL) { + redisClient *c = ln->value; + + addReply(c,shared.mbulk3); + addReply(c,shared.messagebulk); + addReplyBulk(c,channel); + addReplyBulk(c,message); + receivers++; + } + } + /* Send to clients listening to matching channels */ + if (listLength(server.pubsub_patterns)) { + listRewind(server.pubsub_patterns,&li); + channel = getDecodedObject(channel); + while ((ln = listNext(&li)) != NULL) { + pubsubPattern *pat = ln->value; + + if (stringmatchlen((char*)pat->pattern->ptr, + sdslen(pat->pattern->ptr), + (char*)channel->ptr, + sdslen(channel->ptr),0)) { + addReply(pat->client,shared.mbulk4); + addReply(pat->client,shared.pmessagebulk); + addReplyBulk(pat->client,pat->pattern); + addReplyBulk(pat->client,channel); + addReplyBulk(pat->client,message); + receivers++; + } + } + decrRefCount(channel); + } + return receivers; +} + +void subscribeCommand(redisClient *c) { + int j; + + for (j = 1; j < c->argc; j++) + pubsubSubscribeChannel(c,c->argv[j]); +} + +void unsubscribeCommand(redisClient *c) { + if (c->argc == 1) { + pubsubUnsubscribeAllChannels(c,1); + return; + } else { + int j; + + for (j = 1; j < c->argc; j++) + pubsubUnsubscribeChannel(c,c->argv[j],1); + } +} + +void psubscribeCommand(redisClient *c) { + int j; + + for (j = 1; j < c->argc; j++) + pubsubSubscribePattern(c,c->argv[j]); +} + +void punsubscribeCommand(redisClient *c) { + if (c->argc == 1) { + pubsubUnsubscribeAllPatterns(c,1); + return; + } else { + int j; + + for (j = 1; j < c->argc; j++) + pubsubUnsubscribePattern(c,c->argv[j],1); + } +} + +void publishCommand(redisClient *c) { + int receivers = pubsubPublishMessage(c->argv[1],c->argv[2]); + addReplyLongLong(c,receivers); +} |