This is an automated email from the ASF dual-hosted git repository.

wwbmmm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git


The following commit(s) were added to refs/heads/master by this push:
     new cbbb8236 support bthread primitive cross different worker pools (#2551)
cbbb8236 is described below

commit cbbb8236a91dee7e8cadcf9a7c5f807b353e703d
Author: Yang,Liming <[email protected]>
AuthorDate: Mon Apr 1 10:34:04 2024 +0800

    support bthread primitive cross different worker pools (#2551)
---
 src/bthread/bthread.cpp                   |   4 +-
 src/bthread/bthread.h                     |   2 +-
 src/bthread/butex.cpp                     |  53 +++++----
 src/bthread/task_control.h                |   2 +-
 src/bthread/task_group.cpp                |   4 +-
 src/bthread/task_group.h                  |   2 +-
 test/BUILD.bazel                          |   1 +
 test/bthread_butex_multi_tag_unittest.cpp | 171 ++++++++++++++++++++++++++++++
 8 files changed, 212 insertions(+), 27 deletions(-)

diff --git a/src/bthread/bthread.cpp b/src/bthread/bthread.cpp
index 07233e60..bd8c3efd 100644
--- a/src/bthread/bthread.cpp
+++ b/src/bthread/bthread.cpp
@@ -269,8 +269,8 @@ void bthread_flush() {
     }
 }
 
-int bthread_interrupt(bthread_t tid) {
-    return bthread::TaskGroup::interrupt(tid, bthread::get_task_control());
+int bthread_interrupt(bthread_t tid, bthread_tag_t tag) {
+    return bthread::TaskGroup::interrupt(tid, bthread::get_task_control(), 
tag);
 }
 
 int bthread_stop(bthread_t tid) {
diff --git a/src/bthread/bthread.h b/src/bthread/bthread.h
index a4c05867..68734e05 100644
--- a/src/bthread/bthread.h
+++ b/src/bthread/bthread.h
@@ -73,7 +73,7 @@ extern int bthread_start_background(bthread_t* __restrict tid,
 // bthread_interrupt() guarantees that Thread2 is woken up reliably no matter
 // how the 2 threads are interleaved.
 // Returns 0 on success, errno otherwise.
-extern int bthread_interrupt(bthread_t tid);
+extern int bthread_interrupt(bthread_t tid, bthread_tag_t tag = 
BTHREAD_TAG_DEFAULT);
 
 // Make bthread_stopped() on the bthread return true and interrupt the bthread.
 // Note that current bthread_stop() solely sets the built-in "stop flag" and
diff --git a/src/bthread/butex.cpp b/src/bthread/butex.cpp
index 5ac44e1b..1dbd8930 100644
--- a/src/bthread/butex.cpp
+++ b/src/bthread/butex.cpp
@@ -22,6 +22,7 @@
 #include "butil/atomicops.h"                // butil::atomic
 #include "butil/scoped_lock.h"              // BAIDU_SCOPED_LOCK
 #include "butil/macros.h"
+#include "butil/containers/flat_map.h"
 #include "butil/containers/linked_list.h"   // LinkNode
 #ifdef SHOW_BTHREAD_BUTEX_WAITER_COUNT_IN_VARS
 #include "butil/memory/singleton_on_pthread_once.h"
@@ -101,6 +102,7 @@ struct ButexBthreadWaiter : public ButexWaiter {
     Butex* initial_butex;
     TaskControl* control;
     const timespec* abstime;
+    bthread_tag_t tag;
 };
 
 // pthread_task or main_task allocates this structure on stack and queue it
@@ -272,17 +274,22 @@ void butex_destroy(void* butex) {
     butil::return_object(b);
 }
 
-inline TaskGroup* get_task_group(TaskControl* c, bool nosignal = false) {
-    TaskGroup* g = tls_task_group;
+// if TaskGroup g is belong tag
+inline bool is_same_tag(TaskGroup* g, bthread_tag_t tag) {
+    return g && g->tag() == tag;
+}
+
+inline TaskGroup* get_task_group(TaskControl* c, bthread_tag_t tag, bool 
nosignal = false) {
+    auto g = is_same_tag(tls_task_group, tag) ? tls_task_group : NULL;
     if (nosignal) {
         if (NULL == tls_task_group_nosignal) {
-            g = g ? g : c->choose_one_group();
+            g = g ? g : c->choose_one_group(tag);
             tls_task_group_nosignal = g;
         } else {
             g = tls_task_group_nosignal;
         }
     } else {
-        g = g ? g : c->choose_one_group();
+        g = g ? g : c->choose_one_group(tag);
     }
     return g;
 }
@@ -313,7 +320,7 @@ int butex_wake(void* arg, bool nosignal) {
     }
     ButexBthreadWaiter* bbw = static_cast<ButexBthreadWaiter*>(front);
     unsleep_if_necessary(bbw, get_global_timer_thread());
-    TaskGroup* g = get_task_group(bbw->control, nosignal);
+    TaskGroup* g = get_task_group(bbw->control, bbw->tag, nosignal);
     if (g == tls_task_group) {
         run_in_local_task_group(g, bbw->tid, nosignal);
     } else {
@@ -352,26 +359,32 @@ int butex_wake_all(void* arg, bool nosignal) {
     if (bthread_waiters.empty()) {
         return nwakeup;
     }
+    butil::FlatMap<bthread_tag_t, TaskGroup*> nwakeups;
+    nwakeups.init(FLAGS_task_group_ntags);
     // We will exchange with first waiter in the end.
     ButexBthreadWaiter* next = static_cast<ButexBthreadWaiter*>(
         bthread_waiters.head()->value());
     next->RemoveFromList();
     unsleep_if_necessary(next, get_global_timer_thread());
     ++nwakeup;
-    TaskGroup* g = get_task_group(next->control, nosignal);
-    const int saved_nwakeup = nwakeup;
     while (!bthread_waiters.empty()) {
         // pop reversely
         ButexBthreadWaiter* w = static_cast<ButexBthreadWaiter*>(
             bthread_waiters.tail()->value());
         w->RemoveFromList();
         unsleep_if_necessary(w, get_global_timer_thread());
+        auto g = get_task_group(w->control, w->tag, nosignal);
         g->ready_to_run_general(w->tid, true);
+        nwakeups[g->tag()] = g;
         ++nwakeup;
     }
-    if (!nosignal && saved_nwakeup != nwakeup) {
-        g->flush_nosignal_tasks_general();
+    if (!nosignal) {
+        for (auto it = nwakeups.begin(); it != nwakeups.end(); ++it) {
+            auto g = it->second;
+            g->flush_nosignal_tasks_general();
+        }
     }
+    auto g = get_task_group(next->control, next->tag, nosignal);
     if (g == tls_task_group) {
         run_in_local_task_group(g, next->tid, nosignal);
     } else {
@@ -422,21 +435,20 @@ int butex_wake_except(void* arg, bthread_t 
excluded_bthread) {
     if (bthread_waiters.empty()) {
         return nwakeup;
     }
-    ButexBthreadWaiter* front = static_cast<ButexBthreadWaiter*>(
-                bthread_waiters.head()->value());
-
-    TaskGroup* g = get_task_group(front->control);
-    const int saved_nwakeup = nwakeup;
+    butil::FlatMap<bthread_tag_t, TaskGroup*> nwakeups;
+    nwakeups.init(FLAGS_task_group_ntags);
     do {
         // pop reversely
-        ButexBthreadWaiter* w = static_cast<ButexBthreadWaiter*>(
-            bthread_waiters.tail()->value());
+        ButexBthreadWaiter* w = 
static_cast<ButexBthreadWaiter*>(bthread_waiters.tail()->value());
         w->RemoveFromList();
         unsleep_if_necessary(w, get_global_timer_thread());
+        auto g = get_task_group(w->control, w->tag);
         g->ready_to_run_general(w->tid, true);
+        nwakeups[g->tag()] = g;
         ++nwakeup;
     } while (!bthread_waiters.empty());
-    if (saved_nwakeup != nwakeup) {
+    for (auto it = nwakeups.begin(); it != nwakeups.end(); ++it) {
+        auto g = it->second;
         g->flush_nosignal_tasks_general();
     }
     return nwakeup;
@@ -473,11 +485,11 @@ int butex_requeue(void* arg, void* arg2) {
     }
     ButexBthreadWaiter* bbw = static_cast<ButexBthreadWaiter*>(front);
     unsleep_if_necessary(bbw, get_global_timer_thread());
-    TaskGroup* g = tls_task_group;
+    auto g = is_same_tag(tls_task_group, bbw->tag) ? tls_task_group : NULL;
     if (g) {
         TaskGroup::exchange(&g, front->tid);
     } else {
-        bbw->control->choose_one_group()->ready_to_run_remote(front->tid);
+        
bbw->control->choose_one_group(bbw->tag)->ready_to_run_remote(front->tid);
     }
     return 1;
 }
@@ -515,7 +527,7 @@ inline bool erase_from_butex(ButexWaiter* bw, bool wakeup, 
WaiterState state) {
     if (erased && wakeup) {
         if (bw->tid) {
             ButexBthreadWaiter* bbw = static_cast<ButexBthreadWaiter*>(bw);
-            get_task_group(bbw->control)->ready_to_run_general(bw->tid);
+            get_task_group(bbw->control, 
bbw->tag)->ready_to_run_general(bw->tid);
         } else {
             ButexPthreadWaiter* pw = static_cast<ButexPthreadWaiter*>(bw);
             wakeup_pthread(pw);
@@ -658,6 +670,7 @@ int butex_wait(void* arg, int expected_value, const 
timespec* abstime) {
     bbw.initial_butex = b;
     bbw.control = g->control();
     bbw.abstime = abstime;
+    bbw.tag = g->tag();
 
     if (abstime != NULL) {
         // Schedule timer before queueing. If the timer is triggered before
diff --git a/src/bthread/task_control.h b/src/bthread/task_control.h
index a19636ac..12598079 100644
--- a/src/bthread/task_control.h
+++ b/src/bthread/task_control.h
@@ -84,7 +84,7 @@ public:
 
     // Choose one TaskGroup (randomly right now).
     // If this method is called after init(), it never returns NULL.
-    TaskGroup* choose_one_group(bthread_tag_t tag = BTHREAD_TAG_DEFAULT);
+    TaskGroup* choose_one_group(bthread_tag_t tag);
 
 private:
     typedef std::array<TaskGroup*, BTHREAD_MAX_CONCURRENCY> TaggedGroups;
diff --git a/src/bthread/task_group.cpp b/src/bthread/task_group.cpp
index a675126d..e3d4b60a 100644
--- a/src/bthread/task_group.cpp
+++ b/src/bthread/task_group.cpp
@@ -878,7 +878,7 @@ static int set_butex_waiter(bthread_t tid, ButexWaiter* w) {
 // by race conditions.
 // TODO: bthreads created by BTHREAD_ATTR_PTHREAD blocking on bthread_usleep()
 // can't be interrupted.
-int TaskGroup::interrupt(bthread_t tid, TaskControl* c) {
+int TaskGroup::interrupt(bthread_t tid, TaskControl* c, bthread_tag_t tag) {
     // Consume current_waiter in the TaskMeta, wake it up then set it back.
     ButexWaiter* w = NULL;
     uint64_t sleep_id = 0;
@@ -906,7 +906,7 @@ int TaskGroup::interrupt(bthread_t tid, TaskControl* c) {
                 if (!c) {
                     return EINVAL;
                 }
-                c->choose_one_group()->ready_to_run_remote(tid);
+                c->choose_one_group(tag)->ready_to_run_remote(tid);
             }
         }
     }
diff --git a/src/bthread/task_group.h b/src/bthread/task_group.h
index d8598678..b71994a0 100644
--- a/src/bthread/task_group.h
+++ b/src/bthread/task_group.h
@@ -173,7 +173,7 @@ public:
 
     // Wake up blocking ops in the thread.
     // Returns 0 on success, errno otherwise.
-    static int interrupt(bthread_t tid, TaskControl* c);
+    static int interrupt(bthread_t tid, TaskControl* c, bthread_tag_t tag);
 
     // Get the meta associate with the task.
     static TaskMeta* address_meta(bthread_t tid);
diff --git a/test/BUILD.bazel b/test/BUILD.bazel
index 3bf7cd45..345043fc 100644
--- a/test/BUILD.bazel
+++ b/test/BUILD.bazel
@@ -227,6 +227,7 @@ cc_test(
             "bthread_setconcurrency_unittest.cpp",
             # glog CHECK die with a fatal error
             "bthread_key_unittest.cpp",
+            "bthread_butex_multi_tag_unittest.cpp",
         ],
     ),
     copts = COPTS,
diff --git a/test/bthread_butex_multi_tag_unittest.cpp 
b/test/bthread_butex_multi_tag_unittest.cpp
new file mode 100644
index 00000000..98057d4a
--- /dev/null
+++ b/test/bthread_butex_multi_tag_unittest.cpp
@@ -0,0 +1,171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gflags/gflags.h>
+#include <gtest/gtest.h>
+#include "bthread/bthread.h"
+#include "bthread/condition_variable.h"
+#include "bthread/countdown_event.h"
+#include "bthread/mutex.h"
+
+DECLARE_int32(task_group_ntags);
+
+int main(int argc, char* argv[]) {
+    FLAGS_task_group_ntags = 3;
+    testing::InitGoogleTest(&argc, argv);
+    GFLAGS_NS::ParseCommandLineFlags(&argc, &argv, true);
+    return RUN_ALL_TESTS();
+}
+
+namespace {
+
+std::vector<bthread_tag_t> butex_wake_return(2, 0);
+
+void* butex_wake_func(void* arg) {
+    auto mutex = static_cast<bthread::Mutex*>(arg);
+    butex_wake_return.push_back(bthread_self_tag());
+    mutex->lock();
+    butex_wake_return.push_back(bthread_self_tag());
+    mutex->unlock();
+    return nullptr;
+}
+
+TEST(BthreadButexMultiTest, butex_wake) {
+    bthread::Mutex mutex;
+    mutex.lock();
+    bthread_t tid1;
+    bthread_attr_t attr = BTHREAD_ATTR_NORMAL;
+    attr.tag = 1;
+    bthread_start_urgent(&tid1, &attr, butex_wake_func, &mutex);
+    mutex.unlock();
+    bthread_join(tid1, nullptr);
+    ASSERT_EQ(butex_wake_return[0], butex_wake_return[1]);
+}
+
+std::vector<bthread_tag_t> butex_wake_all_return1(2, 0);
+std::vector<bthread_tag_t> butex_wake_all_return2(2, 0);
+
+struct ButexWakeAllArgs {
+    bthread::CountdownEvent* ev;
+    bthread::CountdownEvent* ack;
+};
+
+void* butex_wake_all_func1(void* arg) {
+    auto p = static_cast<ButexWakeAllArgs*>(arg);
+    auto ev = p->ev;
+    auto ack = p->ack;
+    butex_wake_all_return1.push_back(bthread_self_tag());
+    ack->signal();
+    ev->wait();
+    butex_wake_all_return1.push_back(bthread_self_tag());
+    return nullptr;
+}
+
+void* butex_wake_all_func2(void* arg) {
+    auto p = static_cast<ButexWakeAllArgs*>(arg);
+    auto ev = p->ev;
+    auto ack = p->ack;
+    butex_wake_all_return2.push_back(bthread_self_tag());
+    ack->signal();
+    ev->wait();
+    butex_wake_all_return2.push_back(bthread_self_tag());
+    return nullptr;
+}
+
+TEST(BthreadButexMultiTest, butex_wake_all) {
+    bthread::CountdownEvent ev(2);
+    bthread::CountdownEvent ack(2);
+    ButexWakeAllArgs args{&ev, &ack};
+    bthread_t tid1, tid2;
+    bthread_attr_t attr1 = BTHREAD_ATTR_NORMAL;
+    attr1.tag = 1;
+    bthread_start_background(&tid1, &attr1, butex_wake_all_func1, &args);
+    bthread_attr_t attr2 = BTHREAD_ATTR_NORMAL;
+    attr2.tag = 2;
+    bthread_start_background(&tid2, &attr2, butex_wake_all_func2, &args);
+    ack.wait();
+    ev.signal(2);
+    bthread_join(tid1, nullptr);
+    bthread_join(tid2, nullptr);
+    ASSERT_EQ(butex_wake_all_return1[0], butex_wake_all_return1[1]);
+    ASSERT_EQ(butex_wake_all_return2[0], butex_wake_all_return2[1]);
+}
+
+std::vector<bthread_tag_t> butex_requeue_return1(2, 0);
+std::vector<bthread_tag_t> butex_requeue_return2(2, 0);
+
+struct ButexRequeueArgs {
+    bthread::Mutex* mutex;
+    bthread::ConditionVariable* cond;
+    bthread::CountdownEvent* ack;
+};
+
+void* butex_requeue_func1(void* arg) {
+    auto p = static_cast<ButexRequeueArgs*>(arg);
+    auto mutex = p->mutex;
+    auto cond = p->cond;
+    auto ack = p->ack;
+    butex_wake_all_return1.push_back(bthread_self_tag());
+    std::unique_lock<bthread::Mutex> lk(*mutex);
+    ack->signal();
+    cond->wait(lk);
+    butex_wake_all_return1.push_back(bthread_self_tag());
+    return nullptr;
+}
+
+void* butex_requeue_func2(void* arg) {
+    auto p = static_cast<ButexRequeueArgs*>(arg);
+    auto mutex = p->mutex;
+    auto cond = p->cond;
+    auto ack = p->ack;
+    butex_wake_all_return2.push_back(bthread_self_tag());
+    std::unique_lock<bthread::Mutex> lk(*mutex);
+    ack->signal();
+    cond->wait(lk);
+    butex_wake_all_return2.push_back(bthread_self_tag());
+    return nullptr;
+}
+
+TEST(BthreadButexMultiTest, butex_requeue) {
+    bthread::Mutex mutex;
+    bthread::ConditionVariable cond;
+    bthread::CountdownEvent ack(2);
+    ButexRequeueArgs args{&mutex, &cond, &ack};
+
+    bthread_t tid1, tid2;
+    bthread_attr_t attr1 = BTHREAD_ATTR_NORMAL;
+    attr1.tag = 1;
+    bthread_start_background(&tid1, &attr1, butex_requeue_func1, &args);
+    bthread_attr_t attr2 = BTHREAD_ATTR_NORMAL;
+    attr2.tag = 2;
+    bthread_start_background(&tid2, &attr2, butex_requeue_func2, &args);
+    ack.wait();
+    {
+        std::unique_lock<bthread::Mutex> lk(mutex);
+        cond.notify_all();
+    }
+    {
+        std::unique_lock<bthread::Mutex> lk(mutex);
+        cond.notify_all();
+    }
+    bthread_join(tid1, nullptr);
+    bthread_join(tid2, nullptr);
+    ASSERT_EQ(butex_wake_all_return1[0], butex_wake_all_return1[1]);
+    ASSERT_EQ(butex_wake_all_return2[0], butex_wake_all_return2[1]);
+}
+
+}  // namespace


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to