This is an automated email from the ASF dual-hosted git repository.

BewareMyPower pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new eade6939 fix: deadlock when increasing partitioned consumers (#1500)
eade6939 is described below

commit eade69391f2119e966cc012aad803c13c674284d
Author: Yunze Xu <[email protected]>
AuthorDate: Thu May 21 16:29:26 2026 +0800

    fix: deadlock when increasing partitioned consumers (#1500)
---
 pulsar/consumer_impl.go         | 120 ++++++------
 pulsar/consumer_partition.go    |  15 +-
 pulsar/consumer_test.go         | 394 +++++++++++++++++++++++++++++++++++++++-
 pulsar/consumer_zero_queue.go   |   2 +-
 pulsar/message_chunking_test.go |   4 +-
 pulsar/reader_impl.go           |   5 +-
 pulsar/reader_test.go           |   6 +-
 7 files changed, 457 insertions(+), 89 deletions(-)

diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index d76108a2..5848b93c 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -25,6 +25,7 @@ import (
        "math/rand"
        "strconv"
        "sync"
+       "sync/atomic"
        "time"
 
        "github.com/apache/pulsar-client-go/pulsar/crypto"
@@ -49,18 +50,11 @@ type acker interface {
 }
 
 type consumer struct {
-       sync.Mutex
        topic   string
        client  *client
        options ConsumerOptions
 
-       // When accessing `consumers`, the lock must be acquired in case 
partitions are being added
-       // in the background by `internalTopicSubscribeToPartitions`. 
Currently, when a new sub-consumer
-       // is created, the current consumer can immediately receive messages 
from the new partition. However,
-       // before the new sub-consumers are visible in `consumers`, the Ack 
related methods cannot find the
-       // sub-consumer for the message's message ID, so we cannot simply 
change `consumers` to `atomic.Value`
-       // and perform copy-on-write when partitions are added.
-       consumers                 []*partitionConsumer
+       consumers                 atomic.Value
        consumerName              string
        disableForceTopicCreation bool
 
@@ -356,14 +350,10 @@ func (c *consumer) internalTopicSubscribeToPartitions() 
error {
                return err
        }
 
-       oldNumPartitions := 0
        newNumPartitions := len(partitions)
 
-       c.Lock()
-       defer c.Unlock()
-
-       oldConsumers := c.consumers
-       oldNumPartitions = len(oldConsumers)
+       oldConsumers := c.partitionConsumers()
+       oldNumPartitions := len(oldConsumers)
 
        if oldConsumers != nil {
                if oldNumPartitions == newNumPartitions {
@@ -376,14 +366,14 @@ func (c *consumer) internalTopicSubscribeToPartitions() 
error {
                        Info("Changed number of partitions in topic")
        }
 
-       c.consumers = make([]*partitionConsumer, newNumPartitions)
+       newConsumers := make([]*partitionConsumer, newNumPartitions)
 
        // When for some reason (eg: forced deletion of sub partition) causes 
oldNumPartitions> newNumPartitions,
        // we need to rebuild the cache of new consumers, otherwise the array 
will be out of bounds.
        if oldConsumers != nil && oldNumPartitions < newNumPartitions {
                // Copy over the existing consumer instances
                for i := 0; i < oldNumPartitions; i++ {
-                       c.consumers[i] = oldConsumers[i]
+                       newConsumers[i] = oldConsumers[i]
                }
        }
 
@@ -408,16 +398,16 @@ func (c *consumer) internalTopicSubscribeToPartitions() 
error {
        for partitionIdx := startPartition; partitionIdx < newNumPartitions; 
partitionIdx++ {
                partitionTopic := partitions[partitionIdx]
 
-               go func() {
+               go func(partitionIdx int, partitionTopic string) {
                        defer wg.Done()
                        opts := newPartitionConsumerOpts(partitionTopic, 
c.consumerName, partitionIdx, c.options)
-                       cons, err := newPartitionConsumer(c, c.client, opts, 
c.messageCh, c.dlq, c.metrics)
+                       cons, err := newPartitionConsumer(c, c.client, opts, 
c.messageCh, c.dlq, c.metrics, false)
                        ch <- ConsumerError{
                                err:       err,
                                partition: partitionIdx,
                                consumer:  cons,
                        }
-               }()
+               }(partitionIdx, partitionTopic)
        }
 
        go func() {
@@ -429,14 +419,14 @@ func (c *consumer) internalTopicSubscribeToPartitions() 
error {
                if ce.err != nil {
                        err = ce.err
                } else {
-                       c.consumers[ce.partition] = ce.consumer
+                       newConsumers[ce.partition] = ce.consumer
                }
        }
 
        if err != nil {
                // Since there were some failures,
                // cleanup all the partitions that succeeded in creating the 
consumer
-               for _, c := range c.consumers {
+               for _, c := range newConsumers {
                        if c != nil {
                                c.Close()
                        }
@@ -444,6 +434,10 @@ func (c *consumer) internalTopicSubscribeToPartitions() 
error {
                return err
        }
 
+       c.consumers.Store(append([]*partitionConsumer(nil), newConsumers...))
+       for partitionIdx := startPartition; partitionIdx < newNumPartitions; 
partitionIdx++ {
+               newConsumers[partitionIdx].startDispatcher()
+       }
        if newNumPartitions < oldNumPartitions {
                c.metrics.ConsumersPartitions.Set(float64(newNumPartitions))
        } else {
@@ -510,11 +504,9 @@ func (c *consumer) UnsubscribeForce() error {
 }
 
 func (c *consumer) unsubscribe(force bool) error {
-       c.Lock()
-       defer c.Unlock()
-
+       consumers := c.partitionConsumers()
        var errMsg string
-       for _, consumer := range c.consumers {
+       for _, consumer := range consumers {
                if err := consumer.unsubscribe(force); err != nil {
                        errMsg += fmt.Sprintf("topic %s, subscription %s: %s", 
consumer.topic, c.Subscription(), err)
                }
@@ -526,8 +518,9 @@ func (c *consumer) unsubscribe(force bool) error {
 }
 
 func (c *consumer) GetLastMessageIDs() ([]TopicMessageID, error) {
+       consumers := c.partitionConsumers()
        ids := make([]TopicMessageID, 0)
-       for _, pc := range c.consumers {
+       for _, pc := range consumers {
                id, err := pc.getLastMessageID()
                tm := &topicMessageID{topic: pc.topic, track: id}
                if err != nil {
@@ -556,7 +549,7 @@ func (c *consumer) Receive(ctx context.Context) (message 
Message, err error) {
 
 func (c *consumer) AckWithTxn(msg Message, txn Transaction) error {
        msgID := msg.ID()
-       consumer, err := c.findPartitionConsumer(msgID)
+       consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
        if err != nil {
                return err
        }
@@ -575,7 +568,7 @@ func (c *consumer) Ack(msg Message) error {
 
 // AckID the consumption of a single message, identified by its MessageID
 func (c *consumer) AckID(msgID MessageID) error {
-       consumer, err := c.findPartitionConsumer(msgID)
+       consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
        if err != nil {
                return err
        }
@@ -587,7 +580,7 @@ func (c *consumer) AckID(msgID MessageID) error {
 
 func (c *consumer) AckIDList(msgIDs []MessageID) error {
        return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) 
(acker, error) {
-               return c.findPartitionConsumer(msgID)
+               return findPartitionConsumer(c.partitionConsumers(), msgID)
        })
 }
 
@@ -600,7 +593,7 @@ func (c *consumer) AckCumulative(msg Message) error {
 // AckIDCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message, identified by its MessageID
 func (c *consumer) AckIDCumulative(msgID MessageID) error {
-       consumer, err := c.findPartitionConsumer(msgID)
+       consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
        if err != nil {
                return err
        }
@@ -697,7 +690,7 @@ func (c *consumer) Nack(msg Message) {
                        mid.NackByMsg(msg)
                        return
                }
-               if consumer, err := c.findPartitionConsumer(mid); err == nil {
+               if consumer, err := 
findPartitionConsumer(c.partitionConsumers(), mid); err == nil {
                        consumer.NackMsg(msg)
                }
                return
@@ -707,7 +700,7 @@ func (c *consumer) Nack(msg Message) {
 }
 
 func (c *consumer) NackID(msgID MessageID) {
-       if consumer, err := c.findPartitionConsumer(msgID); err == nil {
+       if consumer, err := findPartitionConsumer(c.partitionConsumers(), 
msgID); err == nil {
                consumer.NackID(msgID)
        }
 }
@@ -724,16 +717,14 @@ func (c *consumer) closeWithCause(err error) {
        c.closeOnce.Do(func() {
                c.stopDiscovery()
 
-               c.Lock()
-               defer c.Unlock()
-
                var wg sync.WaitGroup
-               for i := range c.consumers {
+               consumers := c.partitionConsumers()
+               for i := range consumers {
                        wg.Add(1)
                        go func(pc *partitionConsumer) {
                                defer wg.Done()
                                pc.Close()
-                       }(c.consumers[i])
+                       }(consumers[i])
                }
                wg.Wait()
                close(c.closeCh)
@@ -741,20 +732,19 @@ func (c *consumer) closeWithCause(err error) {
                c.dlq.close()
                c.rlq.close()
                c.metrics.ConsumersClosed.Inc()
-               c.metrics.ConsumersPartitions.Sub(float64(len(c.consumers)))
+               c.metrics.ConsumersPartitions.Sub(float64(len(consumers)))
                c.options.Interceptors.OnConsumerClose(c, err)
        })
 }
 
 func (c *consumer) Seek(msgID MessageID) error {
-       c.Lock()
-       defer c.Unlock()
+       consumers := c.partitionConsumers()
 
-       if len(c.consumers) > 1 {
+       if len(consumers) > 1 {
                return newError(SeekFailed, "for partition topic, seek command 
should perform on the individual partitions")
        }
 
-       consumer, err := c.unsafeFindPartitionConsumer(msgID)
+       consumer, err := findPartitionConsumer(consumers, msgID)
        if err != nil {
                return err
        }
@@ -768,11 +758,10 @@ func (c *consumer) Seek(msgID MessageID) error {
 }
 
 func (c *consumer) SeekByTime(time time.Time) error {
-       c.Lock()
-       defer c.Unlock()
        var errs error
+       consumers := c.partitionConsumers()
 
-       for _, cons := range c.consumers {
+       for _, cons := range consumers {
                cons.pauseDispatchMessage()
        }
        // clear messageCh
@@ -781,7 +770,7 @@ func (c *consumer) SeekByTime(time time.Time) error {
        }
 
        // run SeekByTime on every partition of topic
-       for _, cons := range c.consumers {
+       for _, cons := range consumers {
                if err := cons.SeekByTime(time); err != nil {
                        msg := fmt.Sprintf("unable to SeekByTime for topic=%s 
subscription=%s", c.topic, c.Subscription())
                        errs = pkgerrors.Wrap(newError(SeekFailed, 
err.Error()), msg)
@@ -791,35 +780,30 @@ func (c *consumer) SeekByTime(time time.Time) error {
        return errs
 }
 
-func (c *consumer) findPartitionConsumer(msgID MessageID) (*partitionConsumer, 
error) {
-       c.Lock()
-       defer c.Unlock()
-       return c.unsafeFindPartitionConsumer(msgID)
-}
-
-// NOTE: This method must be called when c.Lock is held
-func (c *consumer) unsafeFindPartitionConsumer(msgID MessageID) 
(*partitionConsumer, error) {
+func findPartitionConsumer(consumers []*partitionConsumer, msgID MessageID) 
(*partitionConsumer, error) {
        partition := int(msgID.PartitionIdx())
-       if partition < 0 || partition >= len(c.consumers) {
-               c.log.Errorf("invalid partition index %d expected a partition 
between [0-%d]",
-                       partition, len(c.consumers))
+       if partition < 0 || partition >= len(consumers) {
                return nil, fmt.Errorf("invalid partition index %d expected a 
partition between [0-%d]",
-                       partition, len(c.consumers))
+                       partition, len(consumers)-1)
+       }
+       return consumers[partition], nil
+}
+
+func (c *consumer) partitionConsumers() []*partitionConsumer {
+       v := c.consumers.Load()
+       if v == nil {
+               return nil
        }
-       return c.consumers[partition], nil
+       // The slice stored in c.consumers is published via copy-on-write.
+       // Callers must treat the returned slice as immutable.
+       return v.([]*partitionConsumer)
 }
 
 func (c *consumer) hasNext() bool {
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel() // Make sure all paths cancel the context to avoid 
context leak
 
-       // We have to make a snapshot consumers, because we have to iterate 
over all consumers in
-       // other goroutines. But when this method returns, there might be still 
other consumers
-       // not completing the `hasNext` call, so we cannot just call defer 
`c.Unlock()` after acquiring the lock.
-       c.Lock()
-       consumers := make([]*partitionConsumer, len(c.consumers))
-       copy(consumers, c.consumers)
-       c.Unlock()
+       consumers := c.partitionConsumers()
 
        var wg sync.WaitGroup
        wg.Add(len(consumers))
@@ -853,7 +837,7 @@ func (c *consumer) hasNext() bool {
 }
 
 func (c *consumer) setLastDequeuedMsg(msgID MessageID) error {
-       consumer, err := c.findPartitionConsumer(msgID)
+       consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
        if err != nil {
                return err
        }
@@ -920,7 +904,7 @@ func toProtoInitialPosition(p SubscriptionInitialPosition) 
pb.CommandSubscribe_I
 }
 
 func (c *consumer) messageID(msgID MessageID) *trackingMessageID {
-       if _, err := c.findPartitionConsumer(msgID); err != nil {
+       if _, err := findPartitionConsumer(c.partitionConsumers(), msgID); err 
!= nil {
                return nil
        }
        return toTrackingMessageID(msgID)
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index d7aba524..08e1b36f 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -401,8 +401,8 @@ func (s *schemaInfoCache) add(schemaVersionHash string, 
schema Schema) {
 }
 
 func newPartitionConsumer(parent Consumer, client *client, options 
*partitionConsumerOpts,
-       messageCh chan ConsumerMessage, dlq *dlqRouter,
-       metrics *internal.LeveledMetrics) (*partitionConsumer, error) {
+       messageCh chan ConsumerMessage, dlq *dlqRouter, metrics 
*internal.LeveledMetrics,
+       startDispatcher bool) (*partitionConsumer, error) {
        var boFunc func() backoff.Policy
        if options.backOffPolicyFunc != nil {
                boFunc = options.backOffPolicyFunc
@@ -425,7 +425,7 @@ func newPartitionConsumer(parent Consumer, client *client, 
options *partitionCon
                queueCh:                    make(chan []*message, 
options.receiverQueueSize),
                startMessageID:             atomicMessageID{msgID: 
options.startMessageID},
                seekMessageID:              atomicMessageID{msgID: nil},
-               connectedCh:                make(chan struct{}),
+               connectedCh:                make(chan struct{}, 1),
                messageCh:                  messageCh,
                connectClosedCh:            make(chan *connectionClosed, 1),
                closeCh:                    make(chan struct{}),
@@ -512,13 +512,18 @@ func newPartitionConsumer(parent Consumer, client 
*client, options *partitionCon
                }
        }
 
-       go pc.dispatcher()
-
        go pc.runEventsLoop()
+       if startDispatcher {
+               pc.startDispatcher()
+       }
 
        return pc, nil
 }
 
+func (pc *partitionConsumer) startDispatcher() {
+       go pc.dispatcher()
+}
+
 func (pc *partitionConsumer) unsubscribe(force bool) error {
        if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
                pc.log.WithField("state", state).Error("Failed to unsubscribe 
closing or closed consumer")
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index e1964287..c3fc82d4 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -22,7 +22,9 @@ import (
        "errors"
        "fmt"
        "log"
+       "log/slog"
        "net/http"
+       "net/url"
        "os"
        "regexp"
        "strconv"
@@ -4723,7 +4725,7 @@ func TestConsumerWithBackoffPolicy(t *testing.T) {
        assert.Nil(t, err)
        defer _consumer.Close()
 
-       partitionConsumerImp := _consumer.(*consumer).consumers[0]
+       partitionConsumerImp := _consumer.(*consumer).partitionConsumers()[0]
        // 1 s
        startTime := time.Now()
        partitionConsumerImp.reconnectToBroker(nil)
@@ -4946,7 +4948,7 @@ func TestConsumerWithAutoScaledQueueReceive(t *testing.T) 
{
                EnableAutoScaledReceiverQueueSize: true,
        })
        assert.Nil(t, err)
-       pc := c.(*consumer).consumers[0]
+       pc := c.(*consumer).partitionConsumers()[0]
        assert.Equal(t, int32(1), pc.currentQueueSize.Load())
        defer c.Close()
 
@@ -5161,7 +5163,7 @@ func TestConsumerMemoryLimit(t *testing.T) {
        })
        assert.Nil(t, err)
        defer c1.Close()
-       pc1 := c1.(*consumer).consumers[0]
+       pc1 := c1.(*consumer).partitionConsumers()[0]
 
        // Fill up the messageCh of c1
        for i := 0; i < 10; i++ {
@@ -5201,7 +5203,7 @@ func TestConsumerMemoryLimit(t *testing.T) {
        })
        assert.Nil(t, err)
        defer c2.Close()
-       pc2 := c2.(*consumer).consumers[0]
+       pc2 := c2.(*consumer).partitionConsumers()[0]
 
        // Try to induce c2 receiver queue size expansion
        for i := 0; i < 10; i++ {
@@ -5273,7 +5275,7 @@ func TestMultiConsumerMemoryLimit(t *testing.T) {
        })
        assert.Nil(t, err)
        defer c1.Close()
-       pc1 := c1.(*consumer).consumers[0]
+       pc1 := c1.(*consumer).partitionConsumers()[0]
 
        // Use mem-limited client 2 to create consumer c1
        c2, err := cli2.Subscribe(ConsumerOptions{
@@ -5284,7 +5286,7 @@ func TestMultiConsumerMemoryLimit(t *testing.T) {
        })
        assert.Nil(t, err)
        defer c2.Close()
-       pc2 := c2.(*consumer).consumers[0]
+       pc2 := c2.(*consumer).partitionConsumers()[0]
 
        // Fill up the messageCh of c1 nad c2
        for i := 0; i < 10; i++ {
@@ -5918,7 +5920,7 @@ func TestSelectConnectionForSameConsumer(t *testing.T) {
        assert.NoError(t, err)
        defer _consumer.Close()
 
-       partitionConsumerImpl := _consumer.(*consumer).consumers[0]
+       partitionConsumerImpl := _consumer.(*consumer).partitionConsumers()[0]
        conn := partitionConsumerImpl._getConn()
 
        for i := 0; i < 5; i++ {
@@ -5928,6 +5930,382 @@ func TestSelectConnectionForSameConsumer(t *testing.T) {
        }
 }
 
+func 
TestInternalTopicSubscribeToPartitionsDoesNotBlockExistingPartitionLookup(t 
*testing.T) {
+       lookupURL, err := url.Parse("pulsar://localhost:6650")
+       require.NoError(t, err)
+
+       allowSubscribe := make(chan struct{})
+       subscribeStarted := make(chan struct{})
+       var releaseSubscribe sync.Once
+
+       logger := slog.New(slog.NewJSONHandler(os.Stdout, 
&slog.HandlerOptions{Level: slog.LevelInfo}))
+       log := plog.NewLoggerWithSlog(logger)
+
+       rpcClient := &blockingSubscribeRPCClient{
+               lookupResult:     &internal.LookupResult{LogicalAddr: 
lookupURL, PhysicalAddr: lookupURL},
+               subscribeStarted: subscribeStarted,
+               allowSubscribe:   allowSubscribe,
+               subscribeErr:     errors.New("stop subscribe after lookup 
check"),
+               nextConsumerID:   1,
+       }
+
+       c := 
newInternalTopicPartitionTestConsumer(internalTopicPartitionTestConsumerOptions{
+               conn:             dummyConnection{},
+               rpcClient:        rpcClient,
+               partitions:       2,
+               log:              log,
+               consumerOptions:  ConsumerOptions{SubscriptionName: "test-sub", 
NackPrecisionBit: ptr(defaultNackPrecisionBit)},
+               initialConsumers: []*partitionConsumer{{topic: 
"persistent://public/default/test-topic-partition-0"}},
+       })
+
+       go func() {
+               c.internalTopicSubscribeToPartitions()
+       }()
+
+       select {
+       case <-subscribeStarted:
+       case <-time.After(3 * time.Second):
+               t.Fatal("timed out waiting for partition discovery to start 
subscribing the new partition")
+       }
+
+       lookupErrCh := make(chan error, 1)
+       go func() {
+               _, err := findPartitionConsumer(c.partitionConsumers(), 
&messageID{partitionIdx: 0})
+               lookupErrCh <- err
+       }()
+
+       select {
+       case err := <-lookupErrCh:
+               require.NoError(t, err)
+       case <-time.After(3 * time.Second):
+               releaseSubscribe.Do(func() { close(allowSubscribe) })
+               select {
+               case <-lookupErrCh:
+               case <-time.After(time.Second):
+                       t.Fatal("existing partition lookup stayed blocked even 
after partition discovery stopped")
+               }
+               t.Fatal("existing partition lookup blocked while a new 
partition was being added")
+       }
+
+       releaseSubscribe.Do(func() { close(allowSubscribe) })
+}
+
+func 
TestInternalTopicSubscribeToPartitionsPublishesConsumersBeforeDispatchingMessages(t
 *testing.T) {
+       lookupURL, err := url.Parse("pulsar://localhost:6650")
+       require.NoError(t, err)
+
+       partitionOneSubscribed := make(chan struct{})
+       partitionOneFlowed := make(chan struct{})
+       partitionTwoBlocked := make(chan struct{})
+       allowPartitionTwo := make(chan struct{})
+       cnx := newPartitionExpansionRaceConnection()
+       rpcClient := &partitionExpansionRaceRPCClient{
+               lookupResult:           &internal.LookupResult{LogicalAddr: 
lookupURL, PhysicalAddr: lookupURL},
+               cnx:                    cnx,
+               partitionOneSubscribed: partitionOneSubscribed,
+               partitionOneFlowed:     partitionOneFlowed,
+               partitionTwoBlocked:    partitionTwoBlocked,
+               allowPartitionTwo:      allowPartitionTwo,
+       }
+
+       c := 
newInternalTopicPartitionTestConsumer(internalTopicPartitionTestConsumerOptions{
+               conn:       cnx,
+               rpcClient:  rpcClient,
+               partitions: 3,
+               log:        plog.DefaultNopLogger(),
+               consumerOptions: ConsumerOptions{
+                       SubscriptionName:  "test-sub",
+                       ReceiverQueueSize: 1,
+                       NackPrecisionBit:  ptr(defaultNackPrecisionBit),
+                       AckWithResponse:   true,
+               },
+               initialConsumers: []*partitionConsumer{{topic: 
"persistent://public/default/test-topic-partition-0"}},
+               dlq:              &dlqRouter{},
+       })
+
+       errCh := make(chan error, 1)
+       go func() {
+               errCh <- c.internalTopicSubscribeToPartitions()
+       }()
+
+       select {
+       case <-partitionOneSubscribed:
+       case <-time.After(3 * time.Second):
+               t.Fatal("timed out waiting for partition 1 to subscribe")
+       }
+
+       select {
+       case <-partitionTwoBlocked:
+       case <-time.After(3 * time.Second):
+               t.Fatal("timed out waiting for partition 2 subscribe to block")
+       }
+
+       require.Len(t, c.partitionConsumers(), 1)
+       select {
+       case <-partitionOneFlowed:
+               t.Fatal("new partition dispatcher requested permits before 
c.consumers contained the new partition")
+       case <-time.After(200 * time.Millisecond):
+       }
+
+       close(allowPartitionTwo)
+
+       select {
+       case err := <-errCh:
+               require.NoError(t, err)
+       case <-time.After(3 * time.Second):
+               t.Fatal("timed out waiting for partition discovery to finish")
+       }
+       require.Len(t, c.partitionConsumers(), 3)
+
+       select {
+       case <-partitionOneFlowed:
+       case <-time.After(3 * time.Second):
+               t.Fatal("timed out waiting for partition 1 dispatcher to 
request permits")
+       }
+
+       handler := cnx.handler(rpcClient.partitionOneConsumerID.Load())
+       require.NotNil(t, handler)
+       err = handler.MessageReceived(&pb.CommandMessage{
+               MessageId: &pb.MessageIdData{
+                       LedgerId: proto.Uint64(1),
+                       EntryId:  proto.Uint64(1),
+               },
+       }, internal.NewBufferWrapper(rawCompatSingleMessage))
+       require.NoError(t, err)
+
+       var cm ConsumerMessage
+       select {
+       case cm = <-c.messageCh:
+       case <-time.After(3 * time.Second):
+               t.Fatal("timed out waiting for the queued partition 1 message 
to dispatch")
+       }
+       require.Equal(t, int32(1), cm.Message.ID().PartitionIdx())
+       require.NoError(t, c.AckID(cm.Message.ID()))
+
+       for _, pc := range c.partitionConsumers()[1:] {
+               pc.Close()
+       }
+}
+
+type internalTopicPartitionTestConsumerOptions struct {
+       conn             internal.Connection
+       rpcClient        internal.RPCClient
+       partitions       int
+       log              plog.Logger
+       consumerOptions  ConsumerOptions
+       initialConsumers []*partitionConsumer
+       dlq              *dlqRouter
+}
+
+func newInternalTopicPartitionTestConsumer(opts 
internalTopicPartitionTestConsumerOptions) *consumer {
+       var consumers atomic.Value
+       consumers.Store(append([]*partitionConsumer(nil), 
opts.initialConsumers...))
+
+       return &consumer{
+               topic: "persistent://public/default/test-topic",
+               client: &client{
+                       cnxPool:       &blockingConnPool{cnx: opts.conn},
+                       rpcClient:     opts.rpcClient,
+                       lookupService: &partitionMetadataLookup{partitions: 
opts.partitions},
+                       log:           opts.log,
+               },
+               options:      opts.consumerOptions,
+               consumers:    consumers,
+               messageCh:    make(chan ConsumerMessage, 1),
+               closeCh:      make(chan struct{}),
+               errorCh:      make(chan error, 1),
+               consumerName: "test-consumer",
+               dlq:          opts.dlq,
+               log:          opts.log,
+               metrics:      newTestMetrics(),
+       }
+}
+
+type partitionMetadataLookup struct {
+       internal.LookupService
+       partitions int
+}
+
+func (l *partitionMetadataLookup) GetPartitionedTopicMetadata(_ string) 
(*internal.PartitionedTopicMetadata, error) {
+       return &internal.PartitionedTopicMetadata{Partitions: l.partitions}, nil
+}
+
+type blockingConnPool struct {
+       internal.ConnectionPool
+       cnx internal.Connection
+}
+
+func (p *blockingConnPool) GetConnection(_ *url.URL, _ *url.URL, _ int32) 
(internal.Connection, error) {
+       return p.cnx, nil
+}
+
+func (p *blockingConnPool) GetConnections() map[string]internal.Connection {
+       return map[string]internal.Connection{}
+}
+
+func (p *blockingConnPool) GenerateRoundRobinIndex() int32 {
+       return 0
+}
+
+func (p *blockingConnPool) Close() {}
+
+type partitionExpansionRaceConnection struct {
+       dummyConnection
+       mu       sync.Mutex
+       handlers map[uint64]internal.ConsumerHandler
+}
+
+func newPartitionExpansionRaceConnection() *partitionExpansionRaceConnection {
+       return &partitionExpansionRaceConnection{handlers: 
make(map[uint64]internal.ConsumerHandler)}
+}
+
+func (c *partitionExpansionRaceConnection) AddConsumeHandler(id uint64, 
handler internal.ConsumerHandler) error {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       c.handlers[id] = handler
+       return nil
+}
+
+func (c *partitionExpansionRaceConnection) DeleteConsumeHandler(id uint64) {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       delete(c.handlers, id)
+}
+
+func (c *partitionExpansionRaceConnection) handler(id uint64) 
internal.ConsumerHandler {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       return c.handlers[id]
+}
+
+type partitionExpansionRaceRPCClient struct {
+       internal.RPCClient
+       lookupResult           *internal.LookupResult
+       cnx                    *partitionExpansionRaceConnection
+       partitionOneSubscribed chan struct{}
+       partitionOneFlowed     chan struct{}
+       partitionTwoBlocked    chan struct{}
+       allowPartitionTwo      chan struct{}
+       requestID              atomic.Uint64
+       consumerID             atomic.Uint64
+       partitionOneConsumerID atomic.Uint64
+       partitionOneOnce       sync.Once
+       partitionOneFlowOnce   sync.Once
+       partitionTwoOnce       sync.Once
+}
+
+func (r *partitionExpansionRaceRPCClient) NewRequestID() uint64 {
+       return r.requestID.Add(1)
+}
+
+func (r *partitionExpansionRaceRPCClient) NewProducerID() uint64 {
+       return r.requestID.Add(1)
+}
+
+func (r *partitionExpansionRaceRPCClient) NewConsumerID() uint64 {
+       return r.consumerID.Add(1)
+}
+
+func (r *partitionExpansionRaceRPCClient) RequestOnCnxNoWait(
+       _ internal.Connection, cmdType pb.BaseCommand_Type, msg proto.Message,
+) error {
+       if cmdType == pb.BaseCommand_FLOW {
+               flow := msg.(*pb.CommandFlow)
+               if flow.GetConsumerId() == r.partitionOneConsumerID.Load() {
+                       r.partitionOneFlowOnce.Do(func() { 
close(r.partitionOneFlowed) })
+               }
+       }
+       return nil
+}
+
+func (r *partitionExpansionRaceRPCClient) RequestOnCnx(
+       _ internal.Connection, _ uint64, cmdType pb.BaseCommand_Type, msg 
proto.Message,
+) (*internal.RPCResult, error) {
+       switch cmdType {
+       case pb.BaseCommand_SUBSCRIBE:
+               return r.handleSubscribe(msg.(*pb.CommandSubscribe))
+       case pb.BaseCommand_ACK, pb.BaseCommand_CLOSE_CONSUMER:
+               return r.success(), nil
+       default:
+               return nil, fmt.Errorf("unexpected command type %v", cmdType)
+       }
+}
+
+func (r *partitionExpansionRaceRPCClient) handleSubscribe(cmd 
*pb.CommandSubscribe) (*internal.RPCResult, error) {
+       switch {
+       case strings.HasSuffix(cmd.GetTopic(), "-partition-1"):
+               r.partitionOneConsumerID.Store(cmd.GetConsumerId())
+               r.partitionOneOnce.Do(func() { close(r.partitionOneSubscribed) 
})
+               return r.success(), nil
+       case strings.HasSuffix(cmd.GetTopic(), "-partition-2"):
+               r.partitionTwoOnce.Do(func() { close(r.partitionTwoBlocked) })
+               <-r.allowPartitionTwo
+               return r.success(), nil
+       default:
+               return nil, fmt.Errorf("unexpected subscribe topic %s", 
cmd.GetTopic())
+       }
+}
+
+func (r *partitionExpansionRaceRPCClient) success() *internal.RPCResult {
+       successType := pb.BaseCommand_SUCCESS
+       return &internal.RPCResult{
+               Response: &pb.BaseCommand{Type: &successType},
+               Cnx:      r.cnx,
+       }
+}
+
+func (r *partitionExpansionRaceRPCClient) LookupService(_ string) 
(internal.LookupService, error) {
+       return &grabConnMockLookup{result: r.lookupResult}, nil
+}
+
+type blockingSubscribeRPCClient struct {
+       internal.RPCClient
+       lookupResult     *internal.LookupResult
+       subscribeStarted chan struct{}
+       allowSubscribe   chan struct{}
+       subscribeErr     error
+       nextConsumerID   uint64
+       startOnce        sync.Once
+}
+
+func (r *blockingSubscribeRPCClient) NewRequestID() uint64 {
+       return 1
+}
+
+func (r *blockingSubscribeRPCClient) NewProducerID() uint64 {
+       return 1
+}
+
+func (r *blockingSubscribeRPCClient) NewConsumerID() uint64 {
+       id := r.nextConsumerID
+       r.nextConsumerID++
+       return id
+}
+
+func (r *blockingSubscribeRPCClient) RequestOnCnxNoWait(
+       _ internal.Connection, _ pb.BaseCommand_Type, _ proto.Message) error {
+       return nil
+}
+
+func (r *blockingSubscribeRPCClient) RequestOnCnx(
+       _ internal.Connection, _ uint64, cmdType pb.BaseCommand_Type, _ 
proto.Message,
+) (*internal.RPCResult, error) {
+       switch cmdType {
+       case pb.BaseCommand_SUBSCRIBE:
+               r.startOnce.Do(func() { close(r.subscribeStarted) })
+               <-r.allowSubscribe
+               return nil, r.subscribeErr
+       case pb.BaseCommand_CLOSE_CONSUMER:
+               return nil, nil
+       default:
+               return nil, fmt.Errorf("unexpected command type %v", cmdType)
+       }
+}
+
+func (r *blockingSubscribeRPCClient) LookupService(_ string) 
(internal.LookupService, error) {
+       return &grabConnMockLookup{result: r.lookupResult}, nil
+}
+
 // closeInterceptor captures the (consumer, err) pair delivered to
 // ConsumerCloseInterceptor.OnConsumerClose and signals via fired.
 type closeInterceptor struct {
@@ -6006,7 +6384,7 @@ func TestConsumerOnCloseInterceptorOnMaxReconnect(t 
*testing.T) {
        assert.NotNil(t, interceptor.err, "interceptor should receive the cause 
of the close")
        assert.Equal(t, testConsumer, interceptor.consumer, "interceptor should 
receive the parent consumer")
 
-       pc := testConsumer.(*consumer).consumers[0]
+       pc := testConsumer.(*consumer).partitionConsumers()[0]
        require.Eventually(t, func() bool {
                return pc.getConsumerState() == consumerClosed
        }, 30*time.Second, 100*time.Millisecond, "consumer should be closed 
after exhausting max reconnect retries")
diff --git a/pulsar/consumer_zero_queue.go b/pulsar/consumer_zero_queue.go
index 4978fae2..20a0944e 100644
--- a/pulsar/consumer_zero_queue.go
+++ b/pulsar/consumer_zero_queue.go
@@ -80,7 +80,7 @@ func newZeroConsumer(client *client, options ConsumerOptions, 
topic string,
                        pc.availablePermits.inc()
                }
        }
-       pc, err := newPartitionConsumer(zc, zc.client, opts, zc.messageCh, 
zc.dlq, zc.metrics)
+       pc, err := newPartitionConsumer(zc, zc.client, opts, zc.messageCh, 
zc.dlq, zc.metrics, true)
        if err != nil {
                return nil, err
        }
diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go
index 12f0517c..2d6462e7 100644
--- a/pulsar/message_chunking_test.go
+++ b/pulsar/message_chunking_test.go
@@ -178,7 +178,7 @@ func TestMaxPendingChunkMessages(t *testing.T) {
        assert.NoError(t, err)
        assert.NotNil(t, c)
        defer c.Close()
-       pc := c.(*consumer).consumers[0]
+       pc := c.(*consumer).partitionConsumers()[0]
 
        sendSingleChunk(producer, "0", 0, 2)
        // MaxPendingChunkedMessage is 1, the chunked message with uuid 0 will 
be discarded
@@ -228,7 +228,7 @@ func TestExpireIncompleteChunks(t *testing.T) {
        defer c.Close()
 
        uuid := "test-uuid"
-       chunkCtxMap := c.(*consumer).consumers[0].chunkedMsgCtxMap
+       chunkCtxMap := c.(*consumer).partitionConsumers()[0].chunkedMsgCtxMap
        chunkCtxMap.addIfAbsent(uuid, 2, 100)
        ctx := chunkCtxMap.get(uuid)
        assert.NotNil(t, ctx)
diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go
index 55b05037..ad65c285 100644
--- a/pulsar/reader_impl.go
+++ b/pulsar/reader_impl.go
@@ -224,8 +224,9 @@ func (r *reader) SeekByTime(time time.Time) error {
 }
 
 func (r *reader) GetLastMessageID() (MessageID, error) {
-       if len(r.c.consumers) > 1 {
+       consumers := r.c.partitionConsumers()
+       if len(consumers) > 1 {
                return nil, fmt.Errorf("GetLastMessageID is not supported for 
multi-topics reader")
        }
-       return r.c.consumers[0].getLastMessageID()
+       return consumers[0].getLastMessageID()
 }
diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go
index 29d61b2b..d8ebc240 100644
--- a/pulsar/reader_test.go
+++ b/pulsar/reader_test.go
@@ -952,7 +952,7 @@ func TestReaderWithBackoffPolicy(t *testing.T) {
        assert.NotNil(t, _reader)
        assert.Nil(t, err)
 
-       partitionConsumerImp := _reader.(*reader).c.consumers[0]
+       partitionConsumerImp := _reader.(*reader).c.partitionConsumers()[0]
        // 1 s
        startTime := time.Now()
        partitionConsumerImp.reconnectToBroker(nil)
@@ -1061,7 +1061,7 @@ func TestReaderHasNextFailed(t *testing.T) {
                StartMessageID: EarliestMessageID(),
        })
        assert.Nil(t, err)
-       r.(*reader).c.consumers[0].state.Store(consumerClosing)
+       r.(*reader).c.partitionConsumers()[0].state.Store(consumerClosing)
        assert.False(t, r.HasNext())
 }
 
@@ -1082,7 +1082,7 @@ func TestReaderHasNextRetryFailed(t *testing.T) {
        defer close(c)
 
        // Close the consumer events loop and assign a mock eventsCh
-       pc := r.(*reader).c.consumers[0]
+       pc := r.(*reader).c.partitionConsumers()[0]
        pc.Close()
        pc.state.Store(consumerReady)
        pc.eventsCh = c

Reply via email to