This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new b81434f38 feat(replication): implement WAIT without timeout support 
(#3047)
b81434f38 is described below

commit b81434f38b470df8a65bf9f987fdcbd7c04cf07a
Author: Zhixin Wen <[email protected]>
AuthorDate: Mon Jul 14 22:24:02 2025 -0700

    feat(replication): implement WAIT without timeout support (#3047)
    
    Co-authored-by: hulk <[email protected]>
    Co-authored-by: Zhixin Wen <[email protected]>
    Co-authored-by: Twice <[email protected]>
    Co-authored-by: Twice <[email protected]>
---
 src/cluster/replication.cc          |   4 ++
 src/commands/cmd_replication.cc     |  46 ++++++++++++-
 src/server/server.cc                |  58 +++++++++++++++++
 src/server/server.h                 |  21 ++++++
 src/server/worker.cc                |   1 +
 tests/gocase/unit/wait/wait_test.go | 124 ++++++++++++++++++++++++++++++++++++
 6 files changed, 253 insertions(+), 1 deletion(-)

diff --git a/src/cluster/replication.cc b/src/cluster/replication.cc
index e032b87be..40b860a41 100644
--- a/src/cluster/replication.cc
+++ b/src/cluster/replication.cc
@@ -159,6 +159,10 @@ void FeedSlaveThread::loop() {
     }
     curr_seq = batch.sequence + batch.writeBatchPtr->Count();
     next_repl_seq_.store(curr_seq);
+
+    // Wake up any WAIT connections that might be waiting for this sequence
+    srv_->WakeupWaitConnections(curr_seq);
+
     while (!IsStopped() && !srv_->storage->WALHasNewData(curr_seq)) {
       usleep(yield_microseconds);
       checkLivenessIfNeed();
diff --git a/src/commands/cmd_replication.cc b/src/commands/cmd_replication.cc
index 2f47a3403..65fd9bbcc 100644
--- a/src/commands/cmd_replication.cc
+++ b/src/commands/cmd_replication.cc
@@ -344,10 +344,54 @@ class CommandDBName : public Commander {
   }
 };
 
+class CommandWait : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    auto num_replicas_result = ParseInt<int64_t>(args[1], 10);
+    if (!num_replicas_result || *num_replicas_result <= 0) {
+      return {Status::RedisParseErr, "numreplicas should be a positive 
integer"};
+    }
+
+    num_replicas_ = *num_replicas_result;
+
+    return Commander::Parse(args);
+  }
+
+  Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, 
Connection *conn, std::string *output) override {
+    // Only master can execute WAIT command
+    if (srv->IsSlave()) {
+      return {Status::RedisExecErr, "WAIT command can only be executed on 
master"};
+    }
+
+    // Get current sequence number
+    auto current_seq = srv->storage->LatestSeqNumber();
+
+    // Check if we already have enough replicas at the current sequence
+    size_t reached_replicas = srv->GetReplicasReachedSequence(current_seq);
+
+    // If we already have enough replicas, return immediately
+    if (reached_replicas >= num_replicas_) {
+      *output = redis::Integer(reached_replicas);
+      return Status::OK();
+    }
+
+    // Block the connection and wait for replicas to catch up
+    srv->BlockOnWait(conn, current_seq, num_replicas_);
+
+    // The connection will be woken up by WakeupWaitConnections when enough 
replicas
+    // have reached the target sequence
+    return {Status::BlockingCmd};
+  }
+
+ private:
+  uint64_t num_replicas_ = 0;
+};
+
 REDIS_REGISTER_COMMANDS(Replication, MakeCmdAttr<CommandReplConf>("replconf", 
-3, "read-only no-script", NO_KEY),
                         MakeCmdAttr<CommandPSync>("psync", -2, "read-only 
no-multi no-script", NO_KEY),
                         MakeCmdAttr<CommandFetchMeta>("_fetch_meta", 1, 
"read-only no-multi no-script", NO_KEY),
                         MakeCmdAttr<CommandFetchFile>("_fetch_file", 2, 
"read-only no-multi no-script", NO_KEY),
-                        MakeCmdAttr<CommandDBName>("_db_name", 1, "read-only 
no-multi", NO_KEY), )
+                        MakeCmdAttr<CommandDBName>("_db_name", 1, "read-only 
no-multi", NO_KEY),
+                        MakeCmdAttr<CommandWait>("wait", 2, "read-only 
no-multi no-script blocking", NO_KEY), )
 
 }  // namespace redis
diff --git a/src/server/server.cc b/src/server/server.cc
index 02755d003..aa594bcd8 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -44,6 +44,7 @@
 #include "fmt/format.h"
 #include "logging.h"
 #include "redis_connection.h"
+#include "redis_reply.h"
 #include "rocksdb/version.h"
 #include "storage/compaction_checker.h"
 #include "storage/redis_db.h"
@@ -700,6 +701,63 @@ void Server::OnEntryAddedToStream(const std::string &ns, 
const std::string &key,
   }
 }
 
+void Server::BlockOnWait(redis::Connection *conn, rocksdb::SequenceNumber 
target_seq, uint64_t num_replicas) {
+  std::lock_guard<std::mutex> guard(wait_contexts_mu_);
+
+  wait_contexts_.emplace_back(conn, target_seq, num_replicas);
+  IncrBlockedClientNum();
+}
+
+void Server::WakeupWaitConnections(rocksdb::SequenceNumber seq) {
+  std::lock_guard<std::mutex> guard(wait_contexts_mu_);
+
+  for (auto it = wait_contexts_.begin(); it != wait_contexts_.end();) {
+    // Check if target sequence is reached
+    if (seq >= it->target_seq) {
+      // Count how many replicas have reached the target sequence
+      size_t reached_replicas = GetReplicasReachedSequence(it->target_seq);
+
+      // If enough replicas have reached the target sequence, wake up the 
connection
+      if (reached_replicas >= it->num_replicas) {
+        // Send the response with the number of replicas that have reached the 
target sequence
+        it->conn->Reply(redis::Integer(reached_replicas));
+
+        auto s = it->conn->Owner()->EnableWriteEvent(it->conn->GetFD());
+        if (!s.IsOK()) {
+          error("[server] Failed to enable write event on WAIT connection {}: 
{}", it->conn->GetFD(), s.Msg());
+        }
+        it = wait_contexts_.erase(it);
+        DecrBlockedClientNum();
+        continue;
+      }
+    }
+
+    ++it;
+  }
+}
+
+void Server::CleanupWaitConnection(redis::Connection *conn) {
+  std::lock_guard<std::mutex> guard(wait_contexts_mu_);
+
+  auto it = std::find_if(wait_contexts_.begin(), wait_contexts_.end(),
+                         [conn](const auto &context) { return context.conn == 
conn; });
+  if (it != wait_contexts_.end()) {
+    wait_contexts_.erase(it);
+    DecrBlockedClientNum();
+  }
+}
+
+size_t Server::GetReplicasReachedSequence(rocksdb::SequenceNumber target_seq) {
+  std::lock_guard<std::mutex> slave_guard(slave_threads_mu_);
+  size_t reached_replicas = 0;
+  for (const auto &slave : slave_threads_) {
+    if (!slave->IsStopped() && slave->GetCurrentReplSeq() >= target_seq) {
+      reached_replicas++;
+    }
+  }
+  return reached_replicas;
+}
+
 void Server::updateCachedTime() { unix_time_secs.store(util::GetTimeStamp()); }
 
 int Server::IncrClientNum() {
diff --git a/src/server/server.h b/src/server/server.h
index eb3632bb5..b25b6fbd1 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -43,6 +43,7 @@
 #include "cluster/slot_import.h"
 #include "cluster/slot_migrate.h"
 #include "commands/commander.h"
+#include "common/time_util.h"
 #include "lua.hpp"
 #include "memory_profiler.h"
 #include "namespace.h"
@@ -229,6 +230,14 @@ class Server {
   void WakeupBlockingConns(const std::string &key, size_t n_conns);
   void OnEntryAddedToStream(const std::string &ns, const std::string &key, 
const redis::StreamEntryID &entry_id);
 
+  // WAIT command infrastructure
+  void BlockOnWait(redis::Connection *conn, rocksdb::SequenceNumber 
target_seq, uint64_t num_replicas);
+  void WakeupWaitConnections(rocksdb::SequenceNumber seq);
+  void CleanupWaitConnection(redis::Connection *conn);
+
+  // Helper methods for WAIT command
+  size_t GetReplicasReachedSequence(rocksdb::SequenceNumber target_seq);
+
   size_t GetReplicaCount() {
     slave_threads_mu_.lock();
     auto replica_count = slave_threads_.size();
@@ -404,6 +413,18 @@ class Server {
   std::mutex blocked_stream_consumers_mu_;
   std::map<std::string, std::set<std::shared_ptr<StreamConsumer>>> 
blocked_stream_consumers_;
 
+  // WAIT command blocking infrastructure
+  struct WaitContext {
+    redis::Connection *conn;
+    rocksdb::SequenceNumber target_seq;
+    uint64_t num_replicas;
+
+    WaitContext(redis::Connection *c, rocksdb::SequenceNumber seq, uint64_t 
replicas)
+        : conn(c), target_seq(seq), num_replicas(replicas) {}
+  };
+  std::list<WaitContext> wait_contexts_;
+  std::mutex wait_contexts_mu_;
+
   // threads
   std::shared_mutex works_concurrency_rw_lock_;
   std::thread cron_thread_;
diff --git a/src/server/worker.cc b/src/server/worker.cc
index e65467f2c..dcba89c11 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -426,6 +426,7 @@ void Worker::FreeConnection(redis::Connection *conn) {
 
   removeConnection(conn->GetFD());
   srv->ResetWatchedKeys(conn);
+  srv->CleanupWaitConnection(conn);
   if (rate_limit_group_) {
     bufferevent_remove_from_rate_limit_group(conn->GetBufferEvent());
   }
diff --git a/tests/gocase/unit/wait/wait_test.go 
b/tests/gocase/unit/wait/wait_test.go
new file mode 100644
index 000000000..1030fc8a7
--- /dev/null
+++ b/tests/gocase/unit/wait/wait_test.go
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package wait
+
+import (
+       "context"
+       "strings"
+       "testing"
+       "time"
+
+       "github.com/apache/kvrocks/tests/gocase/util"
+       "github.com/stretchr/testify/require"
+)
+
+func TestWaitCommand(t *testing.T) {
+       // Start master server
+       masterSrv := util.StartServer(t, map[string]string{})
+       defer masterSrv.Close()
+
+       // Start slave server
+       slaveSrv := util.StartServer(t, map[string]string{})
+       defer slaveSrv.Close()
+
+       ctx := context.Background()
+       masterRdb := masterSrv.NewClient()
+       defer func() { require.NoError(t, masterRdb.Close()) }()
+
+       slaveRdb := slaveSrv.NewClient()
+       defer func() { require.NoError(t, slaveRdb.Close()) }()
+
+       // Set up replication
+       util.SlaveOf(t, slaveRdb, masterSrv)
+       util.WaitForSync(t, slaveRdb)
+
+       t.Run("WAIT with negative number should return error", func(t 
*testing.T) {
+               result := masterRdb.Do(ctx, "WAIT", "-1")
+               require.Error(t, result.Err())
+               require.Contains(t, result.Err().Error(), "numreplicas should 
be a positive integer")
+       })
+
+       t.Run("WAIT with invalid arguments should return error", func(t 
*testing.T) {
+               result := masterRdb.Do(ctx, "WAIT")
+               require.Error(t, result.Err())
+               require.Contains(t, result.Err().Error(), "wrong number of 
arguments")
+
+               result = masterRdb.Do(ctx, "WAIT", "1", "1000")
+               require.Error(t, result.Err())
+               require.Contains(t, result.Err().Error(), "wrong number of 
arguments")
+       })
+
+       t.Run("WAIT should not block indefinitely", func(t *testing.T) {
+               // Start a goroutine to execute WAIT
+               done := make(chan bool, 1)
+               go func() {
+                       require.NoError(t, masterRdb.Do(ctx, "SET", "k1", 
"v1").Err())
+                       require.NoError(t, masterRdb.Do(ctx, "WAIT", "1").Err())
+                       done <- true
+               }()
+
+               // Wait for the command to complete (should be immediate)
+               select {
+               case <-done:
+                       // Success - command completed immediately
+               case <-time.After(5 * time.Second):
+                       t.Fatal("WAIT command blocked indefinitely")
+               }
+       })
+
+       t.Run("WAIT should block until enough replicas acknowledge", func(t 
*testing.T) {
+               // Disconnect the slave
+               require.NoError(t, slaveRdb.Do(ctx, "SLAVEOF", "NO", 
"ONE").Err())
+
+               // Master remove the slave from the replication list 
periodically
+               // so we need to wait for the master to detect the disconnection
+               require.Eventually(t, func() bool {
+                       info := masterRdb.Info(ctx, "replication").Val()
+                       return !strings.Contains(info, "connected_slaves:1")
+               }, 50*time.Second, 100*time.Millisecond)
+
+               // Start a goroutine to execute WAIT
+               done := make(chan bool, 1)
+               go func() {
+                       require.NoError(t, masterRdb.Do(ctx, "SET", "k1", 
"v1").Err())
+                       require.NoError(t, masterRdb.Do(ctx, "WAIT", "1").Err())
+                       done <- true
+               }()
+
+               select {
+               case <-done:
+                       t.Fatal("WAIT command did not block")
+               case <-time.After(1 * time.Second):
+                       // Success - command blocked
+               }
+
+               // Reconnect the slave
+               util.SlaveOf(t, slaveRdb, masterSrv)
+               util.WaitForSync(t, slaveRdb)
+
+               // Now WAIT should complete
+               select {
+               case <-done:
+                       // Success - command completed after replica connected
+               case <-time.After(5 * time.Second):
+                       t.Fatal("WAIT command did not complete after replica 
connected")
+               }
+       })
+}

Reply via email to