This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 3afb62535 feat(scripting): support strict key-accessing mode for lua
scripting (#3139)
3afb62535 is described below
commit 3afb6253567c19c7b1695da6f6ae47cf6dbce48e
Author: Twice <[email protected]>
AuthorDate: Sat Aug 23 13:36:09 2025 +0800
feat(scripting): support strict key-accessing mode for lua scripting (#3139)
Accessing undeclared keys in lua scripting may lead to unexpected
behavior in the current design of kvrocks (also in redis, refer to
https://redis.io/docs/latest/commands/eval/), so in this PR we add a new
option `lua-strict-key-accessing` to prevent users to access undeclared
keys in lua, e.g.
```
EVAL "return redis.call('set', 'a', 1)" // ERROR!
EVAL "return redis.call('set', KEYS[1], 1)" 1 a // ok
```
This check is performed in both lua scripting and lua functions.
---
kvrocks.conf | 13 ++++++++++
src/config/config.cc | 1 +
src/config/config.h | 2 ++
src/storage/scripting.cc | 36 ++++++++++++++++++++++++---
src/storage/scripting.h | 5 ++--
tests/gocase/unit/scripting/function_test.go | 34 +++++++++++++++++++++++++
tests/gocase/unit/scripting/scripting_test.go | 31 +++++++++++++++++++++++
7 files changed, 117 insertions(+), 5 deletions(-)
diff --git a/kvrocks.conf b/kvrocks.conf
index 67819c274..b56934345 100644
--- a/kvrocks.conf
+++ b/kvrocks.conf
@@ -420,6 +420,19 @@ txn-context-enabled no
# Default: disabled
# histogram-bucket-boundaries
10,20,40,60,80,100,150,250,350,500,750,1000,1500,2000,4000,8000
+# Whether the strict key-accessing mode of lua scripting is enabled.
+#
+# If enabled, the lua script will abort and report errors
+# if it tries to access keys that are not declared in
+# the script's `KEYS` table or the function's `keys` argument.
+#
+# Note that accessing undeclared keys may lead to unexpected behavior,
+# so this option is to ensure that scripts only access keys
+# that are explicitly declared.
+#
+# Default: no
+lua-strict-key-accessing no
+
################################## TLS ###################################
# By default, TLS/SSL is disabled, i.e. `tls-port` is set to 0.
diff --git a/src/config/config.cc b/src/config/config.cc
index f17db7829..fb0f68f67 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -243,6 +243,7 @@ Config::Config() {
{"txn-context-enabled", true, new YesNoField(&txn_context_enabled,
false)},
{"skip-block-cache-deallocation-on-close", false, new
YesNoField(&skip_block_cache_deallocation_on_close, false)},
{"histogram-bucket-boundaries", true, new
StringField(&histogram_bucket_boundaries_str_, "")},
+ {"lua-strict-key-accessing", false, new
YesNoField(&lua_strict_key_accessing, false)},
/* rocksdb options */
{"rocksdb.compression", false,
diff --git a/src/config/config.h b/src/config/config.h
index 578843c39..30527eca6 100644
--- a/src/config/config.h
+++ b/src/config/config.h
@@ -193,6 +193,8 @@ struct Config {
bool skip_block_cache_deallocation_on_close = false;
+ bool lua_strict_key_accessing = false;
+
std::vector<double> histogram_bucket_boundaries;
struct RocksDB {
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index b5f865ac2..357b9ae26 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -426,6 +426,9 @@ Status FunctionCall(redis::Connection *conn,
engine::Context *ctx, const std::st
SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
+ // save keys on registry the to perform key touching check
+ SaveOnRegistry(lua, REGISTRY_KEYS_NAME, &keys);
+
PushArray(lua, keys);
PushArray(lua, argv);
if (lua_pcall(lua, 2, 1, -4)) {
@@ -437,6 +440,7 @@ Status FunctionCall(redis::Connection *conn,
engine::Context *ctx, const std::st
lua_pop(lua, 2);
}
+ RemoveFromRegistry(lua, REGISTRY_KEYS_NAME);
RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
/* Call the Lua garbage collector from time to time to avoid a
@@ -704,6 +708,9 @@ Status EvalGenericCommand(redis::Connection *conn,
engine::Context *ctx, const s
SetGlobalArray(lua, "KEYS", keys);
SetGlobalArray(lua, "ARGV", argv);
+ // save keys on registry the to perform key touching check
+ SaveOnRegistry(lua, REGISTRY_KEYS_NAME, &keys);
+
if (lua_pcall(lua, 0, 1, -2)) {
auto msg = fmt::format("running script (call to {}): {}", funcname,
lua_tostring(lua, -1));
*output = redis::Error({Status::NotOK, msg});
@@ -719,6 +726,7 @@ Status EvalGenericCommand(redis::Connection *conn,
engine::Context *ctx, const s
lua_setglobal(lua, "KEYS");
lua_pushnil(lua);
lua_setglobal(lua, "ARGV");
+ RemoveFromRegistry(lua, REGISTRY_KEYS_NAME);
RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
/* Call the Lua garbage collector from time to time to avoid a
@@ -814,6 +822,28 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
return raise_error ? RaiseError(lua) : 1;
}
+ // if it is a write command and strict mode is enabled,
+ // we need to check if the input keys are all in allowed keys
+ if (config->lua_strict_key_accessing && !(cmd_flags & redis::kCmdReadOnly)) {
+ auto allowed_keys = GetFromRegistry<const std::vector<std::string>>(lua,
REGISTRY_KEYS_NAME);
+
+ attributes->ForEachKeyRange(
+ [&](const std::vector<std::string> &args, const redis::CommandKeyRange
&range) {
+ range.ForEachKey(
+ [&](const std::string &key) {
+ if (std::find(allowed_keys->begin(), allowed_keys->end(), key)
== allowed_keys->end()) {
+ PushError(lua, fmt::format("Script attempted to access key
'{}' which is not in the allowed keys "
+ "(lua-strict-key-accessing)",
+ key)
+ .c_str());
+ RaiseError(lua);
+ }
+ },
+ args);
+ },
+ args);
+ }
+
if (config->cluster_enabled) {
if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) {
PushError(lua, "Can not run script on cluster, 'no-cluster' flag is
set");
@@ -1582,9 +1612,9 @@ Status CreateFunction(Server *srv, const std::string
&body, std::string *sha, lu
libname = shebang_split_sv.substr(shebang_libname_prefix.size());
if (libname.empty() ||
std::any_of(libname.begin(), libname.end(), [](char v) { return
!std::isalnum(v) && v != '_'; })) {
- return {
- Status::NotOK,
- "Library names can only contain letters, numbers, or underscores(_)
and must be at least one character long"};
+ return {Status::NotOK,
+ "Library names can only contain letters, numbers, or
underscores(_) and must be at least one "
+ "character long"};
}
found_libname = true;
}
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index c28ab6689..bf03f3e10 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -39,7 +39,8 @@ inline constexpr const char
REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_re
inline constexpr const char REDIS_FUNCTION_LIBNAME[] =
"REDIS_FUNCTION_LIBNAME";
inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] =
"REDIS_FUNCTION_NEEDSTORE";
inline constexpr const char REDIS_FUNCTION_LIBRARIES[] =
"REDIS_FUNCTION_LIBRARIES";
-inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX";
+inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] =
"__SCRIPT_RUN_CTX";
+inline constexpr const char REGISTRY_KEYS_NAME[] = "__CURRENT_KEYS";
namespace lua {
@@ -165,7 +166,7 @@ template <typename T>
void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) {
lua_pushstring(lua, name);
if (ptr) {
- lua_pushlightuserdata(lua, ptr);
+ lua_pushlightuserdata(lua, (void *)ptr);
} else {
lua_pushnil(lua);
}
diff --git a/tests/gocase/unit/scripting/function_test.go
b/tests/gocase/unit/scripting/function_test.go
index 98c92fe1a..5d307b488 100644
--- a/tests/gocase/unit/scripting/function_test.go
+++ b/tests/gocase/unit/scripting/function_test.go
@@ -581,3 +581,37 @@ func TestFunctionScriptFlags(t *testing.T) {
util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
})
}
+
+func TestFunctionInStrictMode(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{})
+ defer srv.Close()
+
+ ctx := context.Background()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ t.Run("Accessing undeclared keys in strict mode", func(t *testing.T) {
+ rdb.FunctionLoad(ctx, `#!lua name=tmplib
+ redis.register_function('set1', function(keys, args)
+ return redis.call('set', keys[1], args[1])
+ end)
+ redis.register_function('set2', function(keys, args)
+ return redis.call('set', args[1], args[2])
+ end)
+ `)
+
+ rdb.ConfigSet(ctx, "lua-strict-key-accessing", "yes")
+
+ util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "set2", 0, "x",
"1").Err(), ".*'x'.*not in the allowed keys.*")
+ util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "set2", 1, "y", "x",
"1").Err(), ".*'x'.*not in the allowed keys.*")
+
+ require.NoError(t, rdb.Do(ctx, "FCALL", "set2", 1, "x", "x",
"1").Err())
+ require.NoError(t, rdb.Do(ctx, "FCALL", "set2", 2, "x", "y",
"x", "1").Err())
+ require.NoError(t, rdb.Do(ctx, "FCALL", "set1", 1, "x",
"1").Err())
+
+ rdb.ConfigSet(ctx, "lua-strict-key-accessing", "no")
+
+ require.NoError(t, rdb.Do(ctx, "FCALL", "set2", 0, "x",
"1").Err())
+ require.NoError(t, rdb.Do(ctx, "FCALL", "set1", 1, "x",
"1").Err())
+ })
+}
diff --git a/tests/gocase/unit/scripting/scripting_test.go
b/tests/gocase/unit/scripting/scripting_test.go
index 0c886dfd5..ce606a04f 100644
--- a/tests/gocase/unit/scripting/scripting_test.go
+++ b/tests/gocase/unit/scripting/scripting_test.go
@@ -894,3 +894,34 @@ func TestEvalScriptFlags(t *testing.T) {
})
}
+
+func TestEvalScriptInStrictMode(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{})
+ defer srv.Close()
+
+ ctx := context.Background()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ t.Run("Accessing undeclared keys in strict mode", func(t *testing.T) {
+ rdb.ConfigSet(ctx, "lua-strict-key-accessing", "yes")
+
+ util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set',
'a', 1)", []string{}).Err(), ".*'a'.*not in the allowed keys.*")
+ util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set',
ARGV[1], 1)", []string{}, "a").Err(), ".*'a'.*not in the allowed keys.*")
+ util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set',
'b', 1)", []string{"a"}).Err(), ".*'b'.*not in the allowed keys.*")
+ util.ErrorRegexp(t, rdb.Eval(ctx, "return redis.call('set',
KEYS[1]..'b', 1)", []string{"a"}).Err(), ".*'ab'.*not in the allowed keys.*")
+
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('set',
KEYS[1], 1)", []string{"a"}).Err())
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('set',
KEYS[2], 1)", []string{"a", "b"}).Err())
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 'a',
1)", []string{"a", "b"}).Err())
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 'b',
1)", []string{"a", "b"}).Err())
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('set',
ARGV[1]..ARGV[2], 1)", []string{"ab"}, "a", "b").Err())
+
+ // read-only commands are allowed as an exception
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('get',
'a')", []string{}).Err())
+
+ rdb.ConfigSet(ctx, "lua-strict-key-accessing", "no")
+
+ require.NoError(t, rdb.Eval(ctx, "return redis.call('set', 'a',
1)", []string{}).Err())
+ })
+}