This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new b81434f38 feat(replication): implement WAIT without timeout support
(#3047)
b81434f38 is described below
commit b81434f38b470df8a65bf9f987fdcbd7c04cf07a
Author: Zhixin Wen <[email protected]>
AuthorDate: Mon Jul 14 22:24:02 2025 -0700
feat(replication): implement WAIT without timeout support (#3047)
Co-authored-by: hulk <[email protected]>
Co-authored-by: Zhixin Wen <[email protected]>
Co-authored-by: Twice <[email protected]>
Co-authored-by: Twice <[email protected]>
---
src/cluster/replication.cc | 4 ++
src/commands/cmd_replication.cc | 46 ++++++++++++-
src/server/server.cc | 58 +++++++++++++++++
src/server/server.h | 21 ++++++
src/server/worker.cc | 1 +
tests/gocase/unit/wait/wait_test.go | 124 ++++++++++++++++++++++++++++++++++++
6 files changed, 253 insertions(+), 1 deletion(-)
diff --git a/src/cluster/replication.cc b/src/cluster/replication.cc
index e032b87be..40b860a41 100644
--- a/src/cluster/replication.cc
+++ b/src/cluster/replication.cc
@@ -159,6 +159,10 @@ void FeedSlaveThread::loop() {
}
curr_seq = batch.sequence + batch.writeBatchPtr->Count();
next_repl_seq_.store(curr_seq);
+
+ // Wake up any WAIT connections that might be waiting for this sequence
+ srv_->WakeupWaitConnections(curr_seq);
+
while (!IsStopped() && !srv_->storage->WALHasNewData(curr_seq)) {
usleep(yield_microseconds);
checkLivenessIfNeed();
diff --git a/src/commands/cmd_replication.cc b/src/commands/cmd_replication.cc
index 2f47a3403..65fd9bbcc 100644
--- a/src/commands/cmd_replication.cc
+++ b/src/commands/cmd_replication.cc
@@ -344,10 +344,54 @@ class CommandDBName : public Commander {
}
};
+class CommandWait : public Commander {
+ public:
+ Status Parse(const std::vector<std::string> &args) override {
+ auto num_replicas_result = ParseInt<int64_t>(args[1], 10);
+ if (!num_replicas_result || *num_replicas_result <= 0) {
+ return {Status::RedisParseErr, "numreplicas should be a positive
integer"};
+ }
+
+ num_replicas_ = *num_replicas_result;
+
+ return Commander::Parse(args);
+ }
+
+ Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv,
Connection *conn, std::string *output) override {
+ // Only master can execute WAIT command
+ if (srv->IsSlave()) {
+ return {Status::RedisExecErr, "WAIT command can only be executed on
master"};
+ }
+
+ // Get current sequence number
+ auto current_seq = srv->storage->LatestSeqNumber();
+
+ // Check if we already have enough replicas at the current sequence
+ size_t reached_replicas = srv->GetReplicasReachedSequence(current_seq);
+
+ // If we already have enough replicas, return immediately
+ if (reached_replicas >= num_replicas_) {
+ *output = redis::Integer(reached_replicas);
+ return Status::OK();
+ }
+
+ // Block the connection and wait for replicas to catch up
+ srv->BlockOnWait(conn, current_seq, num_replicas_);
+
+ // The connection will be woken up by WakeupWaitConnections when enough
replicas
+ // have reached the target sequence
+ return {Status::BlockingCmd};
+ }
+
+ private:
+ uint64_t num_replicas_ = 0;
+};
+
REDIS_REGISTER_COMMANDS(Replication, MakeCmdAttr<CommandReplConf>("replconf",
-3, "read-only no-script", NO_KEY),
MakeCmdAttr<CommandPSync>("psync", -2, "read-only
no-multi no-script", NO_KEY),
MakeCmdAttr<CommandFetchMeta>("_fetch_meta", 1,
"read-only no-multi no-script", NO_KEY),
MakeCmdAttr<CommandFetchFile>("_fetch_file", 2,
"read-only no-multi no-script", NO_KEY),
- MakeCmdAttr<CommandDBName>("_db_name", 1, "read-only
no-multi", NO_KEY), )
+ MakeCmdAttr<CommandDBName>("_db_name", 1, "read-only
no-multi", NO_KEY),
+ MakeCmdAttr<CommandWait>("wait", 2, "read-only
no-multi no-script blocking", NO_KEY), )
} // namespace redis
diff --git a/src/server/server.cc b/src/server/server.cc
index 02755d003..aa594bcd8 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -44,6 +44,7 @@
#include "fmt/format.h"
#include "logging.h"
#include "redis_connection.h"
+#include "redis_reply.h"
#include "rocksdb/version.h"
#include "storage/compaction_checker.h"
#include "storage/redis_db.h"
@@ -700,6 +701,63 @@ void Server::OnEntryAddedToStream(const std::string &ns,
const std::string &key,
}
}
+void Server::BlockOnWait(redis::Connection *conn, rocksdb::SequenceNumber
target_seq, uint64_t num_replicas) {
+ std::lock_guard<std::mutex> guard(wait_contexts_mu_);
+
+ wait_contexts_.emplace_back(conn, target_seq, num_replicas);
+ IncrBlockedClientNum();
+}
+
+void Server::WakeupWaitConnections(rocksdb::SequenceNumber seq) {
+ std::lock_guard<std::mutex> guard(wait_contexts_mu_);
+
+ for (auto it = wait_contexts_.begin(); it != wait_contexts_.end();) {
+ // Check if target sequence is reached
+ if (seq >= it->target_seq) {
+ // Count how many replicas have reached the target sequence
+ size_t reached_replicas = GetReplicasReachedSequence(it->target_seq);
+
+ // If enough replicas have reached the target sequence, wake up the
connection
+ if (reached_replicas >= it->num_replicas) {
+ // Send the response with the number of replicas that have reached the
target sequence
+ it->conn->Reply(redis::Integer(reached_replicas));
+
+ auto s = it->conn->Owner()->EnableWriteEvent(it->conn->GetFD());
+ if (!s.IsOK()) {
+ error("[server] Failed to enable write event on WAIT connection {}:
{}", it->conn->GetFD(), s.Msg());
+ }
+ it = wait_contexts_.erase(it);
+ DecrBlockedClientNum();
+ continue;
+ }
+ }
+
+ ++it;
+ }
+}
+
+void Server::CleanupWaitConnection(redis::Connection *conn) {
+ std::lock_guard<std::mutex> guard(wait_contexts_mu_);
+
+ auto it = std::find_if(wait_contexts_.begin(), wait_contexts_.end(),
+ [conn](const auto &context) { return context.conn ==
conn; });
+ if (it != wait_contexts_.end()) {
+ wait_contexts_.erase(it);
+ DecrBlockedClientNum();
+ }
+}
+
+size_t Server::GetReplicasReachedSequence(rocksdb::SequenceNumber target_seq) {
+ std::lock_guard<std::mutex> slave_guard(slave_threads_mu_);
+ size_t reached_replicas = 0;
+ for (const auto &slave : slave_threads_) {
+ if (!slave->IsStopped() && slave->GetCurrentReplSeq() >= target_seq) {
+ reached_replicas++;
+ }
+ }
+ return reached_replicas;
+}
+
void Server::updateCachedTime() { unix_time_secs.store(util::GetTimeStamp()); }
int Server::IncrClientNum() {
diff --git a/src/server/server.h b/src/server/server.h
index eb3632bb5..b25b6fbd1 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -43,6 +43,7 @@
#include "cluster/slot_import.h"
#include "cluster/slot_migrate.h"
#include "commands/commander.h"
+#include "common/time_util.h"
#include "lua.hpp"
#include "memory_profiler.h"
#include "namespace.h"
@@ -229,6 +230,14 @@ class Server {
void WakeupBlockingConns(const std::string &key, size_t n_conns);
void OnEntryAddedToStream(const std::string &ns, const std::string &key,
const redis::StreamEntryID &entry_id);
+ // WAIT command infrastructure
+ void BlockOnWait(redis::Connection *conn, rocksdb::SequenceNumber
target_seq, uint64_t num_replicas);
+ void WakeupWaitConnections(rocksdb::SequenceNumber seq);
+ void CleanupWaitConnection(redis::Connection *conn);
+
+ // Helper methods for WAIT command
+ size_t GetReplicasReachedSequence(rocksdb::SequenceNumber target_seq);
+
size_t GetReplicaCount() {
slave_threads_mu_.lock();
auto replica_count = slave_threads_.size();
@@ -404,6 +413,18 @@ class Server {
std::mutex blocked_stream_consumers_mu_;
std::map<std::string, std::set<std::shared_ptr<StreamConsumer>>>
blocked_stream_consumers_;
+ // WAIT command blocking infrastructure
+ struct WaitContext {
+ redis::Connection *conn;
+ rocksdb::SequenceNumber target_seq;
+ uint64_t num_replicas;
+
+ WaitContext(redis::Connection *c, rocksdb::SequenceNumber seq, uint64_t
replicas)
+ : conn(c), target_seq(seq), num_replicas(replicas) {}
+ };
+ std::list<WaitContext> wait_contexts_;
+ std::mutex wait_contexts_mu_;
+
// threads
std::shared_mutex works_concurrency_rw_lock_;
std::thread cron_thread_;
diff --git a/src/server/worker.cc b/src/server/worker.cc
index e65467f2c..dcba89c11 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -426,6 +426,7 @@ void Worker::FreeConnection(redis::Connection *conn) {
removeConnection(conn->GetFD());
srv->ResetWatchedKeys(conn);
+ srv->CleanupWaitConnection(conn);
if (rate_limit_group_) {
bufferevent_remove_from_rate_limit_group(conn->GetBufferEvent());
}
diff --git a/tests/gocase/unit/wait/wait_test.go
b/tests/gocase/unit/wait/wait_test.go
new file mode 100644
index 000000000..1030fc8a7
--- /dev/null
+++ b/tests/gocase/unit/wait/wait_test.go
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package wait
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/apache/kvrocks/tests/gocase/util"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWaitCommand(t *testing.T) {
+ // Start master server
+ masterSrv := util.StartServer(t, map[string]string{})
+ defer masterSrv.Close()
+
+ // Start slave server
+ slaveSrv := util.StartServer(t, map[string]string{})
+ defer slaveSrv.Close()
+
+ ctx := context.Background()
+ masterRdb := masterSrv.NewClient()
+ defer func() { require.NoError(t, masterRdb.Close()) }()
+
+ slaveRdb := slaveSrv.NewClient()
+ defer func() { require.NoError(t, slaveRdb.Close()) }()
+
+ // Set up replication
+ util.SlaveOf(t, slaveRdb, masterSrv)
+ util.WaitForSync(t, slaveRdb)
+
+ t.Run("WAIT with negative number should return error", func(t
*testing.T) {
+ result := masterRdb.Do(ctx, "WAIT", "-1")
+ require.Error(t, result.Err())
+ require.Contains(t, result.Err().Error(), "numreplicas should
be a positive integer")
+ })
+
+ t.Run("WAIT with invalid arguments should return error", func(t
*testing.T) {
+ result := masterRdb.Do(ctx, "WAIT")
+ require.Error(t, result.Err())
+ require.Contains(t, result.Err().Error(), "wrong number of
arguments")
+
+ result = masterRdb.Do(ctx, "WAIT", "1", "1000")
+ require.Error(t, result.Err())
+ require.Contains(t, result.Err().Error(), "wrong number of
arguments")
+ })
+
+ t.Run("WAIT should not block indefinitely", func(t *testing.T) {
+ // Start a goroutine to execute WAIT
+ done := make(chan bool, 1)
+ go func() {
+ require.NoError(t, masterRdb.Do(ctx, "SET", "k1",
"v1").Err())
+ require.NoError(t, masterRdb.Do(ctx, "WAIT", "1").Err())
+ done <- true
+ }()
+
+ // Wait for the command to complete (should be immediate)
+ select {
+ case <-done:
+ // Success - command completed immediately
+ case <-time.After(5 * time.Second):
+ t.Fatal("WAIT command blocked indefinitely")
+ }
+ })
+
+ t.Run("WAIT should block until enough replicas acknowledge", func(t
*testing.T) {
+ // Disconnect the slave
+ require.NoError(t, slaveRdb.Do(ctx, "SLAVEOF", "NO",
"ONE").Err())
+
+ // Master remove the slave from the replication list
periodically
+ // so we need to wait for the master to detect the disconnection
+ require.Eventually(t, func() bool {
+ info := masterRdb.Info(ctx, "replication").Val()
+ return !strings.Contains(info, "connected_slaves:1")
+ }, 50*time.Second, 100*time.Millisecond)
+
+ // Start a goroutine to execute WAIT
+ done := make(chan bool, 1)
+ go func() {
+ require.NoError(t, masterRdb.Do(ctx, "SET", "k1",
"v1").Err())
+ require.NoError(t, masterRdb.Do(ctx, "WAIT", "1").Err())
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ t.Fatal("WAIT command did not block")
+ case <-time.After(1 * time.Second):
+ // Success - command blocked
+ }
+
+ // Reconnect the slave
+ util.SlaveOf(t, slaveRdb, masterSrv)
+ util.WaitForSync(t, slaveRdb)
+
+ // Now WAIT should complete
+ select {
+ case <-done:
+ // Success - command completed after replica connected
+ case <-time.After(5 * time.Second):
+ t.Fatal("WAIT command did not complete after replica
connected")
+ }
+ })
+}