summaryrefslogtreecommitdiff
path: root/src/rdb.c
diff options
context:
space:
mode:
authorOran Agra <oran@redislabs.com>2020-11-02 09:35:37 +0200
committerOran Agra <oran@redislabs.com>2020-12-06 14:54:34 +0200
commit3716950cfc389c0f7ed13fac5bd205173c2d8189 (patch)
tree7e9e73ac7ffda406e85a19c1b365a1a13deab81d /src/rdb.c
parent5b44631397787a65327fcab77f7df37862286ed9 (diff)
downloadredis-3716950cfc389c0f7ed13fac5bd205173c2d8189.tar.gz
Sanitize dump payload: validate no duplicate records in hash/zset/intset
If RESTORE passes successfully with full sanitization, we can't affort to crash later on assertion due to duplicate records in a hash when converting it form ziplist to dict. This means that when doing full sanitization, we must make sure there are no duplicate records in any of the collections.
Diffstat (limited to 'src/rdb.c')
-rw-r--r--src/rdb.c181
1 files changed, 129 insertions, 52 deletions
diff --git a/src/rdb.c b/src/rdb.c
index e88cabb5d..e8501ee50 100644
--- a/src/rdb.c
+++ b/src/rdb.c
@@ -44,7 +44,7 @@
#include <sys/param.h>
/* This macro is called when the internal RDB structure is corrupt */
-#define rdbExitReportCorruptRDB(...) rdbReportError(1, __LINE__,__VA_ARGS__)
+#define rdbReportCorruptRDB(...) rdbReportError(1, __LINE__,__VA_ARGS__)
/* This macro is called when RDB read failed (possibly a short read) */
#define rdbReportReadError(...) rdbReportError(0, __LINE__,__VA_ARGS__)
@@ -228,7 +228,7 @@ int rdbLoadLenByRef(rio *rdb, int *isencoded, uint64_t *lenptr) {
if (rioRead(rdb,&len,8) == 0) return -1;
*lenptr = ntohu64(len);
} else {
- rdbExitReportCorruptRDB(
+ rdbReportCorruptRDB(
"Unknown length encoding %d in rdbLoadLen()",type);
return -1; /* Never reached. */
}
@@ -296,7 +296,7 @@ void *rdbLoadIntegerObject(rio *rdb, int enctype, int flags, size_t *lenptr) {
v = enc[0]|(enc[1]<<8)|(enc[2]<<16)|(enc[3]<<24);
val = (int32_t)v;
} else {
- rdbExitReportCorruptRDB("Unknown RDB integer encoding type %d",enctype);
+ rdbReportCorruptRDB("Unknown RDB integer encoding type %d",enctype);
return NULL; /* Never reached. */
}
if (plain || sds) {
@@ -400,7 +400,7 @@ void *rdbLoadLzfStringObject(rio *rdb, int flags, size_t *lenptr) {
/* Load the compressed representation and uncompress it to target. */
if (rioRead(rdb,c,clen) == 0) goto err;
if (lzf_decompress(c,clen,val,len) != len) {
- rdbExitReportCorruptRDB("Invalid LZF compressed string");
+ rdbReportCorruptRDB("Invalid LZF compressed string");
goto err;
}
zfree(c);
@@ -516,7 +516,7 @@ void *rdbGenericLoadStringObject(rio *rdb, int flags, size_t *lenptr) {
case RDB_ENC_LZF:
return rdbLoadLzfStringObject(rdb,flags,lenptr);
default:
- rdbExitReportCorruptRDB("Unknown RDB string encoding type %llu",len);
+ rdbReportCorruptRDB("Unknown RDB string encoding type %llu",len);
return NULL;
}
}
@@ -1467,26 +1467,26 @@ robj *rdbLoadCheckModuleValue(rio *rdb, char *modulename) {
{
uint64_t len;
if (rdbLoadLenByRef(rdb,NULL,&len) == -1) {
- rdbExitReportCorruptRDB(
+ rdbReportCorruptRDB(
"Error reading integer from module %s value", modulename);
}
} else if (opcode == RDB_MODULE_OPCODE_STRING) {
robj *o = rdbGenericLoadStringObject(rdb,RDB_LOAD_NONE,NULL);
if (o == NULL) {
- rdbExitReportCorruptRDB(
+ rdbReportCorruptRDB(
"Error reading string from module %s value", modulename);
}
decrRefCount(o);
} else if (opcode == RDB_MODULE_OPCODE_FLOAT) {
float val;
if (rdbLoadBinaryFloatValue(rdb,&val) == -1) {
- rdbExitReportCorruptRDB(
+ rdbReportCorruptRDB(
"Error reading float from module %s value", modulename);
}
} else if (opcode == RDB_MODULE_OPCODE_DOUBLE) {
double val;
if (rdbLoadBinaryDoubleValue(rdb,&val) == -1) {
- rdbExitReportCorruptRDB(
+ rdbReportCorruptRDB(
"Error reading double from module %s value", modulename);
}
}
@@ -1564,7 +1564,14 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
if (o->encoding == OBJ_ENCODING_INTSET) {
/* Fetch integer value from element. */
if (isSdsRepresentableAsLongLong(sdsele,&llval) == C_OK) {
- o->ptr = intsetAdd(o->ptr,llval,NULL);
+ uint8_t success;
+ o->ptr = intsetAdd(o->ptr,llval,&success);
+ if (!success) {
+ rdbReportCorruptRDB("Duplicate set members detected");
+ decrRefCount(o);
+ sdsfree(sdsele);
+ return NULL;
+ }
} else {
setTypeConvert(o,OBJ_ENCODING_HT);
dictExpand(o->ptr,len);
@@ -1575,7 +1582,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
* to a regular hash table encoded set. */
if (o->encoding == OBJ_ENCODING_HT) {
if (dictAdd((dict*)o->ptr,sdsele,NULL) != DICT_OK) {
- rdbExitReportCorruptRDB("Duplicate set members detected");
+ rdbReportCorruptRDB("Duplicate set members detected");
decrRefCount(o);
sdsfree(sdsele);
return NULL;
@@ -1626,7 +1633,12 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
if (sdslen(sdsele) > maxelelen) maxelelen = sdslen(sdsele);
znode = zslInsert(zs->zsl,score,sdsele);
- dictAdd(zs->dict,sdsele,&znode->score);
+ if (dictAdd(zs->dict,sdsele,&znode->score) != DICT_OK) {
+ rdbReportCorruptRDB("Duplicate zset fields detected");
+ decrRefCount(o);
+ sdsfree(sdsele);
+ return NULL;
+ }
}
/* Convert *after* loading, since sorted sets are not stored ordered. */
@@ -1637,15 +1649,24 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
uint64_t len;
int ret;
sds field, value;
+ dict *dupSearchDict = NULL;
len = rdbLoadLen(rdb, NULL);
if (len == RDB_LENERR) return NULL;
o = createHashObject();
- /* Too many entries? Use a hash table. */
+ /* Too many entries? Use a hash table right from the start. */
if (len > server.hash_max_ziplist_entries)
hashTypeConvert(o, OBJ_ENCODING_HT);
+ else if (deep_integrity_validation) {
+ /* In this mode, we need to guarantee that the server won't crash
+ * later when the ziplist is converted to a dict.
+ * Create a set (dict with no values) to for a dup search.
+ * We can dismiss it as soon as we convert the ziplist to a hash. */
+ dupSearchDict = dictCreate(&hashDictType, NULL);
+ }
+
/* Load every field and value into the ziplist */
while (o->encoding == OBJ_ENCODING_ZIPLIST && len > 0) {
@@ -1653,14 +1674,29 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
/* Load raw strings */
if ((field = rdbGenericLoadStringObject(rdb,RDB_LOAD_SDS,NULL)) == NULL) {
decrRefCount(o);
+ if (dupSearchDict) dictRelease(dupSearchDict);
return NULL;
}
if ((value = rdbGenericLoadStringObject(rdb,RDB_LOAD_SDS,NULL)) == NULL) {
sdsfree(field);
decrRefCount(o);
+ if (dupSearchDict) dictRelease(dupSearchDict);
return NULL;
}
+ if (dupSearchDict) {
+ sds field_dup = sdsdup(field);
+ if (dictAdd(dupSearchDict, field_dup, NULL) != DICT_OK) {
+ rdbReportCorruptRDB("Hash with dup elements");
+ dictRelease(dupSearchDict);
+ decrRefCount(o);
+ sdsfree(field_dup);
+ sdsfree(field);
+ sdsfree(value);
+ return NULL;
+ }
+ }
+
/* Add pair to ziplist */
o->ptr = ziplistPush(o->ptr, (unsigned char*)field,
sdslen(field), ZIPLIST_TAIL);
@@ -1680,6 +1716,13 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
sdsfree(value);
}
+ if (dupSearchDict) {
+ /* We no longer need this, from now on the entries are added
+ * to a dict so the check is performed implicitly. */
+ dictRelease(dupSearchDict);
+ dupSearchDict = NULL;
+ }
+
if (o->encoding == OBJ_ENCODING_HT && len > DICT_HT_INITIAL_SIZE)
dictExpand(o->ptr,len);
@@ -1700,7 +1743,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
/* Add pair to hash table */
ret = dictAdd((dict*)o->ptr, field, value);
if (ret == DICT_ERR) {
- rdbExitReportCorruptRDB("Duplicate hash fields detected");
+ rdbReportCorruptRDB("Duplicate hash fields detected");
sdsfree(value);
sdsfree(field);
decrRefCount(o);
@@ -1725,8 +1768,8 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
return NULL;
}
if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
- if (!ziplistValidateIntegrity(zl, encoded_len, deep_integrity_validation)) {
- rdbExitReportCorruptRDB("Ziplist integrity check failed.");
+ if (!ziplistValidateIntegrity(zl, encoded_len, deep_integrity_validation, NULL, NULL)) {
+ rdbReportCorruptRDB("Ziplist integrity check failed.");
decrRefCount(o);
zfree(zl);
return NULL;
@@ -1743,28 +1786,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
unsigned char *encoded =
rdbGenericLoadStringObject(rdb,RDB_LOAD_PLAIN,&encoded_len);
if (encoded == NULL) return NULL;
- if (rdbtype == RDB_TYPE_HASH_ZIPMAP) {
- /* Since we don't keep zipmaps anymore, the rdb loading for these
- * is O(n) anyway, use `deep` validation. */
- if (!zipmapValidateIntegrity(encoded, encoded_len, 1)) {
- rdbExitReportCorruptRDB("Zipmap integrity check failed.");
- zfree(encoded);
- return NULL;
- }
- } else if (rdbtype == RDB_TYPE_SET_INTSET) {
- if (!intsetValidateIntegrity(encoded, encoded_len)) {
- rdbExitReportCorruptRDB("Intset integrity check failed.");
- zfree(encoded);
- return NULL;
- }
- } else { /* ziplist */
- if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
- if (!ziplistValidateIntegrity(encoded, encoded_len, deep_integrity_validation)) {
- rdbExitReportCorruptRDB("Ziplist integrity check failed.");
- zfree(encoded);
- return NULL;
- }
- }
+
o = createObject(OBJ_STRING,encoded); /* Obj type fixed below. */
/* Fix the object encoding, and make sure to convert the encoded
@@ -1775,6 +1797,15 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
* converted. */
switch(rdbtype) {
case RDB_TYPE_HASH_ZIPMAP:
+ /* Since we don't keep zipmaps anymore, the rdb loading for these
+ * is O(n) anyway, use `deep` validation. */
+ if (!zipmapValidateIntegrity(encoded, encoded_len, 1)) {
+ rdbReportCorruptRDB("Zipmap integrity check failed.");
+ zfree(encoded);
+ o->ptr = NULL;
+ decrRefCount(o);
+ return NULL;
+ }
/* Convert to ziplist encoded hash. This must be deprecated
* when loading dumps created by Redis 2.4 gets deprecated. */
{
@@ -1783,14 +1814,28 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
unsigned char *fstr, *vstr;
unsigned int flen, vlen;
unsigned int maxlen = 0;
+ dict *dupSearchDict = dictCreate(&hashDictType, NULL);
while ((zi = zipmapNext(zi, &fstr, &flen, &vstr, &vlen)) != NULL) {
if (flen > maxlen) maxlen = flen;
if (vlen > maxlen) maxlen = vlen;
zl = ziplistPush(zl, fstr, flen, ZIPLIST_TAIL);
zl = ziplistPush(zl, vstr, vlen, ZIPLIST_TAIL);
+
+ /* search for duplicate records */
+ sds field = sdsnewlen(fstr, flen);
+ if (dictAdd(dupSearchDict, field, NULL) != DICT_OK) {
+ rdbReportCorruptRDB("Hash zipmap with dup elements");
+ dictRelease(dupSearchDict);
+ sdsfree(field);
+ zfree(encoded);
+ o->ptr = NULL;
+ decrRefCount(o);
+ return NULL;
+ }
}
+ dictRelease(dupSearchDict);
zfree(o->ptr);
o->ptr = zl;
o->type = OBJ_HASH;
@@ -1804,23 +1849,55 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
}
break;
case RDB_TYPE_LIST_ZIPLIST:
+ if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
+ if (!ziplistValidateIntegrity(encoded, encoded_len, deep_integrity_validation, NULL, NULL)) {
+ rdbReportCorruptRDB("List ziplist integrity check failed.");
+ zfree(encoded);
+ o->ptr = NULL;
+ decrRefCount(o);
+ return NULL;
+ }
o->type = OBJ_LIST;
o->encoding = OBJ_ENCODING_ZIPLIST;
listTypeConvert(o,OBJ_ENCODING_QUICKLIST);
break;
case RDB_TYPE_SET_INTSET:
+ if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
+ if (!intsetValidateIntegrity(encoded, encoded_len, deep_integrity_validation)) {
+ rdbReportCorruptRDB("Intset integrity check failed.");
+ zfree(encoded);
+ o->ptr = NULL;
+ decrRefCount(o);
+ return NULL;
+ }
o->type = OBJ_SET;
o->encoding = OBJ_ENCODING_INTSET;
if (intsetLen(o->ptr) > server.set_max_intset_entries)
setTypeConvert(o,OBJ_ENCODING_HT);
break;
case RDB_TYPE_ZSET_ZIPLIST:
+ if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
+ if (!zsetZiplistValidateIntegrity(encoded, encoded_len, deep_integrity_validation)) {
+ rdbReportCorruptRDB("Zset ziplist integrity check failed.");
+ zfree(encoded);
+ o->ptr = NULL;
+ decrRefCount(o);
+ return NULL;
+ }
o->type = OBJ_ZSET;
o->encoding = OBJ_ENCODING_ZIPLIST;
if (zsetLength(o) > server.zset_max_ziplist_entries)
zsetConvert(o,OBJ_ENCODING_SKIPLIST);
break;
case RDB_TYPE_HASH_ZIPLIST:
+ if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
+ if (!hashZiplistValidateIntegrity(encoded, encoded_len, deep_integrity_validation)) {
+ rdbReportCorruptRDB("Hash ziplist integrity check failed.");
+ zfree(encoded);
+ o->ptr = NULL;
+ decrRefCount(o);
+ return NULL;
+ }
o->type = OBJ_HASH;
o->encoding = OBJ_ENCODING_ZIPLIST;
if (hashTypeLength(o) > server.hash_max_ziplist_entries)
@@ -1828,7 +1905,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
break;
default:
/* totally unreachable */
- rdbExitReportCorruptRDB("Unknown RDB encoding type %d",rdbtype);
+ rdbReportCorruptRDB("Unknown RDB encoding type %d",rdbtype);
break;
}
} else if (rdbtype == RDB_TYPE_STREAM_LISTPACKS) {
@@ -1852,7 +1929,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
return NULL;
}
if (sdslen(nodekey) != sizeof(streamID)) {
- rdbExitReportCorruptRDB("Stream node key entry is not the "
+ rdbReportCorruptRDB("Stream node key entry is not the "
"size of a stream ID");
sdsfree(nodekey);
decrRefCount(o);
@@ -1871,7 +1948,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
}
if (deep_integrity_validation) server.stat_dump_payload_sanitizations++;
if (!streamValidateListpackIntegrity(lp, lp_size, deep_integrity_validation)) {
- rdbExitReportCorruptRDB("Stream listpack integrity check failed.");
+ rdbReportCorruptRDB("Stream listpack integrity check failed.");
sdsfree(nodekey);
decrRefCount(o);
zfree(lp);
@@ -1883,7 +1960,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
/* Serialized listpacks should never be empty, since on
* deletion we should remove the radix tree key if the
* resulting listpack is empty. */
- rdbExitReportCorruptRDB("Empty listpack inside stream");
+ rdbReportCorruptRDB("Empty listpack inside stream");
sdsfree(nodekey);
decrRefCount(o);
zfree(lp);
@@ -1895,7 +1972,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
(unsigned char*)nodekey,sizeof(streamID),lp,NULL);
sdsfree(nodekey);
if (!retval) {
- rdbExitReportCorruptRDB("Listpack re-added with existing key");
+ rdbReportCorruptRDB("Listpack re-added with existing key");
decrRefCount(o);
zfree(lp);
return NULL;
@@ -1945,7 +2022,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
streamCG *cgroup = streamCreateCG(s,cgname,sdslen(cgname),&cg_id);
if (cgroup == NULL) {
- rdbExitReportCorruptRDB("Duplicated consumer group name %s",
+ rdbReportCorruptRDB("Duplicated consumer group name %s",
cgname);
decrRefCount(o);
sdsfree(cgname);
@@ -1981,7 +2058,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
return NULL;
}
if (!raxInsert(cgroup->pel,rawid,sizeof(rawid),nack,NULL)) {
- rdbExitReportCorruptRDB("Duplicated global PEL entry "
+ rdbReportCorruptRDB("Duplicated global PEL entry "
"loading stream consumer group");
decrRefCount(o);
streamFreeNACK(nack);
@@ -2034,7 +2111,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
}
streamNACK *nack = raxFind(cgroup->pel,rawid,sizeof(rawid));
if (nack == raxNotFound) {
- rdbExitReportCorruptRDB("Consumer entry not found in "
+ rdbReportCorruptRDB("Consumer entry not found in "
"group global PEL");
decrRefCount(o);
return NULL;
@@ -2045,7 +2122,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
* NACK structure also in the consumer-specific PEL. */
nack->consumer = consumer;
if (!raxInsert(consumer->pel,rawid,sizeof(rawid),nack,NULL)) {
- rdbExitReportCorruptRDB("Duplicated consumer PEL entry "
+ rdbReportCorruptRDB("Duplicated consumer PEL entry "
" loading a stream consumer "
"group");
decrRefCount(o);
@@ -2070,7 +2147,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
if (mt == NULL) {
moduleTypeNameByID(name,moduleid);
- rdbExitReportCorruptRDB("The RDB file contains module data I can't load: no matching module '%s'", name);
+ rdbReportCorruptRDB("The RDB file contains module data I can't load: no matching module '%s'", name);
return NULL;
}
RedisModuleIO io;
@@ -2097,7 +2174,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
return NULL;
}
if (eof != RDB_MODULE_OPCODE_EOF) {
- rdbExitReportCorruptRDB("The RDB file contains module data for the module '%s' that is not terminated by the proper module value EOF marker", name);
+ rdbReportCorruptRDB("The RDB file contains module data for the module '%s' that is not terminated by the proper module value EOF marker", name);
if (ptr) {
o = createModuleObject(mt,ptr); /* creating just in order to easily destroy */
decrRefCount(o);
@@ -2108,7 +2185,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
if (ptr == NULL) {
moduleTypeNameByID(name,moduleid);
- rdbExitReportCorruptRDB("The RDB file contains module data for the module type '%s', that the responsible module is not able to load. Check for modules log above for additional clues.", name);
+ rdbReportCorruptRDB("The RDB file contains module data for the module type '%s', that the responsible module is not able to load. Check for modules log above for additional clues.", name);
return NULL;
}
o = createModuleObject(mt,ptr);
@@ -2327,7 +2404,7 @@ int rdbLoadRio(rio *rdb, int rdbflags, rdbSaveInfo *rsi) {
} else if (!strcasecmp(auxkey->ptr,"lua")) {
/* Load the script back in memory. */
if (luaCreateFunction(NULL,server.lua,auxval) == NULL) {
- rdbExitReportCorruptRDB(
+ rdbReportCorruptRDB(
"Can't load Lua script from RDB file! "
"BODY: %s", (char*)auxval->ptr);
}
@@ -2492,7 +2569,7 @@ int rdbLoadRio(rio *rdb, int rdbflags, rdbSaveInfo *rsi) {
"got (%llx). Aborting now.",
(unsigned long long)expected,
(unsigned long long)cksum);
- rdbExitReportCorruptRDB("RDB CRC error");
+ rdbReportCorruptRDB("RDB CRC error");
return C_ERR;
}
}