From eeffcf380fcd3e3a0b2f650e24df8338a529642d Mon Sep 17 00:00:00 2001
From: antirez <antirez@gmail.com>
Date: Fri, 6 May 2011 17:21:27 +0200
Subject: [PATCH] Lua scripts max execution time

---
 redis.conf      |  7 +++++++
 src/config.c    | 10 ++++++++++
 src/redis.c     |  1 +
 src/redis.h     |  5 +++++
 src/scripting.c | 18 ++++++++++++++++++
 5 files changed, 41 insertions(+)

diff --git a/redis.conf b/redis.conf
index e7a01eec..f962b970 100644
--- a/redis.conf
+++ b/redis.conf
@@ -292,6 +292,13 @@ appendfsync everysec
 # "no" that is the safest pick from the point of view of durability.
 no-appendfsync-on-rewrite no
 
+################################ LUA SCRIPTING  ###############################
+
+# Max execution time of a Lua script in milliseconds.
+# This prevents that a programming error generating an infinite loop will block
+# your server forever. Set it to 0 or a negative value for unlimited execution.
+lua-time-limit 60000
+
 #################################### DISK STORE ###############################
 
 # When disk store is active Redis works as an on-disk database, where memory
diff --git a/src/config.c b/src/config.c
index 98fdb15d..d4608559 100644
--- a/src/config.c
+++ b/src/config.c
@@ -296,6 +296,8 @@ void loadServerConfig(char *filename) {
         } else if (!strcasecmp(argv[0],"cluster-config-file") && argc == 2) {
             zfree(server.cluster.configfile);
             server.cluster.configfile = zstrdup(argv[1]);
+        } else if (!strcasecmp(argv[0],"lua-time-limit") && argc == 2) {
+            server.lua_time_limit = strtoll(argv[1],NULL,10);
         } else {
             err = "Bad directive or wrong number of arguments"; goto loaderr;
         }
@@ -460,6 +462,9 @@ void configSetCommand(redisClient *c) {
     } else if (!strcasecmp(c->argv[2]->ptr,"zset-max-ziplist-value")) {
         if (getLongLongFromObject(o,&ll) == REDIS_ERR || ll < 0) goto badfmt;
         server.zset_max_ziplist_value = ll;
+    } else if (!strcasecmp(c->argv[2]->ptr,"lua-time-limit")) {
+        if (getLongLongFromObject(o,&ll) == REDIS_ERR || ll < 0) goto badfmt;
+        server.lua_time_limit = ll;
     } else {
         addReplyErrorFormat(c,"Unsupported CONFIG parameter: %s",
             (char*)c->argv[2]->ptr);
@@ -621,6 +626,11 @@ void configGetCommand(redisClient *c) {
         addReplyBulkLongLong(c,server.zset_max_ziplist_value);
         matches++;
     }
+    if (stringmatch(pattern,"lua-time-limit",0)) {
+        addReplyBulkCString(c,"lua-time-limit");
+        addReplyBulkLongLong(c,server.lua_time_limit);
+        matches++;
+    }
     setDeferredMultiBulkLength(c,replylen,matches*2);
 }
 
diff --git a/src/redis.c b/src/redis.c
index 39c4ba92..72477356 100644
--- a/src/redis.c
+++ b/src/redis.c
@@ -858,6 +858,7 @@ void initServerConfig() {
     server.cache_flush_delay = 0;
     server.cluster_enabled = 0;
     server.cluster.configfile = zstrdup("nodes.conf");
+    server.lua_time_limit = REDIS_LUA_TIME_LIMIT;
 
     updateLRUClock();
     resetServerSaveParams();
diff --git a/src/redis.h b/src/redis.h
index d9609991..50d669ae 100644
--- a/src/redis.h
+++ b/src/redis.h
@@ -225,6 +225,9 @@
 #define REDIS_BGSAVE_THREAD_DONE_OK 2
 #define REDIS_BGSAVE_THREAD_DONE_ERR 3
 
+/* Scripting */
+#define REDIS_LUA_TIME_LIMIT 60000 /* milliseconds */
+
 /* We can print the stacktrace, so our assert is defined this way: */
 #define redisAssert(_e) ((_e)?(void)0 : (_redisAssert(#_e,__FILE__,__LINE__),_exit(1)))
 #define redisPanic(_e) _redisPanic(#_e,__FILE__,__LINE__),_exit(1)
@@ -659,6 +662,8 @@ struct redisServer {
     /* Scripting */
     lua_State *lua;
     redisClient *lua_client;
+    long long lua_time_limit;
+    long long lua_time_start;
 };
 
 typedef struct pubsubPattern {
diff --git a/src/scripting.c b/src/scripting.c
index ad00123a..b4297e60 100644
--- a/src/scripting.c
+++ b/src/scripting.c
@@ -199,6 +199,19 @@ int luaRedisCommand(lua_State *lua) {
     return 1;
 }
 
+void luaMaskCountHook(lua_State *lua, lua_Debug *ar) {
+    long long elapsed;
+    REDIS_NOTUSED(ar);
+
+    if (server.lua_time_limit <= 0) return;
+    elapsed = (ustime()/1000) - server.lua_time_start;
+    if (elapsed >= server.lua_time_limit) {
+        lua_pushstring(lua,"Script aborted for max execution time...");
+        lua_error(lua);
+        redisLog(REDIS_NOTICE,"Lua script aborted for max execution time after %lld milliseconds of running time",elapsed);
+    }
+}
+
 void scriptingInit(void) {
     lua_State *lua = lua_open();
     luaL_openlibs(lua);
@@ -212,6 +225,10 @@ void scriptingInit(void) {
     server.lua_client = createClient(-1);
     server.lua_client->flags |= REDIS_LUA_CLIENT;
 
+    /* Set an hook in order to be able to stop the script execution if it
+     * is running for too much time. */
+    lua_sethook(lua,luaMaskCountHook,LUA_MASKCOUNT,10000);
+
     server.lua = lua;
 }
 
@@ -375,6 +392,7 @@ void evalCommand(redisClient *c) {
     /* At this point whatever this script was never seen before or if it was
      * already defined, we can call it. We have zero arguments and expect
      * a single return value. */
+    server.lua_time_start = ustime()/1000;
     if (lua_pcall(lua,0,1,0)) {
         selectDb(c,server.lua_client->db->id); /* set DB ID from Lua client */
         addReplyErrorFormat(c,"Error running script (call to %s): %s\n",