diff options
Diffstat (limited to 'src/t_hash.c')
-rw-r--r-- | src/t_hash.c | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/src/t_hash.c b/src/t_hash.c index 51c7d6758..9f7540a72 100644 --- a/src/t_hash.c +++ b/src/t_hash.c @@ -598,6 +598,42 @@ int hashZiplistValidateIntegrity(unsigned char *zl, size_t size, int deep) { return ret; } +/* Create a new sds string from the ziplist entry. */ +sds hashSdsFromZiplistEntry(ziplistEntry *e) { + return e->sval ? sdsnewlen(e->sval, e->slen) : sdsfromlonglong(e->lval); +} + +/* Reply with bulk string from the ziplist entry. */ +void hashReplyFromZiplistEntry(client *c, ziplistEntry *e) { + if (e->sval) + addReplyBulkCBuffer(c, e->sval, e->slen); + else + addReplyBulkLongLong(c, e->lval); +} + +/* Return random element from a non empty hash. + * 'key' and 'val' will be set to hold the element. + * The memory in them is not to be freed or modified by the caller. + * 'val' can be NULL in which case it's not extracted. */ +void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, ziplistEntry *key, ziplistEntry *val) { + if (hashobj->encoding == OBJ_ENCODING_HT) { + dictEntry *de = dictGetFairRandomKey(hashobj->ptr); + sds s = dictGetKey(de); + key->sval = (unsigned char*)s; + key->slen = sdslen(s); + if (val) { + sds s = dictGetVal(de); + val->sval = (unsigned char*)s; + val->slen = sdslen(s); + } + } else if (hashobj->encoding == OBJ_ENCODING_ZIPLIST) { + ziplistRandomPair(hashobj->ptr, hashsize, key, val); + } else { + serverPanic("Unknown hash encoding"); + } +} + + /*----------------------------------------------------------------------------- * Hash type commands *----------------------------------------------------------------------------*/ @@ -922,3 +958,220 @@ void hscanCommand(client *c) { checkType(c,o,OBJ_HASH)) return; scanGenericCommand(c,o,cursor); } + +/* How many times bigger should be the hash compared to the requested size + * for us to not use the "remove elements" strategy? Read later in the + * implementation for more info. */ +#define HRANDFIELD_SUB_STRATEGY_MUL 3 + +void hrandfieldWithCountCommand(client *c, long l, int withvalues) { + unsigned long count, size; + int uniq = 1; + robj *hash; + + if ((hash = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp])) + == NULL || checkType(c,hash,OBJ_HASH)) return; + size = hashTypeLength(hash); + + if(l >= 0) { + count = (unsigned long) l; + } else { + count = -l; + uniq = 0; + } + + /* If count is zero, serve it ASAP to avoid special cases later. */ + if (count == 0) { + addReply(c,shared.emptyarray); + return; + } + + /* CASE 1: The count was negative, so the extraction method is just: + * "return N random elements" sampling the whole set every time. + * This case is trivial and can be served without auxiliary data + * structures. This case is the only one that also needs to return the + * elements in random order. */ + if (!uniq || count == 1) { + if (withvalues && c->resp == 2) + addReplyArrayLen(c, count*2); + else + addReplyArrayLen(c, count); + if (hash->encoding == OBJ_ENCODING_HT) { + sds key, value; + while (count--) { + dictEntry *de = dictGetRandomKey(hash->ptr); + key = dictGetKey(de); + value = dictGetVal(de); + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkCBuffer(c, key, sdslen(key)); + if (withvalues) + addReplyBulkCBuffer(c, value, sdslen(value)); + } + } else if (hash->encoding == OBJ_ENCODING_ZIPLIST) { + ziplistEntry *keys, *vals = NULL; + keys = zmalloc(sizeof(ziplistEntry)*count); + if (withvalues) + vals = zmalloc(sizeof(ziplistEntry)*count); + ziplistRandomPairs(hash->ptr, count, keys, vals); + for (unsigned long i = 0; i < count; i++) { + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + if (keys[i].sval) + addReplyBulkCBuffer(c, keys[i].sval, keys[i].slen); + else + addReplyBulkLongLong(c, keys[i].lval); + if (withvalues) { + if (vals[i].sval) + addReplyBulkCBuffer(c, vals[i].sval, vals[i].slen); + else + addReplyBulkLongLong(c, vals[i].lval); + } + } + zfree(keys); + zfree(vals); + } + return; + } + + /* Initiate reply count, RESP3 responds with nested array, RESP2 with flat one. */ + long reply_size = count < size ? count : size; + if (withvalues && c->resp == 2) + addReplyArrayLen(c, reply_size*2); + else + addReplyArrayLen(c, reply_size); + + /* CASE 2: + * The number of requested elements is greater than the number of + * elements inside the hash: simply return the whole hash. */ + if(count >= size) { + hashTypeIterator *hi = hashTypeInitIterator(hash); + while (hashTypeNext(hi) != C_ERR) { + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + addHashIteratorCursorToReply(c, hi, OBJ_HASH_KEY); + if (withvalues) + addHashIteratorCursorToReply(c, hi, OBJ_HASH_VALUE); + } + hashTypeReleaseIterator(hi); + return; + } + + /* CASE 3: + * The number of elements inside the hash is not greater than + * HRANDFIELD_SUB_STRATEGY_MUL times the number of requested elements. + * In this case we create a hash from scratch with all the elements, and + * subtract random elements to reach the requested number of elements. + * + * This is done because if the number of requested elements is just + * a bit less than the number of elements in the hash, the natural approach + * used into CASE 4 is highly inefficient. */ + if (count*HRANDFIELD_SUB_STRATEGY_MUL > size) { + dict *d = dictCreate(&sdsReplyDictType, NULL); + hashTypeIterator *hi = hashTypeInitIterator(hash); + + /* Add all the elements into the temporary dictionary. */ + while ((hashTypeNext(hi)) != C_ERR) { + int ret = DICT_ERR; + sds key, value = NULL; + + key = hashTypeCurrentObjectNewSds(hi,OBJ_HASH_KEY); + if (withvalues) + value = hashTypeCurrentObjectNewSds(hi,OBJ_HASH_VALUE); + ret = dictAdd(d, key, value); + + serverAssert(ret == DICT_OK); + } + serverAssert(dictSize(d) == size); + hashTypeReleaseIterator(hi); + + /* Remove random elements to reach the right count. */ + while (size > count) { + dictEntry *de; + de = dictGetRandomKey(d); + dictUnlink(d,dictGetKey(de)); + sdsfree(dictGetKey(de)); + sdsfree(dictGetVal(de)); + dictFreeUnlinkedEntry(d,de); + size--; + } + + /* Reply with what's in the dict and release memory */ + dictIterator *di; + dictEntry *de; + di = dictGetIterator(d); + while ((de = dictNext(di)) != NULL) { + sds key = dictGetKey(de); + sds value = dictGetVal(de); + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkSds(c, key); + if (withvalues) + addReplyBulkSds(c, value); + } + + dictReleaseIterator(di); + dictRelease(d); + } + + /* CASE 4: We have a big hash compared to the requested number of elements. + * In this case we can simply get random elements from the hash and add + * to the temporary hash, trying to eventually get enough unique elements + * to reach the specified count. */ + else { + unsigned long added = 0; + ziplistEntry key, value; + dict *d = dictCreate(&hashDictType, NULL); + while(added < count) { + hashTypeRandomElement(hash, size, &key, withvalues? &value : NULL); + + /* Try to add the object to the dictionary. If it already exists + * free it, otherwise increment the number of objects we have + * in the result dictionary. */ + sds skey = hashSdsFromZiplistEntry(&key); + if (dictAdd(d,skey,NULL) != DICT_OK) { + sdsfree(skey); + continue; + } + added++; + + /* We can reply right away, so that we don't need to store the value in the dict. */ + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + hashReplyFromZiplistEntry(c, &key); + if (withvalues) + hashReplyFromZiplistEntry(c, &value); + } + + /* Release memory */ + dictRelease(d); + } +} + +/* HRANDFIELD [<count> WITHVALUES] */ +void hrandfieldCommand(client *c) { + long l; + int withvalues = 0; + robj *hash; + ziplistEntry ele; + + if (c->argc >= 3) { + if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return; + if (c->argc > 4 || (c->argc == 4 && strcasecmp(c->argv[3]->ptr,"withvalues"))) { + addReplyErrorObject(c,shared.syntaxerr); + return; + } else if (c->argc == 4) + withvalues = 1; + hrandfieldWithCountCommand(c, l, withvalues); + return; + } + + /* Handle variant without <count> argument. Reply with simple bulk string */ + if ((hash = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp]))== NULL || + checkType(c,hash,OBJ_HASH)) { + return; + } + + hashTypeRandomElement(hash,hashTypeLength(hash),&ele,NULL); + hashReplyFromZiplistEntry(c, &ele); +} |