From 1b03836c0efec87eada6787f78e7a6c80cb6f70a Mon Sep 17 00:00:00 2001
From: antirez <antirez@gmail.com>
Date: Tue, 5 Jan 2010 14:25:56 -0500
Subject: [PATCH] A first fix for SET key overwrite

---
 redis.c | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/redis.c b/redis.c
index 1c0517e7..55a9ff9c 100644
--- a/redis.c
+++ b/redis.c
@@ -469,6 +469,7 @@ static robj *getDecodedObject(robj *o);
 static int removeExpire(redisDb *db, robj *key);
 static int expireIfNeeded(redisDb *db, robj *key);
 static int deleteIfVolatile(redisDb *db, robj *key);
+static int deleteIfSwapped(redisDb *db, robj *key);
 static int deleteKey(redisDb *db, robj *key);
 static time_t getExpire(redisDb *db, robj *key);
 static int setExpire(redisDb *db, robj *key, time_t when);
@@ -3264,6 +3265,12 @@ static void setGenericCommand(redisClient *c, int nx) {
     retval = dictAdd(c->db->dict,c->argv[1],c->argv[2]);
     if (retval == DICT_ERR) {
         if (!nx) {
+            /* If the key is about a swapped value, we want a new key object
+             * to overwrite the old. So we delete the old key in the database.
+             * This will also make sure that swap pages about the old object
+             * will be marked as free. */
+            if (deleteIfSwapped(c->db,c->argv[1]))
+                incrRefCount(c->argv[1]);
             dictReplace(c->db->dict,c->argv[1],c->argv[2]);
             incrRefCount(c->argv[2]);
         } else {
@@ -7032,6 +7039,19 @@ static int vmCanSwapOut(void) {
     return (server.bgsavechildpid == -1 && server.bgrewritechildpid == -1);
 }
 
+/* Delete a key if swapped. Returns 1 if the key was found, was swapped
+ * and was deleted. Otherwise 0 is returned. */
+static int deleteIfSwapped(redisDb *db, robj *key) {
+    dictEntry *de;
+    robj *foundkey;
+
+    if ((de = dictFind(db->dict,key)) == NULL) return 0;
+    foundkey = dictGetEntryKey(de);
+    if (foundkey->storage == REDIS_VM_MEMORY) return 0;
+    deleteKey(db,key);
+    return 1;
+}
+
 /* ================================= Debugging ============================== */
 
 static void debugCommand(redisClient *c) {