This is an automated email from the ASF dual-hosted git repository.
guangmingchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git
The following commit(s) were added to refs/heads/master by this push:
new ba9e8388 Fix thread safety of Wrapper (#2952)
ba9e8388 is described below
commit ba9e8388b670f7df0eae7a9bb51da5fac08435e2
Author: Bright Chen <[email protected]>
AuthorDate: Mon Apr 21 09:52:24 2025 +0800
Fix thread safety of Wrapper (#2952)
---
src/brpc/policy/http_rpc_protocol.cpp | 6 +-
src/brpc/server.cpp | 2 +-
src/butil/containers/doubly_buffered_data.h | 152 ++++++++++++++--------------
src/json2pb/pb_to_json.cpp | 4 +-
4 files changed, 78 insertions(+), 86 deletions(-)
diff --git a/src/brpc/policy/http_rpc_protocol.cpp
b/src/brpc/policy/http_rpc_protocol.cpp
index e0ff1e31..007bce39 100644
--- a/src/brpc/policy/http_rpc_protocol.cpp
+++ b/src/brpc/policy/http_rpc_protocol.cpp
@@ -331,11 +331,7 @@ static bool ProtoMessageToProtoJson(const
google::protobuf::Message& message,
butil::IOBufAsZeroCopyOutputStream*
wrapper,
Controller* cntl, int error_code) {
json2pb::Pb2ProtoJsonOptions options;
-#if GOOGLE_PROTOBUF_VERSION >= 5026002
- options.always_print_fields_with_no_presence =
cntl->has_always_print_primitive_fields();
-#else
- options.always_print_primitive_fields =
cntl->has_always_print_primitive_fields();
-#endif
+ AlwaysPrintPrimitiveFields(options) =
cntl->has_always_print_primitive_fields();
options.always_print_enums_as_ints = FLAGS_pb_enum_as_number;
std::string error;
bool ok = json2pb::ProtoMessageToProtoJson(message, wrapper, options,
&error);
diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp
index b6c3d3ee..aa55c858 100644
--- a/src/brpc/server.cpp
+++ b/src/brpc/server.cpp
@@ -1972,7 +1972,7 @@ bool IsDummyServerRunning() {
}
const Server::MethodProperty*
-Server::FindMethodPropertyByFullName(const butil::StringPiece&fullname) const
{
+Server::FindMethodPropertyByFullName(const butil::StringPiece& fullname) const
{
return _method_map.seek(fullname);
}
diff --git a/src/butil/containers/doubly_buffered_data.h
b/src/butil/containers/doubly_buffered_data.h
index 5aacece3..ff96a903 100644
--- a/src/butil/containers/doubly_buffered_data.h
+++ b/src/butil/containers/doubly_buffered_data.h
@@ -21,7 +21,8 @@
#define BUTIL_DOUBLY_BUFFERED_DATA_H
#include <deque>
-#include <vector> // std::vector
+#include <vector>
+#include <memory>
#include <pthread.h>
#include "butil/scoped_lock.h"
#include "butil/thread_local.h"
@@ -87,6 +88,8 @@ class DoublyBufferedData {
class Wrapper;
class WrapperTLSGroup;
typedef int WrapperTLSId;
+ typedef std::shared_ptr<Wrapper> WrapperSharedPtr;
+ typedef std::weak_ptr<Wrapper> WrapperWeakPtr;
public:
class ScopedPtr {
friend class DoublyBufferedData;
@@ -111,7 +114,7 @@ public:
const T* _data;
// Index of foreground instance used by ScopedPtr.
int _index;
- Wrapper* _w;
+ WrapperSharedPtr _w;
};
DoublyBufferedData();
@@ -152,8 +155,7 @@ private:
return _data + index;
}
- Wrapper* AddWrapper(Wrapper*);
- void RemoveWrapper(Wrapper*);
+ WrapperSharedPtr GetWrapper();
// Foreground and background void.
T _data[2];
@@ -165,7 +167,7 @@ private:
WrapperTLSId _wrapper_key;
// All thread-local instances.
- std::vector<Wrapper*> _wrappers;
+ std::vector<WrapperWeakPtr> _wrappers;
// Sequence access to _wrappers.
pthread_mutex_t _wrappers_mutex{};
@@ -195,18 +197,22 @@ class DoublyBufferedData<T, TLS,
AllowBthreadSuspended>::WrapperTLSGroup {
public:
const static size_t RAW_BLOCK_SIZE = 4096;
const static size_t ELEMENTS_PER_BLOCK =
- RAW_BLOCK_SIZE / sizeof(Wrapper) > 0 ? RAW_BLOCK_SIZE /
sizeof(Wrapper) : 1;
+ RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) > 0 ?
+ RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) : 1;
struct BAIDU_CACHELINE_ALIGNMENT ThreadBlock {
- inline DoublyBufferedData::Wrapper* at(size_t offset) {
- return _data + offset;
+ WrapperSharedPtr at(size_t offset) {
+ if (NULL == _data[offset]) {
+ _data[offset] = std::make_shared<Wrapper>();
+ }
+ return _data[offset];
};
private:
- DoublyBufferedData::Wrapper _data[ELEMENTS_PER_BLOCK];
+ WrapperSharedPtr _data[ELEMENTS_PER_BLOCK];
};
- inline static WrapperTLSId key_create() {
+ static WrapperTLSId key_create() {
BAIDU_SCOPED_LOCK(_s_mutex);
WrapperTLSId id = 0;
if (!_get_free_ids().empty()) {
@@ -218,7 +224,7 @@ public:
return id;
}
- inline static int key_delete(WrapperTLSId id) {
+ static int key_delete(WrapperTLSId id) {
BAIDU_SCOPED_LOCK(_s_mutex);
if (id < 0 || id >= _s_id) {
errno = EINVAL;
@@ -228,17 +234,13 @@ public:
return 0;
}
- inline static DoublyBufferedData::Wrapper*
get_or_create_tls_data(WrapperTLSId id) {
+ static WrapperSharedPtr get_or_create_tls_data(WrapperTLSId id) {
if (BAIDU_UNLIKELY(id < 0)) {
CHECK(false) << "Invalid id=" << id;
return NULL;
}
if (_s_tls_blocks == NULL) {
- _s_tls_blocks = new (std::nothrow) std::vector<ThreadBlock*>;
- if (BAIDU_UNLIKELY(_s_tls_blocks == NULL)) {
- LOG(FATAL) << "Fail to create vector, " << berror();
- return NULL;
- }
+ _s_tls_blocks = new std::vector<ThreadBlock*>;
butil::thread_atexit(_destroy_tls_blocks);
}
const size_t block_id = (size_t)id / ELEMENTS_PER_BLOCK;
@@ -248,12 +250,8 @@ public:
}
ThreadBlock* tb = (*_s_tls_blocks)[block_id];
if (tb == NULL) {
- ThreadBlock* new_block = new (std::nothrow) ThreadBlock;
- if (BAIDU_UNLIKELY(new_block == NULL)) {
- return NULL;
- }
- tb = new_block;
- (*_s_tls_blocks)[block_id] = new_block;
+ tb = new ThreadBlock;
+ (*_s_tls_blocks)[block_id] = tb;
}
return tb->at(id - block_id * ELEMENTS_PER_BLOCK);
}
@@ -316,10 +314,6 @@ public:
}
~Wrapper() {
- if (_control != NULL) {
- _control->RemoveWrapper(this);
- }
-
if (AllowBthreadSuspended) {
WaitReadDone(0);
WaitReadDone(1);
@@ -406,9 +400,9 @@ private:
// Called when thread initializes thread-local wrapper.
template <typename T, typename TLS, bool AllowBthreadSuspended>
-typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper*
-DoublyBufferedData<T, TLS, AllowBthreadSuspended>::AddWrapper(
- typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper*
w) {
+typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperSharedPtr
+DoublyBufferedData<T, TLS, AllowBthreadSuspended>::GetWrapper() {
+ WrapperSharedPtr w = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key);
if (NULL == w) {
return NULL;
}
@@ -423,29 +417,19 @@ DoublyBufferedData<T, TLS,
AllowBthreadSuspended>::AddWrapper(
w->_control = this;
BAIDU_SCOPED_LOCK(_wrappers_mutex);
_wrappers.push_back(w);
+ // The chance to remove expired weak_ptr.
+ _wrappers.erase(
+ std::remove_if(_wrappers.begin(), _wrappers.end(),
+ [](const WrapperWeakPtr& w) {
+ return w.expired();
+ }),
+ _wrappers.end());
} catch (std::exception& e) {
return NULL;
}
return w;
}
-// Called when thread quits.
-template <typename T, typename TLS, bool AllowBthreadSuspended>
-void DoublyBufferedData<T, TLS, AllowBthreadSuspended>::RemoveWrapper(
- typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* w) {
- if (NULL == w) {
- return;
- }
- BAIDU_SCOPED_LOCK(_wrappers_mutex);
- for (size_t i = 0; i < _wrappers.size(); ++i) {
- if (_wrappers[i] == w) {
- _wrappers[i] = _wrappers.back();
- _wrappers.pop_back();
- return;
- }
- }
-}
-
template <typename T, typename TLS, bool AllowBthreadSuspended>
DoublyBufferedData<T, TLS, AllowBthreadSuspended>::DoublyBufferedData()
: _index(0)
@@ -474,7 +458,10 @@ DoublyBufferedData<T, TLS,
AllowBthreadSuspended>::~DoublyBufferedData() {
{
BAIDU_SCOPED_LOCK(_wrappers_mutex);
for (size_t i = 0; i < _wrappers.size(); ++i) {
- _wrappers[i]->_control = NULL; // hack: disable removal.
+ WrapperSharedPtr w = _wrappers[i].lock();
+ if (NULL != w) {
+ w->_control = NULL; // hack: disable removal.
+ }
}
_wrappers.clear();
}
@@ -487,29 +474,28 @@ DoublyBufferedData<T, TLS,
AllowBthreadSuspended>::~DoublyBufferedData() {
template <typename T, typename TLS, bool AllowBthreadSuspended>
int DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Read(
typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::ScopedPtr*
ptr) {
- Wrapper* p = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key);
- Wrapper* w = AddWrapper(p);
- if (BAIDU_LIKELY(w != NULL)) {
- if (AllowBthreadSuspended) {
- // Use reference count instead of mutex to indicate read of
- // foreground instance, so during the read process, there is
- // no need to lock mutex and bthread is allowed to be suspended.
- w->BeginRead();
- int index = -1;
- ptr->_data = UnsafeRead(index);
- ptr->_index = index;
- w->AddRef(index);
- ptr->_w = w;
- w->BeginReadRelease();
- } else {
- w->BeginRead();
- ptr->_data = UnsafeRead();
- ptr->_w = w;
- }
+ WrapperSharedPtr w = GetWrapper();
+ if (BAIDU_UNLIKELY(w == NULL)) {
+ return -1;
+ }
- return 0;
+ if (AllowBthreadSuspended) {
+ // Use reference count instead of mutex to indicate read of
+ // foreground instance, so during the read process, there is
+ // no need to lock mutex and bthread is allowed to be suspended.
+ w->BeginRead();
+ int index = -1;
+ ptr->_data = UnsafeRead(index);
+ ptr->_index = index;
+ w->AddRef(index);
+ ptr->_w = w;
+ w->BeginReadRelease();
+ } else {
+ w->BeginRead();
+ ptr->_data = UnsafeRead();
+ ptr->_w = w;
}
- return -1;
+ return 0;
}
template <typename T, typename TLS, bool AllowBthreadSuspended>
@@ -530,7 +516,7 @@ template <typename Fn, typename... Args>
size_t DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Modify(Fn&& fn,
Args&&... args) {
// _modify_mutex sequences modifications. Using a separate mutex rather
// than _wrappers_mutex is to avoid blocking threads calling
- // AddWrapper() or RemoveWrapper() too long. Most of the time,
modifications
+ // GetWrapper() too long. Most of the time, modifications
// are done by one thread, contention should be negligible.
BAIDU_SCOPED_LOCK(_modify_mutex);
int bg_index = !_index.load(butil::memory_order_relaxed);
@@ -552,14 +538,24 @@ size_t DoublyBufferedData<T, TLS,
AllowBthreadSuspended>::Modify(Fn&& fn, Args&&
// read, they should see updated _index.
{
BAIDU_SCOPED_LOCK(_wrappers_mutex);
- for (size_t i = 0; i < _wrappers.size(); ++i) {
- // Wait read of old foreground instance done.
- if (AllowBthreadSuspended) {
- _wrappers[i]->WaitReadDone(bg_index);
- } else {
- _wrappers[i]->WaitReadDone();
- }
- }
+ // The chance to remove expired weak_ptr.
+ _wrappers.erase(
+ std::remove_if(_wrappers.begin(), _wrappers.end(),
+ [bg_index](const WrapperWeakPtr& weak) {
+ WrapperSharedPtr w = weak.lock();
+ bool expired = NULL == w;
+ if (!expired) {
+ // Notify all threads waiting for read done.
+ if (AllowBthreadSuspended) {
+ w->WaitReadDone(bg_index);
+ } else {
+ w->WaitReadDone();
+ }
+ }
+ // Remove expired weak_ptr.
+ return expired;
+ }),
+ _wrappers.end());
}
const size_t ret2 = fn(_data[bg_index], std::forward<Args>(args)...);
diff --git a/src/json2pb/pb_to_json.cpp b/src/json2pb/pb_to_json.cpp
index c23ccdf7..e37cc87d 100644
--- a/src/json2pb/pb_to_json.cpp
+++ b/src/json2pb/pb_to_json.cpp
@@ -336,14 +336,14 @@ bool ProtoMessageToJson(const google::protobuf::Message&
message,
}
bool ProtoMessageToJson(const google::protobuf::Message& message,
- google::protobuf::io::ZeroCopyOutputStream *stream,
+ google::protobuf::io::ZeroCopyOutputStream* stream,
const Pb2JsonOptions& options, std::string* error) {
json2pb::ZeroCopyStreamWriter wrapper(stream);
return json2pb::ProtoMessageToJsonStream(message, options, wrapper, error);
}
bool ProtoMessageToJson(const google::protobuf::Message& message,
- google::protobuf::io::ZeroCopyOutputStream *stream,
+ google::protobuf::io::ZeroCopyOutputStream* stream,
std::string* error) {
return ProtoMessageToJson(message, stream, Pb2JsonOptions(), error);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]