diff options
Diffstat (limited to 'src/t_zset.c')
-rw-r--r-- | src/t_zset.c | 299 |
1 files changed, 286 insertions, 13 deletions
diff --git a/src/t_zset.c b/src/t_zset.c index 3d63c41c6..b55fc169e 100644 --- a/src/t_zset.c +++ b/src/t_zset.c @@ -721,20 +721,26 @@ zskiplistNode *zslLastInLexRange(zskiplist *zsl, zlexrangespec *range) { * Ziplist-backed sorted set API *----------------------------------------------------------------------------*/ +double zzlStrtod(unsigned char *vstr, unsigned int vlen) { + char buf[128]; + if (vlen > sizeof(buf)) + vlen = sizeof(buf); + memcpy(buf,vstr,vlen); + buf[vlen] = '\0'; + return strtod(buf,NULL); + } + double zzlGetScore(unsigned char *sptr) { unsigned char *vstr; unsigned int vlen; long long vlong; - char buf[128]; double score; serverAssert(sptr != NULL); serverAssert(ziplistGet(sptr,&vstr,&vlen,&vlong)); if (vstr) { - memcpy(buf,vstr,vlen); - buf[vlen] = '\0'; - score = strtod(buf,NULL); + score = zzlStrtod(vstr,vlen); } else { score = vlong; } @@ -1653,6 +1659,48 @@ int zsetZiplistValidateIntegrity(unsigned char *zl, size_t size, int deep) { return ret; } +/* Create a new sds string from the ziplist entry. */ +sds zsetSdsFromZiplistEntry(ziplistEntry *e) { + return e->sval ? sdsnewlen(e->sval, e->slen) : sdsfromlonglong(e->lval); +} + +/* Reply with bulk string from the ziplist entry. */ +void zsetReplyFromZiplistEntry(client *c, ziplistEntry *e) { + if (e->sval) + addReplyBulkCBuffer(c, e->sval, e->slen); + else + addReplyBulkLongLong(c, e->lval); +} + + +/* Return random element from a non empty zset. + * 'key' and 'val' will be set to hold the element. + * The memory in `key` is not to be freed or modified by the caller. + * 'score' can be NULL in which case it's not extracted. */ +void zsetTypeRandomElement(robj *zsetobj, unsigned long zsetsize, ziplistEntry *key, double *score) { + if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) { + zset *zs = zsetobj->ptr; + dictEntry *de = dictGetFairRandomKey(zs->dict); + sds s = dictGetKey(de); + key->sval = (unsigned char*)s; + key->slen = sdslen(s); + if (score) + *score = *(double*)dictGetVal(de); + } else if (zsetobj->encoding == OBJ_ENCODING_ZIPLIST) { + ziplistEntry val; + ziplistRandomPair(zsetobj->ptr, zsetsize, key, &val); + if (score) { + if (val.sval) { + *score = zzlStrtod(val.sval,val.slen); + } else { + *score = (double)val.lval; + } + } + } else { + serverPanic("Unknown zset encoding"); + } +} + /*----------------------------------------------------------------------------- * Sorted set commands *----------------------------------------------------------------------------*/ @@ -2543,7 +2591,9 @@ void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, in /* read keys to be used for input */ src = zcalloc(sizeof(zsetopsrc) * setnum); for (i = 0, j = numkeysIndex+1; i < setnum; i++, j++) { - robj *obj = lookupKeyWrite(c->db,c->argv[j]); + robj *obj = dstkey ? + lookupKeyWrite(c->db,c->argv[j]) : + lookupKeyRead(c->db,c->argv[j]); if (obj != NULL) { if (obj->type != OBJ_ZSET && obj->type != OBJ_SET) { zfree(src); @@ -2749,6 +2799,9 @@ void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, in unsigned long length = dstzset->zsl->length; zskiplist *zsl = dstzset->zsl; zskiplistNode *zn = zsl->header->level[0].forward; + /* In case of WITHSCORES, respond with a single array in RESP2, and + * nested arrays in RESP3. We can't use a map response type since the + * client library needs to know to respect the order. */ if (withscores && c->resp == 2) addReplyArrayLen(c, length*2); else @@ -2866,6 +2919,9 @@ static void zrangeResultEmitLongLongToClient(zrange_result_handler *handler, static void zrangeResultFinalizeClient(zrange_result_handler *handler, size_t result_count) { + /* In case of WITHSCORES, respond with a single array in RESP2, and + * nested arrays in RESP3. We can't use a map response type since the + * client library needs to know to respect the order. */ if (handler->withscores && (handler->client->resp == 2)) { result_count *= 2; } @@ -3071,8 +3127,8 @@ void zrevrangeCommand(client *c) { /* This command implements ZRANGEBYSCORE, ZREVRANGEBYSCORE. */ void genericZrangebyscoreCommand(zrange_result_handler *handler, - zrangespec *range, robj *zobj, int withscores, long offset, - long limit, int reverse) { + zrangespec *range, robj *zobj, long offset, long limit, + int reverse) { client *c = handler->client; unsigned long rangelen = 0; @@ -3172,8 +3228,7 @@ void genericZrangebyscoreCommand(zrange_result_handler *handler, } rangelen++; - handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), - ((withscores) ? ln->score : ln->score)); + handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score); /* Move to next node */ if (reverse) { @@ -3605,11 +3660,16 @@ void zrangeGenericCommand(zrange_result_handler *handler, int argc_start, int st } /* Step 3: Lookup the key and get the range. */ - if (((zobj = lookupKeyReadOrReply(c, key, shared.emptyarray)) == NULL) - || checkType(c, zobj, OBJ_ZSET)) { + zobj = handler->dstkey ? + lookupKeyWrite(c->db,key) : + lookupKeyRead(c->db,key); + if (zobj == NULL) { + addReply(c,shared.emptyarray); goto cleanup; } + if (checkType(c,zobj,OBJ_ZSET)) goto cleanup; + /* Step 4: Pass this to the command-specific handler. */ switch (rangetype) { case ZRANGE_AUTO: @@ -3619,8 +3679,8 @@ void zrangeGenericCommand(zrange_result_handler *handler, int argc_start, int st break; case ZRANGE_SCORE: - genericZrangebyscoreCommand(handler, &range, zobj, opt_withscores || store, - opt_offset, opt_limit, direction == ZRANGE_DIRECTION_REVERSE); + genericZrangebyscoreCommand(handler, &range, zobj, opt_offset, + opt_limit, direction == ZRANGE_DIRECTION_REVERSE); break; case ZRANGE_LEX: @@ -3895,3 +3955,216 @@ void bzpopminCommand(client *c) { void bzpopmaxCommand(client *c) { blockingGenericZpopCommand(c,ZSET_MAX); } + +/* How many times bigger should be the zset compared to the requested size + * for us to not use the "remove elements" strategy? Read later in the + * implementation for more info. */ +#define ZRANDMEMBER_SUB_STRATEGY_MUL 3 + +void zrandmemberWithCountCommand(client *c, long l, int withscores) { + unsigned long count, size; + int uniq = 1; + robj *zsetobj; + + if ((zsetobj = lookupKeyReadOrReply(c, c->argv[1], shared.null[c->resp])) + == NULL || checkType(c, zsetobj, OBJ_ZSET)) return; + size = zsetLength(zsetobj); + + if(l >= 0) { + count = (unsigned long) l; + } else { + count = -l; + uniq = 0; + } + + /* If count is zero, serve it ASAP to avoid special cases later. */ + if (count == 0) { + addReply(c,shared.emptyarray); + return; + } + + /* CASE 1: The count was negative, so the extraction method is just: + * "return N random elements" sampling the whole set every time. + * This case is trivial and can be served without auxiliary data + * structures. This case is the only one that also needs to return the + * elements in random order. */ + if (!uniq || count == 1) { + if (withscores && c->resp == 2) + addReplyArrayLen(c, count*2); + else + addReplyArrayLen(c, count); + if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) { + zset *zs = zsetobj->ptr; + while (count--) { + dictEntry *de = dictGetFairRandomKey(zs->dict); + sds key = dictGetKey(de); + if (withscores && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkCBuffer(c, key, sdslen(key)); + if (withscores) + addReplyDouble(c, dictGetDoubleVal(de)); + } + } else if (zsetobj->encoding == OBJ_ENCODING_ZIPLIST) { + ziplistEntry *keys, *vals = NULL; + keys = zmalloc(sizeof(ziplistEntry)*count); + if (withscores) + vals = zmalloc(sizeof(ziplistEntry)*count); + ziplistRandomPairs(zsetobj->ptr, count, keys, vals); + for (unsigned long i = 0; i < count; i++) { + if (withscores && c->resp > 2) + addReplyArrayLen(c,2); + if (keys[i].sval) + addReplyBulkCBuffer(c, keys[i].sval, keys[i].slen); + else + addReplyBulkLongLong(c, keys[i].lval); + if (withscores) { + if (vals[i].sval) { + addReplyDouble(c, zzlStrtod(vals[i].sval,vals[i].slen)); + } else + addReplyDouble(c, vals[i].lval); + } + } + zfree(keys); + zfree(vals); + } + return; + } + + zsetopsrc src; + zsetopval zval; + src.subject = zsetobj; + src.type = zsetobj->type; + src.encoding = zsetobj->encoding; + zuiInitIterator(&src); + memset(&zval, 0, sizeof(zval)); + + /* Initiate reply count, RESP3 responds with nested array, RESP2 with flat one. */ + long reply_size = count < size ? count : size; + if (withscores && c->resp == 2) + addReplyArrayLen(c, reply_size*2); + else + addReplyArrayLen(c, reply_size); + + /* CASE 2: + * The number of requested elements is greater than the number of + * elements inside the zset: simply return the whole zset. */ + if (count >= size) { + while (zuiNext(&src, &zval)) { + if (withscores && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkSds(c, zuiNewSdsFromValue(&zval)); + if (withscores) + addReplyDouble(c, zval.score); + } + return; + } + + /* CASE 3: + * The number of elements inside the zset is not greater than + * ZRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements. + * In this case we create a dict from scratch with all the elements, and + * subtract random elements to reach the requested number of elements. + * + * This is done because if the number of requested elements is just + * a bit less than the number of elements in the set, the natural approach + * used into CASE 4 is highly inefficient. */ + if (count*ZRANDMEMBER_SUB_STRATEGY_MUL > size) { + dict *d = dictCreate(&sdsReplyDictType, NULL); + /* Add all the elements into the temporary dictionary. */ + while (zuiNext(&src, &zval)) { + sds key = zuiNewSdsFromValue(&zval); + dictEntry *de = dictAddRaw(d, key, NULL); + serverAssert(de); + if (withscores) + dictSetDoubleVal(de, zval.score); + } + serverAssert(dictSize(d) == size); + + /* Remove random elements to reach the right count. */ + while (size > count) { + dictEntry *de; + de = dictGetRandomKey(d); + dictUnlink(d,dictGetKey(de)); + sdsfree(dictGetKey(de)); + dictFreeUnlinkedEntry(d,de); + size--; + } + + /* Reply with what's in the dict and release memory */ + dictIterator *di; + dictEntry *de; + di = dictGetIterator(d); + while ((de = dictNext(di)) != NULL) { + if (withscores && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkSds(c, dictGetKey(de)); + if (withscores) + addReplyDouble(c, dictGetDoubleVal(de)); + } + + dictReleaseIterator(di); + dictRelease(d); + } + + /* CASE 4: We have a big zset compared to the requested number of elements. + * In this case we can simply get random elements from the zset and add + * to the temporary set, trying to eventually get enough unique elements + * to reach the specified count. */ + else { + unsigned long added = 0; + dict *d = dictCreate(&hashDictType, NULL); + + while (added < count) { + ziplistEntry key; + double score; + zsetTypeRandomElement(zsetobj, size, &key, withscores ? &score: NULL); + + /* Try to add the object to the dictionary. If it already exists + * free it, otherwise increment the number of objects we have + * in the result dictionary. */ + sds skey = zsetSdsFromZiplistEntry(&key); + if (dictAdd(d,skey,NULL) != DICT_OK) { + sdsfree(skey); + continue; + } + added++; + + if (withscores && c->resp > 2) + addReplyArrayLen(c,2); + zsetReplyFromZiplistEntry(c, &key); + if (withscores) + addReplyDouble(c, score); + } + + /* Release memory */ + dictRelease(d); + } +} + +/* ZRANDMEMBER [<count> WITHSCORES] */ +void zrandmemberCommand(client *c) { + long l; + int withscores = 0; + robj *zset; + ziplistEntry ele; + + if (c->argc >= 3) { + if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return; + if (c->argc > 4 || (c->argc == 4 && strcasecmp(c->argv[3]->ptr,"withscores"))) { + addReplyErrorObject(c,shared.syntaxerr); + return; + } else if (c->argc == 4) + withscores = 1; + zrandmemberWithCountCommand(c, l, withscores); + return; + } + + /* Handle variant without <count> argument. Reply with simple bulk string */ + if ((zset = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp]))== NULL || + checkType(c,zset,OBJ_ZSET)) { + return; + } + + zsetTypeRandomElement(zset, zsetLength(zset), &ele,NULL); + zsetReplyFromZiplistEntry(c,&ele); +} |