This is an automated email from the ASF dual-hosted git repository.

wwbmmm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git


The following commit(s) were added to refs/heads/master by this push:
     new 337142f2 Support on_failed callback for streaming rpc (#2565)
337142f2 is described below

commit 337142f29ef1d6c67a43178f3f26fc6964c21270
Author: Bright Chen <chenguangmin...@foxmail.com>
AuthorDate: Mon Apr 8 11:06:17 2024 +0800

    Support on_failed callback for streaming rpc (#2565)
    
    * Support on_failed callback for streaming rpc
    
    * Call on_failed before on_closed
---
 .gitignore                             |   3 +
 src/brpc/controller.cpp                |   3 +-
 src/brpc/policy/baidu_rpc_protocol.cpp |  10 ++--
 src/brpc/socket.cpp                    |  10 ++--
 src/brpc/socket.h                      |   2 +-
 src/brpc/stream.cpp                    |  50 +++++++++++-----
 src/brpc/stream.h                      |   9 ++-
 src/brpc/stream_impl.h                 |  11 +++-
 test/brpc_streaming_rpc_unittest.cpp   | 101 +++++++++++++++++++++++++++------
 9 files changed, 151 insertions(+), 48 deletions(-)

diff --git a/.gitignore b/.gitignore
index bdd21009..9a60889a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -44,3 +44,6 @@ CTestTestfile.cmake
 /test/curl.out
 /test/out.txt
 /test/recordio_ref.io
+
+# Ignore protoc-gen-mcpack files
+/protoc-gen-mcpack*/
\ No newline at end of file
diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp
index 825c05cd..42f507fd 100644
--- a/src/brpc/controller.cpp
+++ b/src/brpc/controller.cpp
@@ -1387,7 +1387,8 @@ void Controller::HandleStreamConnection(Socket 
*host_socket) {
         }
     }
     if (FailedInline()) {
-        Stream::SetFailed(_request_stream);
+        Stream::SetFailed(_request_stream, _error_code,
+                          "%s", _error_text.c_str());
         if (_remote_stream_settings != NULL) {
             policy::SendStreamRst(host_socket,
                                   _remote_stream_settings->stream_id());
diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp 
b/src/brpc/policy/baidu_rpc_protocol.cpp
index b19fbb37..6fa17d6c 100644
--- a/src/brpc/policy/baidu_rpc_protocol.cpp
+++ b/src/brpc/policy/baidu_rpc_protocol.cpp
@@ -254,11 +254,13 @@ void SendRpcResponse(int64_t correlation_id,
                            accessor.remote_stream_settings()->stream_id(),
                            accessor.response_stream()) != 0) {
             const int errcode = errno;
-            PLOG_IF(WARNING, errcode != EPIPE) << "Fail to write into " << 
*sock;
-            cntl->SetFailed(errcode, "Fail to write into %s",
-                            sock->description().c_str());
+            std::string error_text = butil::string_printf(64, "Fail to write 
into %s",
+                                                          
sock->description().c_str());
+            PLOG_IF(WARNING, errcode != EPIPE) << error_text;
+            cntl->SetFailed(errcode,  "%s", error_text.c_str());
             if(stream_ptr) {
-                ((Stream*)stream_ptr->conn())->Close();
+                ((Stream*)stream_ptr->conn())->Close(errcode, "%s",
+                                                     error_text.c_str());
             }
             return;
         }
diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp
index bf06a871..c2e3095d 100644
--- a/src/brpc/socket.cpp
+++ b/src/brpc/socket.cpp
@@ -986,7 +986,7 @@ int Socket::SetFailed(int error_code, const char* 
error_fmt, ...) {
                          &_id_wait_list, error_code, error_text,
                          &_id_wait_list_mutex));
 
-            ResetAllStreams();
+            ResetAllStreams(error_code, error_text);
             // _app_connect shouldn't be set to NULL in SetFailed otherwise
             // HC is always not supported.
             // FIXME: Design a better interface for AppConnect
@@ -2541,7 +2541,7 @@ int Socket::RemoveStream(StreamId stream_id) {
     return 0;
 }
 
-void Socket::ResetAllStreams() {
+void Socket::ResetAllStreams(int error_code, const std::string& error_text) {
     DCHECK(Failed());
     std::set<StreamId> saved_stream_set;
     _stream_mutex.lock();
@@ -2552,9 +2552,9 @@ void Socket::ResetAllStreams() {
         saved_stream_set.swap(*_stream_set);
     }
     _stream_mutex.unlock();
-    for (std::set<StreamId>::const_iterator 
-            it = saved_stream_set.begin(); it != saved_stream_set.end(); ++it) 
{
-        Stream::SetFailed(*it);
+    for (auto it = saved_stream_set.begin();
+            it != saved_stream_set.end(); ++it) {
+        Stream::SetFailed(*it, error_code, "%s", error_text.c_str());
     }
 }
 
diff --git a/src/brpc/socket.h b/src/brpc/socket.h
index 9d85aafa..faf6baac 100644
--- a/src/brpc/socket.h
+++ b/src/brpc/socket.h
@@ -706,7 +706,7 @@ friend void DereferenceSocket(Socket*);
     // broken socket.
     int AddStream(StreamId stream_id);
     int RemoveStream(StreamId stream_id);
-    void ResetAllStreams();
+    void ResetAllStreams(int error_code, const std::string& error_text);
 
     bool ValidFileDescriptor(int fd);
 
diff --git a/src/brpc/stream.cpp b/src/brpc/stream.cpp
index 6eb4413d..27f87c6f 100644
--- a/src/brpc/stream.cpp
+++ b/src/brpc/stream.cpp
@@ -44,6 +44,7 @@ Stream::Stream()
     , _fake_socket_weak_ref(NULL)
     , _connected(false)
     , _closed(false)
+    , _error_code(0)
     , _produced(0)
     , _remote_consumed(0)
     , _cur_buf_size(0)
@@ -74,6 +75,7 @@ int Stream::Create(const StreamOptions &options,
     s->_connected = false;
     s->_options = options;
     s->_closed = false;
+    s->_error_code = 0;
     s->_cur_buf_size = options.max_buf_size > 0 ? options.max_buf_size : 0;
     if (options.max_buf_size > 0 && options.min_buf_size > 
options.max_buf_size) {
         // set 0 if min_buf_size is invalid.
@@ -131,7 +133,7 @@ void Stream::BeforeRecycle(Socket *) {
     if (_host_socket) {
         _host_socket->RemoveStream(id());
     }
-    
+
     // The instance is to be deleted in the consumer thread
     bthread::execution_queue_stop(_consumer_queue);
 }
@@ -466,21 +468,22 @@ int Stream::OnReceived(const StreamFrameMeta& fm, 
butil::IOBuf *buf, Socket* soc
         if (!fm.has_continuation()) {
             butil::IOBuf *tmp = _pending_buf;
             _pending_buf = NULL;
-            if (bthread::execution_queue_execute(_consumer_queue, tmp) != 0) {
+            int rc = bthread::execution_queue_execute(_consumer_queue, tmp);
+            if (rc != 0) {
                 CHECK(false) << "Fail to push into channel";
                 delete tmp;
-                Close();
+                Close(rc, "Fail to push into channel");
             }
         }
         break;
     case FRAME_TYPE_RST:
         RPC_VLOG << "stream=" << id() << " received rst frame";
-        Close();
+        Close(ECONNRESET, "Received RST frame");
         break;
     case FRAME_TYPE_CLOSE:
         RPC_VLOG << "stream=" << id() << " received close frame";
         // TODO:: See the comments in Consume
-        Close();
+        Close(0, "Received CLOSE frame");
         break;
     case FRAME_TYPE_UNKNOWN:
         RPC_VLOG << "Received unknown frame";
@@ -530,15 +533,26 @@ int Stream::Consume(void *meta, 
bthread::TaskIterator<butil::IOBuf*>& iter) {
     Stream* s = (Stream*)meta;
     s->StopIdleTimer();
     if (iter.is_queue_stopped()) {
-        // indicating the queue was closed
+        scoped_ptr<Stream> recycled_stream(s);
+        // Indicating the queue was closed.
         if (s->_host_socket) {
             DereferenceSocket(s->_host_socket);
             s->_host_socket = NULL;
         }
         if (s->_options.handler != NULL) {
+            int error_code;
+            std::string error_text;
+            {
+                BAIDU_SCOPED_LOCK(s->_connect_mutex);
+                error_code = s->_error_code;
+                error_text = s->_error_text;
+            }
+            if (error_code != 0) {
+                // The stream is closed abnormally.
+                s->_options.handler->on_failed(s->id(), error_code, 
error_text);
+            }
             s->_options.handler->on_closed(s->id());
         }
-        delete s;
         return 0;
     }
     DEFINE_SMALL_ARRAY(butil::IOBuf*, buf_list, s->_options.messages_in_batch, 
256);
@@ -630,7 +644,7 @@ void Stream::StopIdleTimer() {
     }
 }
 
-void Stream::Close() {
+void Stream::Close(int error_code, const char* reason_fmt, ...) {
     _fake_socket_weak_ref->SetFailed();
     bthread_mutex_lock(&_connect_mutex);
     if (_closed) {
@@ -638,6 +652,13 @@ void Stream::Close() {
         return;
     }
     _closed = true;
+    _error_code = error_code;
+
+    va_list ap;
+    va_start(ap, reason_fmt);
+    butil::string_vappendf(&_error_text, reason_fmt, ap);
+    va_end(ap);
+
     if (_connected) {
         bthread_mutex_unlock(&_connect_mutex);
         return;
@@ -647,14 +668,17 @@ void Stream::Close() {
     return TriggerOnConnectIfNeed();
 }
 
-int Stream::SetFailed(StreamId id) {
+int Stream::SetFailed(StreamId id, int error_code, const char* reason_fmt, 
...) {
     SocketUniquePtr ptr;
     if (Socket::AddressFailedAsWell(id, &ptr) == -1) {
         // Don't care recycled stream
         return 0;
     }
     Stream* s = (Stream*)ptr->conn();
-    s->Close();
+    va_list ap;
+    va_start(ap, reason_fmt);
+    s->Close(error_code, reason_fmt, ap);
+    va_end(ap);
     return 0;
 }
 
@@ -665,13 +689,13 @@ void Stream::HandleRpcResponse(butil::IOBuf* 
response_buffer) {
     ParseResult pr = policy::ParseRpcMessage(response_buffer, NULL, true, 
NULL);
     if (!pr.is_ok()) {
         CHECK(false);
-        Close();
+        Close(EPROTO, "Fail to parse rpc response message");
         return;
     }
     InputMessageBase* msg = pr.message();
     if (msg == NULL) {
         CHECK(false);
-        Close();
+        Close(ENOMEM, "Message is NULL");
         return;
     }
     _host_socket->PostponeEOF();
@@ -730,7 +754,7 @@ int StreamWait(StreamId stream_id, const timespec* 
due_time) {
 }
 
 int StreamClose(StreamId stream_id) {
-    return Stream::SetFailed(stream_id);
+    return Stream::SetFailed(stream_id, 0, "Local close");
 }
 
 int StreamCreate(StreamId *request_stream, Controller &cntl,
diff --git a/src/brpc/stream.h b/src/brpc/stream.h
index 90965f37..f222ba09 100644
--- a/src/brpc/stream.h
+++ b/src/brpc/stream.h
@@ -44,7 +44,11 @@ public:
                                      butil::IOBuf *const messages[], 
                                      size_t size) = 0;
     virtual void on_idle_timeout(StreamId id) = 0;
-    virtual void on_closed(StreamId id) = 0; 
+    virtual void on_closed(StreamId id) = 0;
+    // `on_failed` will be called  before `on_closed`
+    // when the stream is closed abnormally.
+    virtual void on_failed(StreamId id, int error_code,
+                           const std::string& error_text) {}
 };
 
 struct StreamOptions {
@@ -82,8 +86,7 @@ struct StreamOptions {
     StreamInputHandler* handler;
 };
 
-struct StreamWriteOptions
-{
+struct StreamWriteOptions {
     StreamWriteOptions() : write_in_background(false) {}
 
     // Write message to socket in background thread.
diff --git a/src/brpc/stream_impl.h b/src/brpc/stream_impl.h
index f24b75a3..db92dd63 100644
--- a/src/brpc/stream_impl.h
+++ b/src/brpc/stream_impl.h
@@ -61,13 +61,16 @@ public:
                     const timespec *due_time);
     int Wait(const timespec* due_time);
     void FillSettings(StreamSettings *settings);
-    static int SetFailed(StreamId id);
-    void Close();
+    static int SetFailed(StreamId id, int error_code, const char* reason_fmt, 
...)
+        __attribute__ ((__format__ (__printf__, 3, 4)));
+    void Close(int error_code, const char* reason_fmt, ...)
+        __attribute__ ((__format__ (__printf__, 3, 4)));
 
 private:
 friend void StreamWait(StreamId stream_id, const timespec *due_time,
-                void (*on_writable)(StreamId, void*, int), void *arg);
+                       void (*on_writable)(StreamId, void*, int), void *arg);
 friend class MessageBatcher;
+friend struct butil::DefaultDeleter<Stream>;
     Stream();
     ~Stream();
     int Init(const StreamOptions options);
@@ -111,6 +114,8 @@ friend class MessageBatcher;
     ConnectMeta         _connect_meta;
     bool                _connected;
     bool                _closed;
+    int                 _error_code;
+    std::string         _error_text;
     
     bthread_mutex_t _congestion_control_mutex;
     size_t _produced;
diff --git a/test/brpc_streaming_rpc_unittest.cpp 
b/test/brpc_streaming_rpc_unittest.cpp
index f7e62c81..df6a37d8 100644
--- a/test/brpc_streaming_rpc_unittest.cpp
+++ b/test/brpc_streaming_rpc_unittest.cpp
@@ -20,11 +20,12 @@
 // Date: 2015/10/22 16:28:44
 
 #include <gtest/gtest.h>
-
 #include "brpc/server.h"
+
 #include "brpc/controller.h"
 #include "brpc/channel.h"
 #include "brpc/stream_impl.h"
+#include "brpc/policy/streaming_rpc_protocol.h"
 #include "echo.pb.h"
 
 class AfterAcceptStream {
@@ -69,10 +70,11 @@ private:
 
 class StreamingRpcTest : public testing::Test {
 protected:
-    test::EchoRequest request;
-    test::EchoResponse response;
     void SetUp() { request.set_message("hello world"); }
     void TearDown() {}
+
+    test::EchoRequest request;
+    test::EchoResponse response;
 };
 
 TEST_F(StreamingRpcTest, sanity) {
@@ -96,7 +98,7 @@ TEST_F(StreamingRpcTest, sanity) {
 }
 
 struct HandlerControl {
-    HandlerControl() 
+    HandlerControl()
         : block(false)
     {}
     bool block;
@@ -110,11 +112,11 @@ public:
         , _stopped(false)
         , _idle_times(0)
         , _cntl(cntl)
-    {
-    }
+    {}
+
     int on_received_messages(brpc::StreamId /*id*/,
                              butil::IOBuf *const messages[],
-                             size_t size) {
+                             size_t size) override {
         if (_cntl && _cntl->block) {
             while (_cntl->block) {
                 usleep(100);
@@ -129,15 +131,22 @@ public:
         return 0;
     }
 
-    void on_idle_timeout(brpc::StreamId /*id*/) {
+    void on_idle_timeout(brpc::StreamId /*id*/) override {
         ++_idle_times;
     }
 
-    void on_closed(brpc::StreamId /*id*/) {
+    void on_closed(brpc::StreamId /*id*/) override {
         ASSERT_FALSE(_stopped);
         _stopped = true;
     }
 
+    void on_failed(brpc::StreamId id, int error_code,
+                   const std::string& /*error_text*/) override {
+        ASSERT_FALSE(_failed);
+        ASSERT_NE(0, error_code);
+        _failed = true;
+    }
+
     bool failed() const { return _failed; }
     bool stopped() const { return _stopped; }
     int idle_times() const { return _idle_times; }
@@ -196,8 +205,8 @@ void on_writable(brpc::StreamId, void* arg, int error_code) 
{
 
 TEST_F(StreamingRpcTest, block) {
     HandlerControl hc;
-    OrderedInputHandler handler(&hc);
     hc.block = true;
+    OrderedInputHandler handler(&hc);
     brpc::StreamOptions opt;
     opt.handler = &handler;
     const int N = 10000;
@@ -216,7 +225,7 @@ TEST_F(StreamingRpcTest, block) {
     ASSERT_EQ(0, StreamCreate(&request_stream, cntl, &request_stream_options));
     test::EchoService_Stub stub(&channel);
     stub.Echo(&cntl, &request, &response, NULL);
-    ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText() << " request_stream=" 
+    ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText() << " request_stream="
                                 << request_stream;
     for (int i = 0; i < N; ++i) {
         int network = htonl(i);
@@ -295,8 +304,8 @@ TEST_F(StreamingRpcTest, block) {
 
 TEST_F(StreamingRpcTest, auto_close_if_host_socket_closed) {
     HandlerControl hc;
-    OrderedInputHandler handler(&hc);
     hc.block = true;
+    OrderedInputHandler handler(&hc);
     brpc::StreamOptions opt;
     opt.handler = &handler;
     const int N = 10000;
@@ -332,15 +341,63 @@ TEST_F(StreamingRpcTest, 
auto_close_if_host_socket_closed) {
     while (!handler.stopped()) {
         usleep(100);
     }
-    ASSERT_FALSE(handler.failed());
+    ASSERT_TRUE(handler.failed());
     ASSERT_EQ(0, handler.idle_times());
     ASSERT_EQ(0, handler._expected_next_value);
 }
 
+TEST_F(StreamingRpcTest, failed_when_rst) {
+    OrderedInputHandler handler;
+    brpc::StreamOptions opt;
+    opt.handler = &handler;
+    opt.messages_in_batch = 100;
+    brpc::Server server;
+    MyServiceWithStream service(opt);
+    ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE));
+    ASSERT_EQ(0, server.Start(9007, NULL));
+    brpc::Channel channel;
+    ASSERT_EQ(0, channel.Init("127.0.0.1:9007", NULL));
+    brpc::Controller cntl;
+    brpc::StreamId request_stream;
+    brpc::StreamOptions request_stream_options;
+    request_stream_options.max_buf_size = 0;
+    ASSERT_EQ(0, StreamCreate(&request_stream, cntl, &request_stream_options));
+    brpc::ScopedStream stream_guard(request_stream);
+    test::EchoService_Stub stub(&channel);
+    stub.Echo(&cntl, &request, &response, NULL);
+    ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText() << " request_stream=" << 
request_stream;
+    const int N = 10000;
+    for (int i = 0; i < N; ++i) {
+        int network = htonl(i);
+        butil::IOBuf out;
+        out.append(&network, sizeof(network));
+        ASSERT_EQ(0, brpc::StreamWrite(request_stream, out)) << "i=" << i;
+    }
+
+    usleep(1000 * 10);
+    {
+        brpc::SocketUniquePtr ptr;
+        ASSERT_EQ(0, brpc::Socket::Address(request_stream, &ptr));
+        brpc::Stream* s = (brpc::Stream*)ptr->conn();
+        ASSERT_TRUE(s->_host_socket != NULL);
+        brpc::policy::SendStreamRst(s->_host_socket,
+                                    s->_remote_settings.stream_id());
+    }
+    // ASSERT_EQ(0, brpc::StreamClose(request_stream));
+    server.Stop(0);
+    server.Join();
+    while (!handler.stopped() && !handler.failed()) {
+        usleep(100);
+    }
+    ASSERT_TRUE(handler.failed());
+    ASSERT_EQ(0, handler.idle_times());
+    ASSERT_EQ(N, handler._expected_next_value);
+}
+
 TEST_F(StreamingRpcTest, idle_timeout) {
     HandlerControl hc;
-    OrderedInputHandler handler(&hc);
     hc.block = true;
+    OrderedInputHandler handler(&hc);
     brpc::StreamOptions opt;
     opt.handler = &handler;
     opt.idle_timeout_ms = 2;
@@ -383,7 +440,7 @@ public:
     }
     int on_received_messages(brpc::StreamId id,
                              butil::IOBuf *const messages[],
-                             size_t size) {
+                             size_t size) override {
         if (size != 1) {
             _failed = true;
             return 0;
@@ -406,15 +463,23 @@ public:
         return 0;
     }
 
-    void on_idle_timeout(brpc::StreamId /*id*/) {
+    void on_idle_timeout(brpc::StreamId /*id*/) override {
         ++_idle_times;
     }
 
-    void on_closed(brpc::StreamId /*id*/) {
+    void on_closed(brpc::StreamId /*id*/) override {
         ASSERT_FALSE(_stopped);
         _stopped = true;
     }
 
+
+    void on_failed(brpc::StreamId id, int error_code,
+                   const std::string& /*error_text*/) override {
+        ASSERT_FALSE(_failed);
+        ASSERT_NE(0, error_code);
+        _failed = true;
+    }
+
     bool failed() const { return _failed; }
     bool stopped() const { return _stopped; }
     int idle_times() const { return _idle_times; }
@@ -493,9 +558,9 @@ TEST_F(StreamingRpcTest, server_send_data_before_run_done) {
     ASSERT_EQ(0, channel.Init("127.0.0.1:9007", NULL));
     OrderedInputHandler handler;
     brpc::StreamOptions request_stream_options;
+    request_stream_options.handler = &handler;
     brpc::StreamId request_stream;
     brpc::Controller cntl;
-    request_stream_options.handler = &handler;
     ASSERT_EQ(0, StreamCreate(&request_stream, cntl, &request_stream_options));
     brpc::ScopedStream stream_guard(request_stream);
     test::EchoService_Stub stub(&channel);


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@brpc.apache.org
For additional commands, e-mail: dev-h...@brpc.apache.org

Reply via email to