This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 6fcc38fc8 fix(protocol): inline mode should allow the quoted string
(#2873)
6fcc38fc8 is described below
commit 6fcc38fc8c3486de4159eeb0609afc3572516c37
Author: hulk <[email protected]>
AuthorDate: Tue Apr 15 09:11:27 2025 +0800
fix(protocol): inline mode should allow the quoted string (#2873)
---
src/common/string_util.cc | 101 ++++++++++++++++++++++++++++
src/common/string_util.h | 1 +
src/server/redis_request.cc | 6 +-
tests/cppunit/string_util_test.cc | 60 +++++++++++++++++
tests/gocase/unit/protocol/protocol_test.go | 20 ++++++
5 files changed, 187 insertions(+), 1 deletion(-)
diff --git a/src/common/string_util.cc b/src/common/string_util.cc
index c1e6f7e3f..f05e286fc 100644
--- a/src/common/string_util.cc
+++ b/src/common/string_util.cc
@@ -276,6 +276,107 @@ std::pair<std::string, std::string>
SplitGlob(std::string_view glob) {
return {prefix, ""};
}
+static int HexDigitToInt(const char c) {
+ if (c >= '0' && c <= '9') {
+ return c - '0';
+ } else if (c >= 'a' && c <= 'f') {
+ return c - 'a' + 10;
+ } else if (c >= 'A' && c <= 'F') {
+ return c - 'A' + 10;
+ }
+ return 0;
+}
+
+StatusOr<std::vector<std::string>> SplitArguments(std::string_view in) {
+ std::vector<std::string> arguments;
+ std::string current_string;
+
+ enum State { NORMAL, DOUBLE_QUOTED, SINGLE_QUOTED, ESCAPE } state = NORMAL;
+
+ bool done = false;
+ for (size_t i = 0; i < in.size() && !done; i++) {
+ const auto c = in[i];
+ switch (state) {
+ case NORMAL:
+ if (std::isspace(c)) {
+ if (!current_string.empty()) {
+ arguments.emplace_back(std::move(current_string));
+ current_string.clear();
+ }
+ } else if (c == '\r' || c == '\n' || c == '\t') {
+ done = true;
+ } else if (c == '"') {
+ state = DOUBLE_QUOTED;
+ } else if (c == '\'') {
+ state = SINGLE_QUOTED;
+ } else {
+ current_string.push_back(c);
+ }
+ break;
+ case SINGLE_QUOTED:
+ if (c == '\\' && (i + 1) < in.size() && in[i + 1] == '\'') {
+ current_string.push_back('\'');
+ i++;
+ } else if (c == '\'') {
+ //
+ if (i + 1 < in.size() && !std::isspace(in[i + 1])) {
+ return {Status::NotOK, "the closed single quote must be followed
by a space"};
+ }
+ state = NORMAL;
+ } else {
+ current_string.push_back(c);
+ }
+ break;
+ case DOUBLE_QUOTED:
+ if (c == '\\') {
+ state = ESCAPE;
+ } else if (c == '"') {
+ if (i + 1 < in.size() && !std::isspace(in[i + 1])) {
+ return {Status::NotOK, "the closed double quote must be followed
by a space"};
+ }
+ state = NORMAL;
+ } else {
+ current_string.push_back(c);
+ }
+ break;
+ case ESCAPE:
+ // It's the hex digit after the \x
+ if (c == 'x' && (i + 2) < in.size() && std::isxdigit(in[i + 1]) &&
std::isxdigit(in[i + 2])) {
+ // Convert the hex digit to a char
+ auto hex_byte = static_cast<char>(HexDigitToInt(in[i + 1]) * 16 |
HexDigitToInt(in[i + 2]));
+ current_string.push_back(hex_byte);
+ i += 2;
+ } else if (c == '"' || c == '\'' || c == '\\') {
+ current_string.push_back(c);
+ } else if (c == 'n') {
+ current_string.push_back('\n');
+ } else if (c == 'r') {
+ current_string.push_back('\r');
+ } else if (c == 't') {
+ current_string.push_back('\t');
+ } else if (c == 'b') {
+ current_string.push_back('\b');
+ } else if (c == 'a') {
+ current_string.push_back('\a');
+ } else {
+ current_string.push_back(c);
+ }
+ state = DOUBLE_QUOTED;
+ break;
+ }
+ }
+ if (state == DOUBLE_QUOTED || state == SINGLE_QUOTED) {
+ return {Status::NotOK, "unclosed quote string"};
+ }
+ if (state == ESCAPE) {
+ return {Status::NotOK, "unexpected trailing escape character"};
+ }
+ if (!current_string.empty()) {
+ arguments.emplace_back(std::move(current_string));
+ }
+ return arguments;
+}
+
std::vector<std::string> RegexMatch(const std::string &str, const std::string
®ex) {
std::regex base_regex(regex);
std::smatch pieces_match;
diff --git a/src/common/string_util.h b/src/common/string_util.h
index fe66be5a2..32c0f4b8e 100644
--- a/src/common/string_util.h
+++ b/src/common/string_util.h
@@ -50,6 +50,7 @@ Iter FindICase(Iter begin, Iter end, std::string_view
expected) {
Status ValidateGlob(std::string_view glob);
bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case
= false);
std::pair<std::string, std::string> SplitGlob(std::string_view glob);
+StatusOr<std::vector<std::string>> SplitArguments(std::string_view in);
std::vector<std::string> RegexMatch(const std::string &str, const std::string
®ex);
std::string StringToHex(std::string_view input);
diff --git a/src/server/redis_request.cc b/src/server/redis_request.cc
index f1e2eb993..f729e85c2 100644
--- a/src/server/redis_request.cc
+++ b/src/server/redis_request.cc
@@ -86,7 +86,11 @@ Status Request::Tokenize(evbuffer *input) {
return {Status::NotOK, "Protocol error: invalid bulk length"};
}
- tokens_ = util::Split(std::string(line.get(), line.length), " \t");
+ auto arguments = util::SplitArguments(line.get());
+ if (!arguments.IsOK()) {
+ return {Status::NotOK, "Protocol error: " + arguments.Msg()};
+ }
+ tokens_ = std::move(arguments.GetValue());
if (tokens_.empty()) continue;
commands_.emplace_back(std::move(tokens_));
state_ = ArrayLen;
diff --git a/tests/cppunit/string_util_test.cc
b/tests/cppunit/string_util_test.cc
index 1d24cf594..d31a99615 100644
--- a/tests/cppunit/string_util_test.cc
+++ b/tests/cppunit/string_util_test.cc
@@ -268,3 +268,63 @@ TEST(StringUtil, RegexMatchExtractSSTFile) {
ASSERT_TRUE(match_results[1] == "/000038.sst");
}
}
+
+TEST(StringUtil, SplitArguments) {
+ std::map<std::string, std::vector<std::string>> valid_cases = {
+ // With ' ' only
+ {"a b c", {"a", "b", "c"}},
+ // Other whitespace characters should work
+ {"a\tb\nc\fd", {"a", "b", "c", "d"}},
+
+ // With double quote escape characters
+ {R"(hello "a b" c)", {"hello", "a b", "c"}},
+ // With single quote escape characters
+ {R"('a b' c)", {"a b", "c"}},
+ // With both single and double quote escape characters
+ {R"(a 'b c' " d e ")", {"a", "b c", " d e "}},
+ // With both single and double quote escape characters
+ {R"(a " b c " 'd e')", {"a", " b c ", "d e"}},
+
+ // With the single quote escape characters
+ {R"('a\' b' c)", {"a' b", "c"}},
+ {R"('a\n\t\r\'b' c)", {R"(a\n\t\r'b)", "c"}},
+
+ // With the double quote escape characters
+ {R"("a\"b" c)", {"a\"b", "c"}},
+ {R"("a\n\t\qb\g" c)", {"a\n\tqbg", "c"}},
+
+ // Escape with the hex digits
+ {R"(\x61 \x62 \x63)", {R"(\x61)", R"(\x62)", R"(\x63)"}},
+ {R"("a \x61\x62" "\x63")", {"a ab", "c"}},
+ // '\' will be removed from '\xT0' because it's not v alid hex digit and
a valid escape sequence
+ {R"("a \xT0\x62" "\x63")", {R"(a xT0b)", "c"}},
+ {R"("a b\x6Fc" "d\x63e")", {"a boc", "dce"}},
+
+ };
+ for (const auto &item : valid_cases) {
+ const std::string &input = item.first;
+ const std::vector<std::string> &expected = item.second;
+ auto result = util::SplitArguments(input);
+ ASSERT_TRUE(result.IsOK());
+ ASSERT_EQ(result.GetValue(), expected);
+ }
+
+ // invalid cases
+ std::map<std::string, std::string> invalid_cases = {
+ {R"(a "b c)", "unclosed quote string"},
+ {R"(a 'b c)", "unclosed quote string"},
+ {R"(a "b' c)", "unclosed quote string"},
+ {R"(a 'b" c)", "unclosed quote string"},
+ {R"(a b 'c\)", "unclosed quote string"},
+ {R"(a b "c\)", "unexpected trailing escape character"},
+ {R"(a b "c"d)", "the closed double quote must be followed by a space"},
+ {R"(a 'b'c)", "the closed single quote must be followed by a space"},
+ };
+ for (const auto &item : invalid_cases) {
+ const std::string &input = item.first;
+ const std::string &expected_error = item.second;
+ auto result = util::SplitArguments(input);
+ ASSERT_FALSE(result.IsOK());
+ ASSERT_EQ(result.Msg(), expected_error);
+ }
+}
diff --git a/tests/gocase/unit/protocol/protocol_test.go
b/tests/gocase/unit/protocol/protocol_test.go
index 6be669bb8..a533bfde9 100644
--- a/tests/gocase/unit/protocol/protocol_test.go
+++ b/tests/gocase/unit/protocol/protocol_test.go
@@ -114,6 +114,26 @@ func TestProtocolNetwork(t *testing.T) {
c.MustRead(t, "+OK")
})
+ t.Run("inline protocol with quoted string", func(t *testing.T) {
+ c := srv.NewTCPClient()
+ LF := "\n"
+ defer func() { require.NoError(t, c.Close()) }()
+ require.NoError(t, c.Write("RPUSH my_list a 'b c' d"+LF))
+ c.MustRead(t, ":3")
+ require.NoError(t, c.Write(`RPUSH my_list "foo \x61\x62"`+LF))
+ c.MustRead(t, ":4")
+ require.NoError(t, c.Write(`RPUSH my_list "bar \"\g\t\n\q"`+LF))
+ c.MustRead(t, ":5")
+ require.NoError(t, c.Write(`RPUSH my_list ' a b' "c d e " `+LF))
+ c.MustRead(t, ":7")
+
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+ values, err := rdb.LRange(context.Background(), "my_list", 0,
-1).Result()
+ require.NoError(t, err)
+ require.Equal(t, []string{"a", "b c", "d", "foo ab", "bar
\"g\t\nq", " a b", "c d e "}, values)
+ })
+
t.Run("mix LF/CRLF protocol separator", func(t *testing.T) {
c := srv.NewTCPClient()
defer func() { require.NoError(t, c.Close()) }()