This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 3afb62535 feat(scripting): support strict key-accessing mode for lua 
scripting (#3139)
3afb62535 is described below

commit 3afb6253567c19c7b1695da6f6ae47cf6dbce48e
Author: Twice <[email protected]>
AuthorDate: Sat Aug 23 13:36:09 2025 +0800

    feat(scripting): support strict key-accessing mode for lua scripting (#3139)
    
    Accessing undeclared keys in lua scripting may lead to unexpected
    behavior in the current design of kvrocks (also in redis, refer to
    https://redis.io/docs/latest/commands/eval/), so in this PR we add a new
    option `lua-strict-key-accessing` to prevent users to access undeclared
    keys in lua, e.g.
    
    ```
    EVAL "return redis.call('set', 'a', 1)" // ERROR!
    
    EVAL "return redis.call('set', KEYS[1], 1)" 1 a // ok
    ```
    
    This check is performed in both lua scripting and lua functions.
---
 kvrocks.conf                                  | 13 ++++++++++
 src/config/config.cc                          |  1 +
 src/config/config.h                           |  2 ++
 src/storage/scripting.cc                      | 36 ++++++++++++++++++++++++---
 src/storage/scripting.h                       |  5 ++--
 tests/gocase/unit/scripting/function_test.go  | 34 +++++++++++++++++++++++++
 tests/gocase/unit/scripting/scripting_test.go | 31 +++++++++++++++++++++++
 7 files changed, 117 insertions(+), 5 deletions(-)

diff --git a/kvrocks.conf b/kvrocks.conf
index 67819c274..b56934345 100644
--- a/kvrocks.conf
+++ b/kvrocks.conf
@@ -420,6 +420,19 @@ txn-context-enabled no
 # Default: disabled
 # histogram-bucket-boundaries  
10,20,40,60,80,100,150,250,350,500,750,1000,1500,2000,4000,8000
 
+# Whether the strict key-accessing mode of lua scripting is enabled.
+#
+# If enabled, the lua script will abort and report errors
+# if it tries to access keys that are not declared in
+# the script's `KEYS` table or the function's `keys` argument.
+#
+# Note that accessing undeclared keys may lead to unexpected behavior,
+# so this option is to ensure that scripts only access keys
+# that are explicitly declared.
+#
+# Default: no
+lua-strict-key-accessing no
+
 ################################## TLS ###################################
 
 # By default, TLS/SSL is disabled, i.e. `tls-port` is set to 0.
diff --git a/src/config/config.cc b/src/config/config.cc
index f17db7829..fb0f68f67 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -243,6 +243,7 @@ Config::Config() {
       {"txn-context-enabled", true, new YesNoField(&txn_context_enabled, 
false)},
       {"skip-block-cache-deallocation-on-close", false, new 
YesNoField(&skip_block_cache_deallocation_on_close, false)},
       {"histogram-bucket-boundaries", true, new 
StringField(&histogram_bucket_boundaries_str_, "")},
+      {"lua-strict-key-accessing", false, new 
YesNoField(&lua_strict_key_accessing, false)},
 
       /* rocksdb options */
       {"rocksdb.compression", false,
diff --git a/src/config/config.h b/src/config/config.h
index 578843c39..30527eca6 100644
--- a/src/config/config.h
+++ b/src/config/config.h
@@ -193,6 +193,8 @@ struct Config {
 
   bool skip_block_cache_deallocation_on_close = false;
 
+  bool lua_strict_key_accessing = false;
+
   std::vector<double> histogram_bucket_boundaries;
 
   struct RocksDB {
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index b5f865ac2..357b9ae26 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -426,6 +426,9 @@ Status FunctionCall(redis::Connection *conn, 
engine::Context *ctx, const std::st
 
   SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
 
+  // save keys on registry the to perform key touching check
+  SaveOnRegistry(lua, REGISTRY_KEYS_NAME, &keys);
+
   PushArray(lua, keys);
   PushArray(lua, argv);
   if (lua_pcall(lua, 2, 1, -4)) {
@@ -437,6 +440,7 @@ Status FunctionCall(redis::Connection *conn, 
engine::Context *ctx, const std::st
     lua_pop(lua, 2);
   }
 
+  RemoveFromRegistry(lua, REGISTRY_KEYS_NAME);
   RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
 
   /* Call the Lua garbage collector from time to time to avoid a
@@ -704,6 +708,9 @@ Status EvalGenericCommand(redis::Connection *conn, 
engine::Context *ctx, const s
   SetGlobalArray(lua, "KEYS", keys);
   SetGlobalArray(lua, "ARGV", argv);
 
+  // save keys on registry the to perform key touching check
+  SaveOnRegistry(lua, REGISTRY_KEYS_NAME, &keys);
+
   if (lua_pcall(lua, 0, 1, -2)) {
     auto msg = fmt::format("running script (call to {}): {}", funcname, 
lua_tostring(lua, -1));
     *output = redis::Error({Status::NotOK, msg});
@@ -719,6 +726,7 @@ Status EvalGenericCommand(redis::Connection *conn, 
engine::Context *ctx, const s
   lua_setglobal(lua, "KEYS");
   lua_pushnil(lua);
   lua_setglobal(lua, "ARGV");
+  RemoveFromRegistry(lua, REGISTRY_KEYS_NAME);
   RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
 
   /* Call the Lua garbage collector from time to time to avoid a
@@ -814,6 +822,28 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
     return raise_error ? RaiseError(lua) : 1;
   }
 
+  // if it is a write command and strict mode is enabled,
+  // we need to check if the input keys are all in allowed keys
+  if (config->lua_strict_key_accessing && !(cmd_flags & redis::kCmdReadOnly)) {
+    auto allowed_keys = GetFromRegistry<const std::vector<std::string>>(lua, 
REGISTRY_KEYS_NAME);
+
+    attributes->ForEachKeyRange(
+        [&](const std::vector<std::string> &args, const redis::CommandKeyRange 
&range) {
+          range.ForEachKey(
+              [&](const std::string &key) {
+                if (std::find(allowed_keys->begin(), allowed_keys->end(), key) 
== allowed_keys->end()) {
+                  PushError(lua, fmt::format("Script attempted to access key 
'{}' which is not in the allowed keys "
+                                             "(lua-strict-key-accessing)",
+                                             key)
+                                     .c_str());
+                  RaiseError(lua);
+                }
+              },
+              args);
+        },
+        args);
+  }
+
   if (config->cluster_enabled) {
     if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) {
       PushError(lua, "Can not run script on cluster, 'no-cluster' flag is 
set");
@@ -1582,9 +1612,9 @@ Status CreateFunction(Server *srv, const std::string 
&body, std::string *sha, lu
     libname = shebang_split_sv.substr(shebang_libname_prefix.size());
     if (libname.empty() ||
         std::any_of(libname.begin(), libname.end(), [](char v) { return 
!std::isalnum(v) && v != '_'; })) {
-      return {
-          Status::NotOK,
-          "Library names can only contain letters, numbers, or underscores(_) 
and must be at least one character long"};
+      return {Status::NotOK,
+              "Library names can only contain letters, numbers, or 
underscores(_) and must be at least one "
+              "character long"};
     }
     found_libname = true;
   }
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index c28ab6689..bf03f3e10 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -39,7 +39,8 @@ inline constexpr const char 
REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_re
 inline constexpr const char REDIS_FUNCTION_LIBNAME[] = 
"REDIS_FUNCTION_LIBNAME";
 inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = 
"REDIS_FUNCTION_NEEDSTORE";
 inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = 
"REDIS_FUNCTION_LIBRARIES";
-inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX";
+inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = 
"__SCRIPT_RUN_CTX";
+inline constexpr const char REGISTRY_KEYS_NAME[] = "__CURRENT_KEYS";
 
 namespace lua {
 
@@ -165,7 +166,7 @@ template <typename T>
 void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) {
   lua_pushstring(lua, name);
   if (ptr) {
-    lua_pushlightuserdata(lua, ptr);
+    lua_pushlightuserdata(lua, (void *)ptr);
   } else {
     lua_pushnil(lua);
   }
diff --git a/tests/gocase/unit/scripting/function_test.go 
b/tests/gocase/unit/scripting/function_test.go
index 98c92fe1a..5d307b488 100644
--- a/tests/gocase/unit/scripting/function_test.go
+++ b/tests/gocase/unit/scripting/function_test.go
@@ -581,3 +581,37 @@ func TestFunctionScriptFlags(t *testing.T) {
                util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access 
a non local key in a cluster node script")
        })
 }
+
+func TestFunctionInStrictMode(t *testing.T) {
+       srv := util.StartServer(t, map[string]string{})
+       defer srv.Close()
+
+       ctx := context.Background()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       t.Run("Accessing undeclared keys in strict mode", func(t *testing.T) {
+               rdb.FunctionLoad(ctx, `#!lua name=tmplib
+                       redis.register_function('set1', function(keys, args)
+                               return redis.call('set', keys[1], args[1])
+                       end)
+                       redis.register_function('set2', function(keys, args)
+                               return redis.call('set', args[1], args[2])
+                       end)
+               `)
+
+               rdb.ConfigSet(ctx, "lua-strict-key-accessing", "yes")
+
+               util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "set2", 0, "x", 
"1").Err(), ".*'x'.*not in the allowed keys.*")
+               util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "set2", 1, "y", "x", 
"1").Err(), ".*'x'.*not in the allowed keys.*")
+
+               require.NoError(t, rdb.Do(ctx, "FCALL", "set2", 1, "x", "x", 
"1").Err())
+               require.NoError(t, rdb.Do(ctx, "FCALL", "set2", 2, "x", "y", 
"x", "1").Err())
+               require.NoError(t, rdb.Do(ctx, "FCALL", "set1", 1, "x", 
"1").Err())
+
+               rdb.ConfigSet(ctx, "lua-strict-key-accessing", "no")
+
+               require.NoError(t, rdb.Do(ctx, "FCALL", "set2", 0, "x", 
"1").Err())
+               require.NoError(t, rdb.Do(ctx, "FCALL", "set1", 1, "x", 
"1").Err())
+       })
+}
diff --git a/tests/gocase/unit/scripting/scripting_test.go 
b/tests/gocase/unit/scripting/scripting_test.go
index 0c886dfd5..ce606a04f 100644
--- a/tests/gocase/unit/scripting/scripting_test.go
+++ b/tests/gocase/unit/scripting/scripting_test.go
@@ -894,3 +894,34 @@ func TestEvalScriptFlags(t *testing.T) {
 
        })
 }
+
+func TestEvalScriptInStrictMode(t *testing.T) {
+       srv := util.StartServer(t, map[string]string{})
+       defer srv.Close()
+
+       ctx := context.Background()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       t.Run("Accessing undeclared keys in strict mode", func(t *testing.T) {
+               rdb.ConfigSet(ctx, "lua-strict-key-accessing", "yes")
+
+               util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set', 
'a', 1)", []string{}).Err(), ".*'a'.*not in the allowed keys.*")
+               util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set', 
ARGV[1], 1)", []string{}, "a").Err(), ".*'a'.*not in the allowed keys.*")
+               util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set', 
'b', 1)", []string{"a"}).Err(), ".*'b'.*not in the allowed keys.*")
+               util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set', 
KEYS[1]..'b', 1)", []string{"a"}).Err(), ".*'ab'.*not in the allowed keys.*")
+
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 
KEYS[1], 1)", []string{"a"}).Err())
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 
KEYS[2], 1)", []string{"a", "b"}).Err())
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 'a', 
1)", []string{"a", "b"}).Err())
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 'b', 
1)", []string{"a", "b"}).Err())
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 
ARGV[1]..ARGV[2], 1)", []string{"ab"}, "a", "b").Err())
+
+               // read-only commands are allowed as an exception
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('get', 
'a')", []string{}).Err())
+
+               rdb.ConfigSet(ctx, "lua-strict-key-accessing", "no")
+
+               require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 'a', 
1)", []string{}).Err())
+       })
+}

Reply via email to