diff --git a/src/dict.c b/src/dict.c index 2346f5be..26d0b1ff 100644 --- a/src/dict.c +++ b/src/dict.c @@ -505,6 +505,24 @@ void *dictFetchValue(dict *d, const void *key) { return he ? dictGetVal(he) : NULL; } +/* A fingerprint is a 64 bit number that represents the state of the dictionary + * at a given time, it's just a few dict properties xored together. + * When an unsafe iterator is initialized, we get the dict fingerprint, and check + * the fingerprint again when the iterator is released. + * If the two fingerprints are different it means that the user of the iterator + * performed forbidden operations against the dictionary while iterating. */ +long long dictFingerprint(dict *d) { + long long fingerprint = 0; + + fingerprint ^= (long long) d->ht[0].table; + fingerprint ^= (long long) d->ht[0].size; + fingerprint ^= (long long) d->ht[0].used; + fingerprint ^= (long long) d->ht[1].table; + fingerprint ^= (long long) d->ht[1].size; + fingerprint ^= (long long) d->ht[1].used; + return fingerprint; +} + dictIterator *dictGetIterator(dict *d) { dictIterator *iter = zmalloc(sizeof(*iter)); @@ -530,8 +548,12 @@ dictEntry *dictNext(dictIterator *iter) while (1) { if (iter->entry == NULL) { dictht *ht = &iter->d->ht[iter->table]; - if (iter->safe && iter->index == -1 && iter->table == 0) - iter->d->iterators++; + if (iter->index == -1 && iter->table == 0) { + if (iter->safe) + iter->d->iterators++; + else + iter->fingerprint = dictFingerprint(iter->d); + } iter->index++; if (iter->index >= (signed) ht->size) { if (dictIsRehashing(iter->d) && iter->table == 0) { @@ -558,8 +580,12 @@ dictEntry *dictNext(dictIterator *iter) void dictReleaseIterator(dictIterator *iter) { - if (iter->safe && !(iter->index == -1 && iter->table == 0)) - iter->d->iterators--; + if (!(iter->index == -1 && iter->table == 0)) { + if (iter->safe) + iter->d->iterators--; + else + assert(iter->fingerprint == dictFingerprint(iter->d)); + } zfree(iter); } diff --git a/src/dict.h b/src/dict.h index 3a311f17..4d750ae8 100644 --- a/src/dict.h +++ b/src/dict.h @@ -88,6 +88,7 @@ typedef struct dictIterator { dict *d; int table, index, safe; dictEntry *entry, *nextEntry; + long long fingerprint; /* unsafe iterator fingerprint for misuse detection */ } dictIterator; /* This is the initial size of every hash table */