summaryrefslogtreecommitdiff
path: root/src/t_zset.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/t_zset.c')
-rw-r--r--src/t_zset.c269
1 files changed, 265 insertions, 4 deletions
diff --git a/src/t_zset.c b/src/t_zset.c
index 6851ac86c..b55fc169e 100644
--- a/src/t_zset.c
+++ b/src/t_zset.c
@@ -721,20 +721,26 @@ zskiplistNode *zslLastInLexRange(zskiplist *zsl, zlexrangespec *range) {
* Ziplist-backed sorted set API
*----------------------------------------------------------------------------*/
+double zzlStrtod(unsigned char *vstr, unsigned int vlen) {
+ char buf[128];
+ if (vlen > sizeof(buf))
+ vlen = sizeof(buf);
+ memcpy(buf,vstr,vlen);
+ buf[vlen] = '\0';
+ return strtod(buf,NULL);
+ }
+
double zzlGetScore(unsigned char *sptr) {
unsigned char *vstr;
unsigned int vlen;
long long vlong;
- char buf[128];
double score;
serverAssert(sptr != NULL);
serverAssert(ziplistGet(sptr,&vstr,&vlen,&vlong));
if (vstr) {
- memcpy(buf,vstr,vlen);
- buf[vlen] = '\0';
- score = strtod(buf,NULL);
+ score = zzlStrtod(vstr,vlen);
} else {
score = vlong;
}
@@ -1653,6 +1659,48 @@ int zsetZiplistValidateIntegrity(unsigned char *zl, size_t size, int deep) {
return ret;
}
+/* Create a new sds string from the ziplist entry. */
+sds zsetSdsFromZiplistEntry(ziplistEntry *e) {
+ return e->sval ? sdsnewlen(e->sval, e->slen) : sdsfromlonglong(e->lval);
+}
+
+/* Reply with bulk string from the ziplist entry. */
+void zsetReplyFromZiplistEntry(client *c, ziplistEntry *e) {
+ if (e->sval)
+ addReplyBulkCBuffer(c, e->sval, e->slen);
+ else
+ addReplyBulkLongLong(c, e->lval);
+}
+
+
+/* Return random element from a non empty zset.
+ * 'key' and 'val' will be set to hold the element.
+ * The memory in `key` is not to be freed or modified by the caller.
+ * 'score' can be NULL in which case it's not extracted. */
+void zsetTypeRandomElement(robj *zsetobj, unsigned long zsetsize, ziplistEntry *key, double *score) {
+ if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) {
+ zset *zs = zsetobj->ptr;
+ dictEntry *de = dictGetFairRandomKey(zs->dict);
+ sds s = dictGetKey(de);
+ key->sval = (unsigned char*)s;
+ key->slen = sdslen(s);
+ if (score)
+ *score = *(double*)dictGetVal(de);
+ } else if (zsetobj->encoding == OBJ_ENCODING_ZIPLIST) {
+ ziplistEntry val;
+ ziplistRandomPair(zsetobj->ptr, zsetsize, key, &val);
+ if (score) {
+ if (val.sval) {
+ *score = zzlStrtod(val.sval,val.slen);
+ } else {
+ *score = (double)val.lval;
+ }
+ }
+ } else {
+ serverPanic("Unknown zset encoding");
+ }
+}
+
/*-----------------------------------------------------------------------------
* Sorted set commands
*----------------------------------------------------------------------------*/
@@ -3907,3 +3955,216 @@ void bzpopminCommand(client *c) {
void bzpopmaxCommand(client *c) {
blockingGenericZpopCommand(c,ZSET_MAX);
}
+
+/* How many times bigger should be the zset compared to the requested size
+ * for us to not use the "remove elements" strategy? Read later in the
+ * implementation for more info. */
+#define ZRANDMEMBER_SUB_STRATEGY_MUL 3
+
+void zrandmemberWithCountCommand(client *c, long l, int withscores) {
+ unsigned long count, size;
+ int uniq = 1;
+ robj *zsetobj;
+
+ if ((zsetobj = lookupKeyReadOrReply(c, c->argv[1], shared.null[c->resp]))
+ == NULL || checkType(c, zsetobj, OBJ_ZSET)) return;
+ size = zsetLength(zsetobj);
+
+ if(l >= 0) {
+ count = (unsigned long) l;
+ } else {
+ count = -l;
+ uniq = 0;
+ }
+
+ /* If count is zero, serve it ASAP to avoid special cases later. */
+ if (count == 0) {
+ addReply(c,shared.emptyarray);
+ return;
+ }
+
+ /* CASE 1: The count was negative, so the extraction method is just:
+ * "return N random elements" sampling the whole set every time.
+ * This case is trivial and can be served without auxiliary data
+ * structures. This case is the only one that also needs to return the
+ * elements in random order. */
+ if (!uniq || count == 1) {
+ if (withscores && c->resp == 2)
+ addReplyArrayLen(c, count*2);
+ else
+ addReplyArrayLen(c, count);
+ if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) {
+ zset *zs = zsetobj->ptr;
+ while (count--) {
+ dictEntry *de = dictGetFairRandomKey(zs->dict);
+ sds key = dictGetKey(de);
+ if (withscores && c->resp > 2)
+ addReplyArrayLen(c,2);
+ addReplyBulkCBuffer(c, key, sdslen(key));
+ if (withscores)
+ addReplyDouble(c, dictGetDoubleVal(de));
+ }
+ } else if (zsetobj->encoding == OBJ_ENCODING_ZIPLIST) {
+ ziplistEntry *keys, *vals = NULL;
+ keys = zmalloc(sizeof(ziplistEntry)*count);
+ if (withscores)
+ vals = zmalloc(sizeof(ziplistEntry)*count);
+ ziplistRandomPairs(zsetobj->ptr, count, keys, vals);
+ for (unsigned long i = 0; i < count; i++) {
+ if (withscores && c->resp > 2)
+ addReplyArrayLen(c,2);
+ if (keys[i].sval)
+ addReplyBulkCBuffer(c, keys[i].sval, keys[i].slen);
+ else
+ addReplyBulkLongLong(c, keys[i].lval);
+ if (withscores) {
+ if (vals[i].sval) {
+ addReplyDouble(c, zzlStrtod(vals[i].sval,vals[i].slen));
+ } else
+ addReplyDouble(c, vals[i].lval);
+ }
+ }
+ zfree(keys);
+ zfree(vals);
+ }
+ return;
+ }
+
+ zsetopsrc src;
+ zsetopval zval;
+ src.subject = zsetobj;
+ src.type = zsetobj->type;
+ src.encoding = zsetobj->encoding;
+ zuiInitIterator(&src);
+ memset(&zval, 0, sizeof(zval));
+
+ /* Initiate reply count, RESP3 responds with nested array, RESP2 with flat one. */
+ long reply_size = count < size ? count : size;
+ if (withscores && c->resp == 2)
+ addReplyArrayLen(c, reply_size*2);
+ else
+ addReplyArrayLen(c, reply_size);
+
+ /* CASE 2:
+ * The number of requested elements is greater than the number of
+ * elements inside the zset: simply return the whole zset. */
+ if (count >= size) {
+ while (zuiNext(&src, &zval)) {
+ if (withscores && c->resp > 2)
+ addReplyArrayLen(c,2);
+ addReplyBulkSds(c, zuiNewSdsFromValue(&zval));
+ if (withscores)
+ addReplyDouble(c, zval.score);
+ }
+ return;
+ }
+
+ /* CASE 3:
+ * The number of elements inside the zset is not greater than
+ * ZRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements.
+ * In this case we create a dict from scratch with all the elements, and
+ * subtract random elements to reach the requested number of elements.
+ *
+ * This is done because if the number of requested elements is just
+ * a bit less than the number of elements in the set, the natural approach
+ * used into CASE 4 is highly inefficient. */
+ if (count*ZRANDMEMBER_SUB_STRATEGY_MUL > size) {
+ dict *d = dictCreate(&sdsReplyDictType, NULL);
+ /* Add all the elements into the temporary dictionary. */
+ while (zuiNext(&src, &zval)) {
+ sds key = zuiNewSdsFromValue(&zval);
+ dictEntry *de = dictAddRaw(d, key, NULL);
+ serverAssert(de);
+ if (withscores)
+ dictSetDoubleVal(de, zval.score);
+ }
+ serverAssert(dictSize(d) == size);
+
+ /* Remove random elements to reach the right count. */
+ while (size > count) {
+ dictEntry *de;
+ de = dictGetRandomKey(d);
+ dictUnlink(d,dictGetKey(de));
+ sdsfree(dictGetKey(de));
+ dictFreeUnlinkedEntry(d,de);
+ size--;
+ }
+
+ /* Reply with what's in the dict and release memory */
+ dictIterator *di;
+ dictEntry *de;
+ di = dictGetIterator(d);
+ while ((de = dictNext(di)) != NULL) {
+ if (withscores && c->resp > 2)
+ addReplyArrayLen(c,2);
+ addReplyBulkSds(c, dictGetKey(de));
+ if (withscores)
+ addReplyDouble(c, dictGetDoubleVal(de));
+ }
+
+ dictReleaseIterator(di);
+ dictRelease(d);
+ }
+
+ /* CASE 4: We have a big zset compared to the requested number of elements.
+ * In this case we can simply get random elements from the zset and add
+ * to the temporary set, trying to eventually get enough unique elements
+ * to reach the specified count. */
+ else {
+ unsigned long added = 0;
+ dict *d = dictCreate(&hashDictType, NULL);
+
+ while (added < count) {
+ ziplistEntry key;
+ double score;
+ zsetTypeRandomElement(zsetobj, size, &key, withscores ? &score: NULL);
+
+ /* Try to add the object to the dictionary. If it already exists
+ * free it, otherwise increment the number of objects we have
+ * in the result dictionary. */
+ sds skey = zsetSdsFromZiplistEntry(&key);
+ if (dictAdd(d,skey,NULL) != DICT_OK) {
+ sdsfree(skey);
+ continue;
+ }
+ added++;
+
+ if (withscores && c->resp > 2)
+ addReplyArrayLen(c,2);
+ zsetReplyFromZiplistEntry(c, &key);
+ if (withscores)
+ addReplyDouble(c, score);
+ }
+
+ /* Release memory */
+ dictRelease(d);
+ }
+}
+
+/* ZRANDMEMBER [<count> WITHSCORES] */
+void zrandmemberCommand(client *c) {
+ long l;
+ int withscores = 0;
+ robj *zset;
+ ziplistEntry ele;
+
+ if (c->argc >= 3) {
+ if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
+ if (c->argc > 4 || (c->argc == 4 && strcasecmp(c->argv[3]->ptr,"withscores"))) {
+ addReplyErrorObject(c,shared.syntaxerr);
+ return;
+ } else if (c->argc == 4)
+ withscores = 1;
+ zrandmemberWithCountCommand(c, l, withscores);
+ return;
+ }
+
+ /* Handle variant without <count> argument. Reply with simple bulk string */
+ if ((zset = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp]))== NULL ||
+ checkType(c,zset,OBJ_ZSET)) {
+ return;
+ }
+
+ zsetTypeRandomElement(zset, zsetLength(zset), &ele,NULL);
+ zsetReplyFromZiplistEntry(c,&ele);
+}