This is an automated email from the ASF dual-hosted git repository.

chenBright pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git


The following commit(s) were added to refs/heads/master by this push:
     new eb31fa52 Make _state atomic to prevent concurrent _read_buf mutation 
on TCP fallback (#3347)
eb31fa52 is described below

commit eb31fa5260c4bf43dd615fb575e50ad45a0fbbfc
Author: Bright Chen <[email protected]>
AuthorDate: Thu Jun 25 10:51:43 2026 +0800

    Make _state atomic to prevent concurrent _read_buf mutation on TCP fallback 
(#3347)
    
    RdmaEndpoint::_state was a plain enum, written by the handshake bthread
    and read concurrently by the event-dispatching thread (OnNewDataFromTcp).
    This is a data race, and on a weak memory model it can let the two threads
    concurrently mutate _socket->_read_buf.
    
    Make _state a butil::atomic<State>:
    - Terminal-state stores use release and the matching loads use acquire,
      so data published before a terminal state (the magic bytes put back
      into _read_buf, and the RDMA window/resource setup before ESTABLISHED)
      is visible to the reader.
    
    - Non-terminal handshake transitions use relaxed.
---
 src/brpc/rdma/rdma_endpoint.cpp | 95 ++++++++++++++++++++++-------------------
 src/brpc/rdma/rdma_endpoint.h   |  4 +-
 2 files changed, 53 insertions(+), 46 deletions(-)

diff --git a/src/brpc/rdma/rdma_endpoint.cpp b/src/brpc/rdma/rdma_endpoint.cpp
index 658c7a2f..a5016dc7 100644
--- a/src/brpc/rdma/rdma_endpoint.cpp
+++ b/src/brpc/rdma/rdma_endpoint.cpp
@@ -161,7 +161,7 @@ RdmaEndpoint::~RdmaEndpoint() {
 void RdmaEndpoint::Reset() {
     DeallocateResources();
 
-    _state = UNINIT;
+    _state.store(UNINIT, butil::memory_order_relaxed);
     _resource = NULL;
     _send_cq_events = 0;
     _recv_cq_events = 0;
@@ -195,7 +195,8 @@ void RdmaConnect::StartConnect(const Socket* socket,
         return;
     }
     if (!IsRdmaAvailable()) {
-        rdma_transport->_rdma_ep->_state = RdmaEndpoint::FALLBACK_TCP;
+        rdma_transport->_rdma_ep->_state.store(RdmaEndpoint::FALLBACK_TCP,
+                                               butil::memory_order_relaxed);
         rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
         done(0, data);
         return;
@@ -206,7 +207,8 @@ void RdmaConnect::StartConnect(const Socket* socket,
     bthread_attr_t attr = BTHREAD_ATTR_NORMAL;
     bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtClient");
     if (bthread_start_background(&tid, &attr,
-                RdmaEndpoint::ProcessHandshakeAtClient, 
rdma_transport->_rdma_ep) < 0) {
+                                 RdmaEndpoint::ProcessHandshakeAtClient,
+                                 rdma_transport->_rdma_ep) < 0) {
         LOG(FATAL) << "Fail to start handshake bthread";
         Run();
     } else {
@@ -230,7 +232,7 @@ static void TryReadOnTcpDuringRdmaEst(Socket* s) {
                 const int saved_errno = errno;
                 PLOG(WARNING) << "Fail to read from " << s;
                 s->SetFailed(saved_errno, "Fail to read from %s: %s",
-                        s->description().c_str(), berror(saved_errno));
+                             s->description().c_str(), berror(saved_errno));
                 return;
             }
             if (!s->MoreReadEvents(&progress)) {
@@ -255,22 +257,22 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) {
 
     int progress = Socket::PROGRESS_INIT;
     while (true) {
-        if (ep->_state == UNINIT) {
+        State state = ep->_state.load(butil::memory_order_acquire);
+        if (state == UNINIT) {
             if (!m->CreatedByConnect()) {
                 if (!IsRdmaAvailable()) {
-                    ep->_state = FALLBACK_TCP;
                     rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
+                    ep->_state.store(FALLBACK_TCP, 
butil::memory_order_relaxed);
                     continue;
                 }
                 bthread_t tid;
-                ep->_state = S_HELLO_WAIT;
+                ep->_state.store(S_HELLO_WAIT, butil::memory_order_relaxed);
                 SocketUniquePtr s;
                 m->ReAddress(&s);
                 bthread_attr_t attr = BTHREAD_ATTR_NORMAL;
                 bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtServer");
-                if (bthread_start_background(&tid, &attr,
-                            ProcessHandshakeAtServer, ep) < 0) {
-                    ep->_state = UNINIT;
+                if (bthread_start_background(&tid, &attr, 
ProcessHandshakeAtServer, ep) < 0) {
+                    ep->_state.store(UNINIT, butil::memory_order_relaxed);
                     LOG(FATAL) << "Fail to start handshake bthread";
                 } else {
                     s.release();
@@ -280,13 +282,13 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) {
                 // starts handshake. This will be handled by client handshake.
                 // Ignore the exception here.
             }
-        } else if (ep->_state < ESTABLISHED) {  // during handshake
+        } else if (state < ESTABLISHED) {  // during handshake
             ep->_read_butex->fetch_add(1, butil::memory_order_release);
             bthread::butex_wake(ep->_read_butex);
-        } else if (ep->_state == FALLBACK_TCP){  // handshake finishes
+        } else if (state == FALLBACK_TCP){  // handshake finishes
             InputMessenger::OnNewMessages(m);
             return;
-        } else if (ep->_state == ESTABLISHED) {
+        } else if (state == ESTABLISHED) {
             TryReadOnTcpDuringRdmaEst(ep->_socket);
             return;
         }
@@ -422,9 +424,10 @@ int RdmaEndpoint::WriteToFd(butil::IOBuf* data) {
 
 inline void RdmaEndpoint::TryReadOnTcp() {
     if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) {
-        if (_state == FALLBACK_TCP) {
+        State state = _state.load(butil::memory_order_acquire);
+        if (state == FALLBACK_TCP) {
             InputMessenger::OnNewMessages(_socket);
-        } else if (_state == ESTABLISHED) {
+        } else if (state == ESTABLISHED) {
             TryReadOnTcpDuringRdmaEst(_socket);
         }
     }
@@ -475,28 +478,28 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
     ep->_handshake_version = handshake->ProtocolVersion();
 
     // First initialize CQ and QP resources.
-    ep->_state = C_ALLOC_QPCQ;
+    ep->_state.store(C_ALLOC_QPCQ, butil::memory_order_relaxed);
     if (ep->AllocateResources() < 0) {
         LOG(WARNING) << "Fallback to tcp:" << s->description();
         rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
-        ep->_state = FALLBACK_TCP;
+        ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
         return NULL;
     }
 
     // Send hello message to server
-    ep->_state = C_HELLO_SEND;
+    ep->_state.store(C_HELLO_SEND, butil::memory_order_relaxed);
     if (handshake->SendLocalHello() < 0) {
         int saved_errno = errno;
         PLOG(WARNING) << "Fail to send hello message to server:"
                       << s->description();
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
 
     // Receive and parse remote hello.
-    ep->_state = C_HELLO_WAIT;
+    ep->_state.store(C_HELLO_WAIT, butil::memory_order_relaxed);
     ParsedHello remote{};
     bool negotiated = false;
     if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) {
@@ -505,7 +508,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
                       << s->description();
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
 
@@ -515,7 +518,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
         rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
     } else {
         ep->ApplyRemoteHello(remote);
-        ep->_state = C_BRINGUP_QP;
+        ep->_state.store(C_BRINGUP_QP, butil::memory_order_relaxed);
         if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) {
             LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
                          << s->description();
@@ -526,8 +529,9 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
     }
 
     // Send ACK message to server
-    ep->_state = C_ACK_SEND;
-    uint32_t flags = rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF ? 
HELLO_ACK_RDMA_OK : 0;
+    ep->_state.store(C_ACK_SEND, butil::memory_order_relaxed);
+    bool rdma_on = rdma_transport->_rdma_state == RdmaTransport::RDMA_ON;
+    uint32_t flags = rdma_on ? HELLO_ACK_RDMA_OK : 0;
     uint32_t flags_be = butil::HostToNet32(flags);
     if (ep->WriteToFd(&flags_be, HELLO_ACK_LEN) < 0) {
         int saved_errno = errno;
@@ -535,17 +539,17 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
                       << s->description();
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
 
     if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) {
-        ep->_state = ESTABLISHED;
+        ep->_state.store(ESTABLISHED, butil::memory_order_release);
         LOG_IF(INFO, FLAGS_rdma_trace_verbose)
             << "Client handshake ends (use rdma v" << ep->_handshake_version
             << ") on " << s->description();
     } else {
-        ep->_state = FALLBACK_TCP;
+        ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
         LOG_IF(INFO, FLAGS_rdma_trace_verbose)
             << "Client handshake ends (use tcp) on " << s->description();
     }
@@ -578,7 +582,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
     LOG_IF(INFO, FLAGS_rdma_trace_verbose)
         << "Start handshake on " << s->description();
 
-    ep->_state = S_HELLO_WAIT;
+    ep->_state.store(S_HELLO_WAIT, butil::memory_order_relaxed);
     uint8_t magic[MAGIC_STR_LEN];
     if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) {
         int saved_errno = errno;
@@ -586,7 +590,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
                       << s->description() << " " << s->_remote_side;
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
 
@@ -598,8 +602,11 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
             << s->description();
         // We need to copy data read back to _socket->_read_buf.
         s->_read_buf.append(magic, MAGIC_STR_LEN);
-        ep->_state = FALLBACK_TCP;
         rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
+        // Use release memory order to publish the magic bytes appended
+        // above to whoever reads `_state == FALLBACK_TCP` (the event
+        // thread in OnNewDataFromTcp).
+        ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
         ep->TryReadOnTcp();
         return NULL;
     }
@@ -614,7 +621,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
                       << s->description();
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
 
@@ -624,13 +631,13 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
         rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
     } else {
         ep->ApplyRemoteHello(remote);
-        ep->_state = S_ALLOC_QPCQ;
+        ep->_state.store(S_ALLOC_QPCQ, butil::memory_order_relaxed);
         if (ep->AllocateResources() < 0) {
             LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:"
                          << s->description();
             rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
         } else {
-            ep->_state = S_BRINGUP_QP;
+            ep->_state.store(S_BRINGUP_QP, butil::memory_order_relaxed);
             if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) {
                 LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
                              << s->description();
@@ -639,18 +646,18 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
         }
     }
 
-    ep->_state = S_HELLO_SEND;
+    ep->_state.store(S_HELLO_SEND, butil::memory_order_relaxed);
     if (handshake->SendLocalHello() < 0) {
         int saved_errno = errno;
         PLOG(WARNING) << "Fail to send Hello Message to client:"
                       << s->description();
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
 
-    ep->_state = S_ACK_WAIT;
+    ep->_state.store(S_ACK_WAIT, butil::memory_order_relaxed);
     uint32_t flags_be = 0;
     if (ep->ReadFromFd(&flags_be, HELLO_ACK_LEN) < 0) {
         int saved_errno = errno;
@@ -658,12 +665,11 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
                       << s->description();
         s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: 
%s",
                      s->description().c_str(), berror(saved_errno));
-        ep->_state = FAILED;
+        ep->_state.store(FAILED, butil::memory_order_relaxed);
         return NULL;
     }
     uint32_t flags = butil::NetToHost32(flags_be);
     bool client_ack_ok = (flags & HELLO_ACK_RDMA_OK) != 0;
-
     if (client_ack_ok) {
         if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) {
             // Client asked for RDMA but we are falling back: protocol
@@ -673,17 +679,17 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
                          << "RDMA_OFF state: " << s->description();
             s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s",
                          s->description().c_str(), berror(EPROTO));
-            ep->_state = FAILED;
+            ep->_state.store(FAILED, butil::memory_order_relaxed);
             return NULL;
         }
         rdma_transport->_rdma_state = RdmaTransport::RDMA_ON;
-        ep->_state = ESTABLISHED;
+        ep->_state.store(ESTABLISHED, butil::memory_order_release);
         LOG_IF(INFO, FLAGS_rdma_trace_verbose)
             << "Server handshake ends (use rdma v" << ep->_handshake_version
             << ") on " << s->description();
     } else {
         rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
-        ep->_state = FALLBACK_TCP;
+        ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
         LOG_IF(INFO, FLAGS_rdma_trace_verbose)
             << "Server handshake ends (use tcp) on " << s->description();
     }
@@ -712,7 +718,8 @@ private:
     // blocks or first max_len bytes.
     // Return: the bytes included in the sglist, or -1 if failed
     ssize_t cut_into_sglist_and_iobuf(ibv_sge* sglist, size_t* sge_index,
-            butil::IOBuf* to, size_t max_sge, size_t max_len) {
+                                      butil::IOBuf* to, size_t max_sge,
+                                      size_t max_len) {
         size_t len = 0;
         while (*sge_index < max_sge) {
             if (len == max_len || _ref_num() == 0) {
@@ -967,7 +974,7 @@ ssize_t RdmaEndpoint::HandleCompletion(ibv_wc& wc) {
             if (wc.byte_len < (uint32_t)FLAGS_rdma_zerocopy_min_size) {
                 zerocopy = false;
             }
-            CHECK(_state != FALLBACK_TCP);
+            CHECK_NE(_state.load(butil::memory_order_acquire), FALLBACK_TCP);
             if (zerocopy) {
                 _rbuf[_rq_received].cutn(&_socket->_read_buf, wc.byte_len);
             } else {
@@ -1586,7 +1593,7 @@ void RdmaEndpoint::PollCq(Socket* m) {
 }
 
 std::string RdmaEndpoint::GetStateStr() const {
-    switch (_state) {
+    switch (_state.load(butil::memory_order_relaxed)) {
     case UNINIT: return "UNINIT";
     case C_ALLOC_QPCQ: return "C_ALLOC_QPCQ";
     case C_HELLO_SEND: return "C_HELLO_SEND";
diff --git a/src/brpc/rdma/rdma_endpoint.h b/src/brpc/rdma/rdma_endpoint.h
index 7b6652bc..41c33824 100644
--- a/src/brpc/rdma/rdma_endpoint.h
+++ b/src/brpc/rdma/rdma_endpoint.h
@@ -250,7 +250,7 @@ private:
     std::string GetStateStr() const;
 
     // Try to read data on TCP fd in _socket
-    inline void TryReadOnTcp();
+    void TryReadOnTcp();
 
     // Add cq socket id to poller
     void PollerAddCqSid();
@@ -262,7 +262,7 @@ private:
     Socket* _socket;
 
     // State of Handshake
-    State _state;
+    butil::atomic<State> _state;
 
     // Wire-level handshake protocol version (set by dispatch in
     // ProcessHandshakeAtClient/Server). Aligned with the protocol code:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to