This is an automated email from the ASF dual-hosted git repository.

dinglei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new d857a13  feat: Fix concurrent subscriptions (#1222)
d857a13 is described below

commit d857a13593072b8ea54b685bbb5e865be36f44c5
Author: wenxuwan <[email protected]>
AuthorDate: Fri Sep 19 09:42:10 2025 +0800

    feat: Fix concurrent subscriptions (#1222)
    
    * seperate interface and implement
    
    * fix panic when close tracedispatcher
    
    * Restore rlog/log.go
    
    * Delete default.go
    
    * fix consumer panic
    
    * change any to interface
    
    * Optimize UnCompress
    
    ---------
    
    Co-authored-by: 文徐 <[email protected]>
---
 consumer/push_consumer.go                |  48 +++++++++------
 internal/utils/compression.go            |   9 ++-
 internal/utils/compression_bench_test.go | 101 +++++++++++++++++++++++++++++++
 internal/utils/set.go                    |  14 +++++
 4 files changed, 150 insertions(+), 22 deletions(-)

diff --git a/consumer/push_consumer.go b/consumer/push_consumer.go
index f6864d5..5d87b51 100644
--- a/consumer/push_consumer.go
+++ b/consumer/push_consumer.go
@@ -67,12 +67,12 @@ type pushConsumer struct {
        queueMaxSpanFlowControlTimes int
        consumeFunc                  utils.Set
        submitToConsume              func(*processQueue, 
*primitive.MessageQueue)
-       subscribedTopic              map[string]string
+       subscribedTopic              sync.Map
        interceptor                  primitive.Interceptor
        queueLock                    *QueueLock
        done                         chan struct{}
        closeOnce                    sync.Once
-       crCh                         map[string]chan struct{}
+       crCh                         sync.Map
 }
 
 func NewPushConsumer(opts ...Option) (*pushConsumer, error) {
@@ -113,11 +113,9 @@ func NewPushConsumer(opts ...Option) (*pushConsumer, 
error) {
 
        p := &pushConsumer{
                defaultConsumer: dc,
-               subscribedTopic: make(map[string]string, 0),
                queueLock:       newQueueLock(),
                done:            make(chan struct{}, 1),
                consumeFunc:     utils.NewSet(),
-               crCh:            make(map[string]chan struct{}),
        }
        dc.mqChanged = p.messageQueueChanged
        if p.consumeOrderly {
@@ -165,7 +163,7 @@ func (pc *pushConsumer) Start() error {
                }
 
                retryTopic := internal.GetRetryTopic(pc.consumerGroup)
-               pc.crCh[retryTopic] = make(chan struct{}, 
pc.defaultConsumer.option.ConsumeGoroutineNums)
+               pc.crCh.Store(retryTopic, make(chan struct{}, 
pc.defaultConsumer.option.ConsumeGoroutineNums))
 
                go func() {
                        // todo start clean msg expired
@@ -236,13 +234,20 @@ func (pc *pushConsumer) Start() error {
        }
 
        pc.client.UpdateTopicRouteInfo()
-       for k := range pc.subscribedTopic {
+       pc.subscribedTopic.Range(func(k, v interface{}) bool {
                _, exist := pc.topicSubscribeInfoTable.Load(k)
                if !exist {
                        pc.Shutdown()
-                       return fmt.Errorf("the topic=%s route info not found, 
it may not exist", k)
+                       err = fmt.Errorf("the topic=%s route info not found, it 
may not exist", k)
+                       return false
                }
+               return true
+       })
+
+       if err != nil {
+               return err
        }
+
        pc.client.CheckClientInBroker()
        pc.client.SendHeartbeatToAllBrokerWithLock()
        go pc.client.RebalanceImmediately()
@@ -298,12 +303,10 @@ func (pc *pushConsumer) Subscribe(topic string, selector 
MessageSelector,
        if pc.option.Namespace != "" {
                topic = pc.option.Namespace + "%" + topic
        }
-       if _, ok := pc.crCh[topic]; !ok {
-               pc.crCh[topic] = make(chan struct{}, 
pc.defaultConsumer.option.ConsumeGoroutineNums)
-       }
+       pc.crCh.LoadOrStore(topic, make(chan struct{}, 
pc.defaultConsumer.option.ConsumeGoroutineNums))
        data := buildSubscriptionData(topic, selector)
        pc.subscriptionDataTable.Store(topic, data)
-       pc.subscribedTopic[topic] = ""
+       pc.subscribedTopic.LoadOrStore(topic, "")
 
        pc.consumeFunc.Add(&PushConsumerCallback{
                f:     f,
@@ -550,8 +553,12 @@ func (pc *pushConsumer) validate() error {
                // TODO FQA
                return fmt.Errorf("consumerGroup can't equal [%s], please 
specify another one", internal.DefaultConsumerGroup)
        }
-
-       if len(pc.subscribedTopic) == 0 {
+       noSubscribedTopic := true
+       pc.subscribedTopic.Range(func(key, value interface{}) bool {
+               noSubscribedTopic = false
+               return false
+       })
+       if noSubscribedTopic {
                rlog.Warning("not subscribe any topic yet", 
map[string]interface{}{
                        rlog.LogKeyConsumerGroup: pc.consumerGroup,
                })
@@ -1089,9 +1096,7 @@ func (pc *pushConsumer) consumeMessageConcurrently(pq 
*processQueue, mq *primiti
 
        limiter := pc.option.Limiter
        limiterOn := limiter != nil
-       if _, ok := pc.crCh[mq.Topic]; !ok {
-               pc.crCh[mq.Topic] = make(chan struct{}, 
pc.defaultConsumer.option.ConsumeGoroutineNums)
-       }
+       pc.crCh.LoadOrStore(mq.Topic, make(chan struct{}, 
pc.defaultConsumer.option.ConsumeGoroutineNums))
 
        for count := 0; count < len(msgs); count++ {
                var subMsgs []*primitive.MessageExt
@@ -1107,8 +1112,10 @@ func (pc *pushConsumer) consumeMessageConcurrently(pq 
*processQueue, mq *primiti
                if limiterOn {
                        limiter(utils.WithoutNamespace(mq.Topic))
                }
-               pc.crCh[mq.Topic] <- struct{}{}
-
+               ch, _ := pc.crCh.Load(mq.Topic)
+               if channel, ok := ch.(chan struct{}); ok {
+                       channel <- struct{}{}
+               }
                go primitive.WithRecover(func() {
                        defer func() {
                                if err := recover(); err != nil {
@@ -1121,7 +1128,10 @@ func (pc *pushConsumer) consumeMessageConcurrently(pq 
*processQueue, mq *primiti
                                                rlog.LogKeyConsumerGroup: 
pc.consumerGroup,
                                        })
                                }
-                               <-pc.crCh[mq.Topic]
+                               ch, _ := pc.crCh.Load(mq.Topic)
+                               if channel, ok := ch.(chan struct{}); ok {
+                                       <-channel
+                               }
                        }()
                RETRY:
                        if pq.IsDroppd() {
diff --git a/internal/utils/compression.go b/internal/utils/compression.go
index 11f1791..914c37e 100644
--- a/internal/utils/compression.go
+++ b/internal/utils/compression.go
@@ -21,7 +21,7 @@ import (
        "bytes"
        "compress/zlib"
        "github.com/apache/rocketmq-client-go/v2/errors"
-       "io/ioutil"
+       "io"
        "sync"
 )
 
@@ -79,9 +79,12 @@ func UnCompress(data []byte) []byte {
                return data
        }
        defer r.Close()
-       retData, err := ioutil.ReadAll(r)
+
+       // Use a buffer with reasonable initial size to avoid frequent 
reallocations
+       buf := bytes.NewBuffer(make([]byte, 0, len(data)*2))
+       _, err = io.Copy(buf, r)
        if err != nil {
                return data
        }
-       return retData
+       return buf.Bytes()
 }
diff --git a/internal/utils/compression_bench_test.go 
b/internal/utils/compression_bench_test.go
new file mode 100644
index 0000000..af084bb
--- /dev/null
+++ b/internal/utils/compression_bench_test.go
@@ -0,0 +1,101 @@
+package utils
+
+import (
+       "bytes"
+       "compress/zlib"
+       "io/ioutil"
+       "math/rand"
+       "strconv"
+       "testing"
+)
+
+func generateTestData(size int) []byte {
+       data := make([]byte, size)
+       rand.Read(data)
+       return data
+}
+
+func compressTestData(data []byte) []byte {
+       var buf bytes.Buffer
+       writer, _ := zlib.NewWriterLevel(&buf, zlib.BestCompression)
+       writer.Write(data)
+       writer.Close()
+       return buf.Bytes()
+}
+
+func UnCompressOriginal(data []byte) []byte {
+       rdata := bytes.NewReader(data)
+       r, err := zlib.NewReader(rdata)
+       if err != nil {
+               return data
+       }
+       defer r.Close()
+       retData, err := ioutil.ReadAll(r)
+       if err != nil {
+               return data
+       }
+       return retData
+}
+
+var testDataSizes = []int{1024, 64 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 
1024, 4 * 1024 * 1024}
+
+func BenchmarkUnCompress(b *testing.B) {
+       for _, size := range testDataSizes {
+               data := generateTestData(size)
+               compressed := compressTestData(data)
+
+               b.Run("New_"+formatSize(size), func(b *testing.B) {
+                       b.ResetTimer()
+                       b.ReportAllocs()
+                       for i := 0; i < b.N; i++ {
+                               result := UnCompress(compressed)
+                               _ = result
+                       }
+               })
+
+               b.Run("Original_"+formatSize(size), func(b *testing.B) {
+                       b.ResetTimer()
+                       b.ReportAllocs()
+                       for i := 0; i < b.N; i++ {
+                               result := UnCompressOriginal(compressed)
+                               _ = result
+                       }
+               })
+       }
+}
+
+func BenchmarkMemoryUsage(b *testing.B) {
+       // 测试大内存使用情况
+       largeData := generateTestData(4 * 1024 * 1024) // 4MB
+       compressed := compressTestData(largeData)
+
+       b.Run("New_Memory", func(b *testing.B) {
+               b.ResetTimer()
+               b.ReportAllocs()
+               for i := 0; i < b.N; i++ {
+                       result := UnCompress(compressed)
+                       _ = result
+               }
+       })
+
+       b.Run("Original_Memory", func(b *testing.B) {
+               b.ResetTimer()
+               b.ReportAllocs()
+               for i := 0; i < b.N; i++ {
+                       result := UnCompressOriginal(compressed)
+                       _ = result
+               }
+       })
+}
+
+func formatSize(bytes int) string {
+       if bytes < 1024 {
+               return strconv.Itoa(bytes) + "B"
+       } else if bytes < 1024*1024 {
+               return strconv.Itoa(bytes/1024) + "KB"
+       } else if bytes < 1024*1024*1024 {
+               return strconv.Itoa(bytes/(1024*1024)) + "MB"
+       } else {
+               return strconv.Itoa(bytes/(1024*1024*1024)) + "GB"
+       }
+}
diff --git a/internal/utils/set.go b/internal/utils/set.go
index ed9857b..0ab3ac5 100644
--- a/internal/utils/set.go
+++ b/internal/utils/set.go
@@ -21,6 +21,7 @@ import (
        "bytes"
        "encoding/json"
        "sort"
+       "sync"
 )
 
 type UniqueItem interface {
@@ -34,6 +35,7 @@ func (str StringUnique) UniqueID() string {
 }
 
 type Set struct {
+       mux   sync.RWMutex
        items map[string]UniqueItem
 }
 
@@ -44,29 +46,41 @@ func NewSet() Set {
 }
 
 func (s *Set) Items() map[string]UniqueItem {
+       s.mux.RLock()
+       defer s.mux.RUnlock()
        return s.items
 }
 
 func (s *Set) Add(v UniqueItem) {
+       s.mux.Lock()
+       defer s.mux.Unlock()
        s.items[v.UniqueID()] = v
 }
 
 func (s *Set) AddKV(k, v string) {
+       s.mux.Lock()
+       defer s.mux.Unlock()
        s.items[k] = StringUnique(v)
 }
 
 func (s *Set) Contains(k string) (UniqueItem, bool) {
+       s.mux.RLock()
+       defer s.mux.RUnlock()
        v, ok := s.items[k]
        return v, ok
 }
 
 func (s *Set) Len() int {
+       s.mux.RLock()
+       defer s.mux.RUnlock()
        return len(s.items)
 }
 
 var _ json.Marshaler = &Set{}
 
 func (s *Set) MarshalJSON() ([]byte, error) {
+       s.mux.RLock()
+       defer s.mux.RUnlock()
        if len(s.items) == 0 {
                return []byte("[]"), nil
        }

Reply via email to