This is an automated email from the ASF dual-hosted git repository.
BewareMyPower pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push:
new eade6939 fix: deadlock when increasing partitioned consumers (#1500)
eade6939 is described below
commit eade69391f2119e966cc012aad803c13c674284d
Author: Yunze Xu <[email protected]>
AuthorDate: Thu May 21 16:29:26 2026 +0800
fix: deadlock when increasing partitioned consumers (#1500)
---
pulsar/consumer_impl.go | 120 ++++++------
pulsar/consumer_partition.go | 15 +-
pulsar/consumer_test.go | 394 +++++++++++++++++++++++++++++++++++++++-
pulsar/consumer_zero_queue.go | 2 +-
pulsar/message_chunking_test.go | 4 +-
pulsar/reader_impl.go | 5 +-
pulsar/reader_test.go | 6 +-
7 files changed, 457 insertions(+), 89 deletions(-)
diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index d76108a2..5848b93c 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -25,6 +25,7 @@ import (
"math/rand"
"strconv"
"sync"
+ "sync/atomic"
"time"
"github.com/apache/pulsar-client-go/pulsar/crypto"
@@ -49,18 +50,11 @@ type acker interface {
}
type consumer struct {
- sync.Mutex
topic string
client *client
options ConsumerOptions
- // When accessing `consumers`, the lock must be acquired in case
partitions are being added
- // in the background by `internalTopicSubscribeToPartitions`.
Currently, when a new sub-consumer
- // is created, the current consumer can immediately receive messages
from the new partition. However,
- // before the new sub-consumers are visible in `consumers`, the Ack
related methods cannot find the
- // sub-consumer for the message's message ID, so we cannot simply
change `consumers` to `atomic.Value`
- // and perform copy-on-write when partitions are added.
- consumers []*partitionConsumer
+ consumers atomic.Value
consumerName string
disableForceTopicCreation bool
@@ -356,14 +350,10 @@ func (c *consumer) internalTopicSubscribeToPartitions()
error {
return err
}
- oldNumPartitions := 0
newNumPartitions := len(partitions)
- c.Lock()
- defer c.Unlock()
-
- oldConsumers := c.consumers
- oldNumPartitions = len(oldConsumers)
+ oldConsumers := c.partitionConsumers()
+ oldNumPartitions := len(oldConsumers)
if oldConsumers != nil {
if oldNumPartitions == newNumPartitions {
@@ -376,14 +366,14 @@ func (c *consumer) internalTopicSubscribeToPartitions()
error {
Info("Changed number of partitions in topic")
}
- c.consumers = make([]*partitionConsumer, newNumPartitions)
+ newConsumers := make([]*partitionConsumer, newNumPartitions)
// When for some reason (eg: forced deletion of sub partition) causes
oldNumPartitions> newNumPartitions,
// we need to rebuild the cache of new consumers, otherwise the array
will be out of bounds.
if oldConsumers != nil && oldNumPartitions < newNumPartitions {
// Copy over the existing consumer instances
for i := 0; i < oldNumPartitions; i++ {
- c.consumers[i] = oldConsumers[i]
+ newConsumers[i] = oldConsumers[i]
}
}
@@ -408,16 +398,16 @@ func (c *consumer) internalTopicSubscribeToPartitions()
error {
for partitionIdx := startPartition; partitionIdx < newNumPartitions;
partitionIdx++ {
partitionTopic := partitions[partitionIdx]
- go func() {
+ go func(partitionIdx int, partitionTopic string) {
defer wg.Done()
opts := newPartitionConsumerOpts(partitionTopic,
c.consumerName, partitionIdx, c.options)
- cons, err := newPartitionConsumer(c, c.client, opts,
c.messageCh, c.dlq, c.metrics)
+ cons, err := newPartitionConsumer(c, c.client, opts,
c.messageCh, c.dlq, c.metrics, false)
ch <- ConsumerError{
err: err,
partition: partitionIdx,
consumer: cons,
}
- }()
+ }(partitionIdx, partitionTopic)
}
go func() {
@@ -429,14 +419,14 @@ func (c *consumer) internalTopicSubscribeToPartitions()
error {
if ce.err != nil {
err = ce.err
} else {
- c.consumers[ce.partition] = ce.consumer
+ newConsumers[ce.partition] = ce.consumer
}
}
if err != nil {
// Since there were some failures,
// cleanup all the partitions that succeeded in creating the
consumer
- for _, c := range c.consumers {
+ for _, c := range newConsumers {
if c != nil {
c.Close()
}
@@ -444,6 +434,10 @@ func (c *consumer) internalTopicSubscribeToPartitions()
error {
return err
}
+ c.consumers.Store(append([]*partitionConsumer(nil), newConsumers...))
+ for partitionIdx := startPartition; partitionIdx < newNumPartitions;
partitionIdx++ {
+ newConsumers[partitionIdx].startDispatcher()
+ }
if newNumPartitions < oldNumPartitions {
c.metrics.ConsumersPartitions.Set(float64(newNumPartitions))
} else {
@@ -510,11 +504,9 @@ func (c *consumer) UnsubscribeForce() error {
}
func (c *consumer) unsubscribe(force bool) error {
- c.Lock()
- defer c.Unlock()
-
+ consumers := c.partitionConsumers()
var errMsg string
- for _, consumer := range c.consumers {
+ for _, consumer := range consumers {
if err := consumer.unsubscribe(force); err != nil {
errMsg += fmt.Sprintf("topic %s, subscription %s: %s",
consumer.topic, c.Subscription(), err)
}
@@ -526,8 +518,9 @@ func (c *consumer) unsubscribe(force bool) error {
}
func (c *consumer) GetLastMessageIDs() ([]TopicMessageID, error) {
+ consumers := c.partitionConsumers()
ids := make([]TopicMessageID, 0)
- for _, pc := range c.consumers {
+ for _, pc := range consumers {
id, err := pc.getLastMessageID()
tm := &topicMessageID{topic: pc.topic, track: id}
if err != nil {
@@ -556,7 +549,7 @@ func (c *consumer) Receive(ctx context.Context) (message
Message, err error) {
func (c *consumer) AckWithTxn(msg Message, txn Transaction) error {
msgID := msg.ID()
- consumer, err := c.findPartitionConsumer(msgID)
+ consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
@@ -575,7 +568,7 @@ func (c *consumer) Ack(msg Message) error {
// AckID the consumption of a single message, identified by its MessageID
func (c *consumer) AckID(msgID MessageID) error {
- consumer, err := c.findPartitionConsumer(msgID)
+ consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
@@ -587,7 +580,7 @@ func (c *consumer) AckID(msgID MessageID) error {
func (c *consumer) AckIDList(msgIDs []MessageID) error {
return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID)
(acker, error) {
- return c.findPartitionConsumer(msgID)
+ return findPartitionConsumer(c.partitionConsumers(), msgID)
})
}
@@ -600,7 +593,7 @@ func (c *consumer) AckCumulative(msg Message) error {
// AckIDCumulative the reception of all the messages in the stream up to (and
including)
// the provided message, identified by its MessageID
func (c *consumer) AckIDCumulative(msgID MessageID) error {
- consumer, err := c.findPartitionConsumer(msgID)
+ consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
@@ -697,7 +690,7 @@ func (c *consumer) Nack(msg Message) {
mid.NackByMsg(msg)
return
}
- if consumer, err := c.findPartitionConsumer(mid); err == nil {
+ if consumer, err :=
findPartitionConsumer(c.partitionConsumers(), mid); err == nil {
consumer.NackMsg(msg)
}
return
@@ -707,7 +700,7 @@ func (c *consumer) Nack(msg Message) {
}
func (c *consumer) NackID(msgID MessageID) {
- if consumer, err := c.findPartitionConsumer(msgID); err == nil {
+ if consumer, err := findPartitionConsumer(c.partitionConsumers(),
msgID); err == nil {
consumer.NackID(msgID)
}
}
@@ -724,16 +717,14 @@ func (c *consumer) closeWithCause(err error) {
c.closeOnce.Do(func() {
c.stopDiscovery()
- c.Lock()
- defer c.Unlock()
-
var wg sync.WaitGroup
- for i := range c.consumers {
+ consumers := c.partitionConsumers()
+ for i := range consumers {
wg.Add(1)
go func(pc *partitionConsumer) {
defer wg.Done()
pc.Close()
- }(c.consumers[i])
+ }(consumers[i])
}
wg.Wait()
close(c.closeCh)
@@ -741,20 +732,19 @@ func (c *consumer) closeWithCause(err error) {
c.dlq.close()
c.rlq.close()
c.metrics.ConsumersClosed.Inc()
- c.metrics.ConsumersPartitions.Sub(float64(len(c.consumers)))
+ c.metrics.ConsumersPartitions.Sub(float64(len(consumers)))
c.options.Interceptors.OnConsumerClose(c, err)
})
}
func (c *consumer) Seek(msgID MessageID) error {
- c.Lock()
- defer c.Unlock()
+ consumers := c.partitionConsumers()
- if len(c.consumers) > 1 {
+ if len(consumers) > 1 {
return newError(SeekFailed, "for partition topic, seek command
should perform on the individual partitions")
}
- consumer, err := c.unsafeFindPartitionConsumer(msgID)
+ consumer, err := findPartitionConsumer(consumers, msgID)
if err != nil {
return err
}
@@ -768,11 +758,10 @@ func (c *consumer) Seek(msgID MessageID) error {
}
func (c *consumer) SeekByTime(time time.Time) error {
- c.Lock()
- defer c.Unlock()
var errs error
+ consumers := c.partitionConsumers()
- for _, cons := range c.consumers {
+ for _, cons := range consumers {
cons.pauseDispatchMessage()
}
// clear messageCh
@@ -781,7 +770,7 @@ func (c *consumer) SeekByTime(time time.Time) error {
}
// run SeekByTime on every partition of topic
- for _, cons := range c.consumers {
+ for _, cons := range consumers {
if err := cons.SeekByTime(time); err != nil {
msg := fmt.Sprintf("unable to SeekByTime for topic=%s
subscription=%s", c.topic, c.Subscription())
errs = pkgerrors.Wrap(newError(SeekFailed,
err.Error()), msg)
@@ -791,35 +780,30 @@ func (c *consumer) SeekByTime(time time.Time) error {
return errs
}
-func (c *consumer) findPartitionConsumer(msgID MessageID) (*partitionConsumer,
error) {
- c.Lock()
- defer c.Unlock()
- return c.unsafeFindPartitionConsumer(msgID)
-}
-
-// NOTE: This method must be called when c.Lock is held
-func (c *consumer) unsafeFindPartitionConsumer(msgID MessageID)
(*partitionConsumer, error) {
+func findPartitionConsumer(consumers []*partitionConsumer, msgID MessageID)
(*partitionConsumer, error) {
partition := int(msgID.PartitionIdx())
- if partition < 0 || partition >= len(c.consumers) {
- c.log.Errorf("invalid partition index %d expected a partition
between [0-%d]",
- partition, len(c.consumers))
+ if partition < 0 || partition >= len(consumers) {
return nil, fmt.Errorf("invalid partition index %d expected a
partition between [0-%d]",
- partition, len(c.consumers))
+ partition, len(consumers)-1)
+ }
+ return consumers[partition], nil
+}
+
+func (c *consumer) partitionConsumers() []*partitionConsumer {
+ v := c.consumers.Load()
+ if v == nil {
+ return nil
}
- return c.consumers[partition], nil
+ // The slice stored in c.consumers is published via copy-on-write.
+ // Callers must treat the returned slice as immutable.
+ return v.([]*partitionConsumer)
}
func (c *consumer) hasNext() bool {
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // Make sure all paths cancel the context to avoid
context leak
- // We have to make a snapshot consumers, because we have to iterate
over all consumers in
- // other goroutines. But when this method returns, there might be still
other consumers
- // not completing the `hasNext` call, so we cannot just call defer
`c.Unlock()` after acquiring the lock.
- c.Lock()
- consumers := make([]*partitionConsumer, len(c.consumers))
- copy(consumers, c.consumers)
- c.Unlock()
+ consumers := c.partitionConsumers()
var wg sync.WaitGroup
wg.Add(len(consumers))
@@ -853,7 +837,7 @@ func (c *consumer) hasNext() bool {
}
func (c *consumer) setLastDequeuedMsg(msgID MessageID) error {
- consumer, err := c.findPartitionConsumer(msgID)
+ consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
@@ -920,7 +904,7 @@ func toProtoInitialPosition(p SubscriptionInitialPosition)
pb.CommandSubscribe_I
}
func (c *consumer) messageID(msgID MessageID) *trackingMessageID {
- if _, err := c.findPartitionConsumer(msgID); err != nil {
+ if _, err := findPartitionConsumer(c.partitionConsumers(), msgID); err
!= nil {
return nil
}
return toTrackingMessageID(msgID)
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index d7aba524..08e1b36f 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -401,8 +401,8 @@ func (s *schemaInfoCache) add(schemaVersionHash string,
schema Schema) {
}
func newPartitionConsumer(parent Consumer, client *client, options
*partitionConsumerOpts,
- messageCh chan ConsumerMessage, dlq *dlqRouter,
- metrics *internal.LeveledMetrics) (*partitionConsumer, error) {
+ messageCh chan ConsumerMessage, dlq *dlqRouter, metrics
*internal.LeveledMetrics,
+ startDispatcher bool) (*partitionConsumer, error) {
var boFunc func() backoff.Policy
if options.backOffPolicyFunc != nil {
boFunc = options.backOffPolicyFunc
@@ -425,7 +425,7 @@ func newPartitionConsumer(parent Consumer, client *client,
options *partitionCon
queueCh: make(chan []*message,
options.receiverQueueSize),
startMessageID: atomicMessageID{msgID:
options.startMessageID},
seekMessageID: atomicMessageID{msgID: nil},
- connectedCh: make(chan struct{}),
+ connectedCh: make(chan struct{}, 1),
messageCh: messageCh,
connectClosedCh: make(chan *connectionClosed, 1),
closeCh: make(chan struct{}),
@@ -512,13 +512,18 @@ func newPartitionConsumer(parent Consumer, client
*client, options *partitionCon
}
}
- go pc.dispatcher()
-
go pc.runEventsLoop()
+ if startDispatcher {
+ pc.startDispatcher()
+ }
return pc, nil
}
+func (pc *partitionConsumer) startDispatcher() {
+ go pc.dispatcher()
+}
+
func (pc *partitionConsumer) unsubscribe(force bool) error {
if state := pc.getConsumerState(); state == consumerClosed || state ==
consumerClosing {
pc.log.WithField("state", state).Error("Failed to unsubscribe
closing or closed consumer")
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index e1964287..c3fc82d4 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -22,7 +22,9 @@ import (
"errors"
"fmt"
"log"
+ "log/slog"
"net/http"
+ "net/url"
"os"
"regexp"
"strconv"
@@ -4723,7 +4725,7 @@ func TestConsumerWithBackoffPolicy(t *testing.T) {
assert.Nil(t, err)
defer _consumer.Close()
- partitionConsumerImp := _consumer.(*consumer).consumers[0]
+ partitionConsumerImp := _consumer.(*consumer).partitionConsumers()[0]
// 1 s
startTime := time.Now()
partitionConsumerImp.reconnectToBroker(nil)
@@ -4946,7 +4948,7 @@ func TestConsumerWithAutoScaledQueueReceive(t *testing.T)
{
EnableAutoScaledReceiverQueueSize: true,
})
assert.Nil(t, err)
- pc := c.(*consumer).consumers[0]
+ pc := c.(*consumer).partitionConsumers()[0]
assert.Equal(t, int32(1), pc.currentQueueSize.Load())
defer c.Close()
@@ -5161,7 +5163,7 @@ func TestConsumerMemoryLimit(t *testing.T) {
})
assert.Nil(t, err)
defer c1.Close()
- pc1 := c1.(*consumer).consumers[0]
+ pc1 := c1.(*consumer).partitionConsumers()[0]
// Fill up the messageCh of c1
for i := 0; i < 10; i++ {
@@ -5201,7 +5203,7 @@ func TestConsumerMemoryLimit(t *testing.T) {
})
assert.Nil(t, err)
defer c2.Close()
- pc2 := c2.(*consumer).consumers[0]
+ pc2 := c2.(*consumer).partitionConsumers()[0]
// Try to induce c2 receiver queue size expansion
for i := 0; i < 10; i++ {
@@ -5273,7 +5275,7 @@ func TestMultiConsumerMemoryLimit(t *testing.T) {
})
assert.Nil(t, err)
defer c1.Close()
- pc1 := c1.(*consumer).consumers[0]
+ pc1 := c1.(*consumer).partitionConsumers()[0]
// Use mem-limited client 2 to create consumer c1
c2, err := cli2.Subscribe(ConsumerOptions{
@@ -5284,7 +5286,7 @@ func TestMultiConsumerMemoryLimit(t *testing.T) {
})
assert.Nil(t, err)
defer c2.Close()
- pc2 := c2.(*consumer).consumers[0]
+ pc2 := c2.(*consumer).partitionConsumers()[0]
// Fill up the messageCh of c1 nad c2
for i := 0; i < 10; i++ {
@@ -5918,7 +5920,7 @@ func TestSelectConnectionForSameConsumer(t *testing.T) {
assert.NoError(t, err)
defer _consumer.Close()
- partitionConsumerImpl := _consumer.(*consumer).consumers[0]
+ partitionConsumerImpl := _consumer.(*consumer).partitionConsumers()[0]
conn := partitionConsumerImpl._getConn()
for i := 0; i < 5; i++ {
@@ -5928,6 +5930,382 @@ func TestSelectConnectionForSameConsumer(t *testing.T) {
}
}
+func
TestInternalTopicSubscribeToPartitionsDoesNotBlockExistingPartitionLookup(t
*testing.T) {
+ lookupURL, err := url.Parse("pulsar://localhost:6650")
+ require.NoError(t, err)
+
+ allowSubscribe := make(chan struct{})
+ subscribeStarted := make(chan struct{})
+ var releaseSubscribe sync.Once
+
+ logger := slog.New(slog.NewJSONHandler(os.Stdout,
&slog.HandlerOptions{Level: slog.LevelInfo}))
+ log := plog.NewLoggerWithSlog(logger)
+
+ rpcClient := &blockingSubscribeRPCClient{
+ lookupResult: &internal.LookupResult{LogicalAddr:
lookupURL, PhysicalAddr: lookupURL},
+ subscribeStarted: subscribeStarted,
+ allowSubscribe: allowSubscribe,
+ subscribeErr: errors.New("stop subscribe after lookup
check"),
+ nextConsumerID: 1,
+ }
+
+ c :=
newInternalTopicPartitionTestConsumer(internalTopicPartitionTestConsumerOptions{
+ conn: dummyConnection{},
+ rpcClient: rpcClient,
+ partitions: 2,
+ log: log,
+ consumerOptions: ConsumerOptions{SubscriptionName: "test-sub",
NackPrecisionBit: ptr(defaultNackPrecisionBit)},
+ initialConsumers: []*partitionConsumer{{topic:
"persistent://public/default/test-topic-partition-0"}},
+ })
+
+ go func() {
+ c.internalTopicSubscribeToPartitions()
+ }()
+
+ select {
+ case <-subscribeStarted:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for partition discovery to start
subscribing the new partition")
+ }
+
+ lookupErrCh := make(chan error, 1)
+ go func() {
+ _, err := findPartitionConsumer(c.partitionConsumers(),
&messageID{partitionIdx: 0})
+ lookupErrCh <- err
+ }()
+
+ select {
+ case err := <-lookupErrCh:
+ require.NoError(t, err)
+ case <-time.After(3 * time.Second):
+ releaseSubscribe.Do(func() { close(allowSubscribe) })
+ select {
+ case <-lookupErrCh:
+ case <-time.After(time.Second):
+ t.Fatal("existing partition lookup stayed blocked even
after partition discovery stopped")
+ }
+ t.Fatal("existing partition lookup blocked while a new
partition was being added")
+ }
+
+ releaseSubscribe.Do(func() { close(allowSubscribe) })
+}
+
+func
TestInternalTopicSubscribeToPartitionsPublishesConsumersBeforeDispatchingMessages(t
*testing.T) {
+ lookupURL, err := url.Parse("pulsar://localhost:6650")
+ require.NoError(t, err)
+
+ partitionOneSubscribed := make(chan struct{})
+ partitionOneFlowed := make(chan struct{})
+ partitionTwoBlocked := make(chan struct{})
+ allowPartitionTwo := make(chan struct{})
+ cnx := newPartitionExpansionRaceConnection()
+ rpcClient := &partitionExpansionRaceRPCClient{
+ lookupResult: &internal.LookupResult{LogicalAddr:
lookupURL, PhysicalAddr: lookupURL},
+ cnx: cnx,
+ partitionOneSubscribed: partitionOneSubscribed,
+ partitionOneFlowed: partitionOneFlowed,
+ partitionTwoBlocked: partitionTwoBlocked,
+ allowPartitionTwo: allowPartitionTwo,
+ }
+
+ c :=
newInternalTopicPartitionTestConsumer(internalTopicPartitionTestConsumerOptions{
+ conn: cnx,
+ rpcClient: rpcClient,
+ partitions: 3,
+ log: plog.DefaultNopLogger(),
+ consumerOptions: ConsumerOptions{
+ SubscriptionName: "test-sub",
+ ReceiverQueueSize: 1,
+ NackPrecisionBit: ptr(defaultNackPrecisionBit),
+ AckWithResponse: true,
+ },
+ initialConsumers: []*partitionConsumer{{topic:
"persistent://public/default/test-topic-partition-0"}},
+ dlq: &dlqRouter{},
+ })
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- c.internalTopicSubscribeToPartitions()
+ }()
+
+ select {
+ case <-partitionOneSubscribed:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for partition 1 to subscribe")
+ }
+
+ select {
+ case <-partitionTwoBlocked:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for partition 2 subscribe to block")
+ }
+
+ require.Len(t, c.partitionConsumers(), 1)
+ select {
+ case <-partitionOneFlowed:
+ t.Fatal("new partition dispatcher requested permits before
c.consumers contained the new partition")
+ case <-time.After(200 * time.Millisecond):
+ }
+
+ close(allowPartitionTwo)
+
+ select {
+ case err := <-errCh:
+ require.NoError(t, err)
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for partition discovery to finish")
+ }
+ require.Len(t, c.partitionConsumers(), 3)
+
+ select {
+ case <-partitionOneFlowed:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for partition 1 dispatcher to
request permits")
+ }
+
+ handler := cnx.handler(rpcClient.partitionOneConsumerID.Load())
+ require.NotNil(t, handler)
+ err = handler.MessageReceived(&pb.CommandMessage{
+ MessageId: &pb.MessageIdData{
+ LedgerId: proto.Uint64(1),
+ EntryId: proto.Uint64(1),
+ },
+ }, internal.NewBufferWrapper(rawCompatSingleMessage))
+ require.NoError(t, err)
+
+ var cm ConsumerMessage
+ select {
+ case cm = <-c.messageCh:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for the queued partition 1 message
to dispatch")
+ }
+ require.Equal(t, int32(1), cm.Message.ID().PartitionIdx())
+ require.NoError(t, c.AckID(cm.Message.ID()))
+
+ for _, pc := range c.partitionConsumers()[1:] {
+ pc.Close()
+ }
+}
+
+type internalTopicPartitionTestConsumerOptions struct {
+ conn internal.Connection
+ rpcClient internal.RPCClient
+ partitions int
+ log plog.Logger
+ consumerOptions ConsumerOptions
+ initialConsumers []*partitionConsumer
+ dlq *dlqRouter
+}
+
+func newInternalTopicPartitionTestConsumer(opts
internalTopicPartitionTestConsumerOptions) *consumer {
+ var consumers atomic.Value
+ consumers.Store(append([]*partitionConsumer(nil),
opts.initialConsumers...))
+
+ return &consumer{
+ topic: "persistent://public/default/test-topic",
+ client: &client{
+ cnxPool: &blockingConnPool{cnx: opts.conn},
+ rpcClient: opts.rpcClient,
+ lookupService: &partitionMetadataLookup{partitions:
opts.partitions},
+ log: opts.log,
+ },
+ options: opts.consumerOptions,
+ consumers: consumers,
+ messageCh: make(chan ConsumerMessage, 1),
+ closeCh: make(chan struct{}),
+ errorCh: make(chan error, 1),
+ consumerName: "test-consumer",
+ dlq: opts.dlq,
+ log: opts.log,
+ metrics: newTestMetrics(),
+ }
+}
+
+type partitionMetadataLookup struct {
+ internal.LookupService
+ partitions int
+}
+
+func (l *partitionMetadataLookup) GetPartitionedTopicMetadata(_ string)
(*internal.PartitionedTopicMetadata, error) {
+ return &internal.PartitionedTopicMetadata{Partitions: l.partitions}, nil
+}
+
+type blockingConnPool struct {
+ internal.ConnectionPool
+ cnx internal.Connection
+}
+
+func (p *blockingConnPool) GetConnection(_ *url.URL, _ *url.URL, _ int32)
(internal.Connection, error) {
+ return p.cnx, nil
+}
+
+func (p *blockingConnPool) GetConnections() map[string]internal.Connection {
+ return map[string]internal.Connection{}
+}
+
+func (p *blockingConnPool) GenerateRoundRobinIndex() int32 {
+ return 0
+}
+
+func (p *blockingConnPool) Close() {}
+
+type partitionExpansionRaceConnection struct {
+ dummyConnection
+ mu sync.Mutex
+ handlers map[uint64]internal.ConsumerHandler
+}
+
+func newPartitionExpansionRaceConnection() *partitionExpansionRaceConnection {
+ return &partitionExpansionRaceConnection{handlers:
make(map[uint64]internal.ConsumerHandler)}
+}
+
+func (c *partitionExpansionRaceConnection) AddConsumeHandler(id uint64,
handler internal.ConsumerHandler) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.handlers[id] = handler
+ return nil
+}
+
+func (c *partitionExpansionRaceConnection) DeleteConsumeHandler(id uint64) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ delete(c.handlers, id)
+}
+
+func (c *partitionExpansionRaceConnection) handler(id uint64)
internal.ConsumerHandler {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.handlers[id]
+}
+
+type partitionExpansionRaceRPCClient struct {
+ internal.RPCClient
+ lookupResult *internal.LookupResult
+ cnx *partitionExpansionRaceConnection
+ partitionOneSubscribed chan struct{}
+ partitionOneFlowed chan struct{}
+ partitionTwoBlocked chan struct{}
+ allowPartitionTwo chan struct{}
+ requestID atomic.Uint64
+ consumerID atomic.Uint64
+ partitionOneConsumerID atomic.Uint64
+ partitionOneOnce sync.Once
+ partitionOneFlowOnce sync.Once
+ partitionTwoOnce sync.Once
+}
+
+func (r *partitionExpansionRaceRPCClient) NewRequestID() uint64 {
+ return r.requestID.Add(1)
+}
+
+func (r *partitionExpansionRaceRPCClient) NewProducerID() uint64 {
+ return r.requestID.Add(1)
+}
+
+func (r *partitionExpansionRaceRPCClient) NewConsumerID() uint64 {
+ return r.consumerID.Add(1)
+}
+
+func (r *partitionExpansionRaceRPCClient) RequestOnCnxNoWait(
+ _ internal.Connection, cmdType pb.BaseCommand_Type, msg proto.Message,
+) error {
+ if cmdType == pb.BaseCommand_FLOW {
+ flow := msg.(*pb.CommandFlow)
+ if flow.GetConsumerId() == r.partitionOneConsumerID.Load() {
+ r.partitionOneFlowOnce.Do(func() {
close(r.partitionOneFlowed) })
+ }
+ }
+ return nil
+}
+
+func (r *partitionExpansionRaceRPCClient) RequestOnCnx(
+ _ internal.Connection, _ uint64, cmdType pb.BaseCommand_Type, msg
proto.Message,
+) (*internal.RPCResult, error) {
+ switch cmdType {
+ case pb.BaseCommand_SUBSCRIBE:
+ return r.handleSubscribe(msg.(*pb.CommandSubscribe))
+ case pb.BaseCommand_ACK, pb.BaseCommand_CLOSE_CONSUMER:
+ return r.success(), nil
+ default:
+ return nil, fmt.Errorf("unexpected command type %v", cmdType)
+ }
+}
+
+func (r *partitionExpansionRaceRPCClient) handleSubscribe(cmd
*pb.CommandSubscribe) (*internal.RPCResult, error) {
+ switch {
+ case strings.HasSuffix(cmd.GetTopic(), "-partition-1"):
+ r.partitionOneConsumerID.Store(cmd.GetConsumerId())
+ r.partitionOneOnce.Do(func() { close(r.partitionOneSubscribed)
})
+ return r.success(), nil
+ case strings.HasSuffix(cmd.GetTopic(), "-partition-2"):
+ r.partitionTwoOnce.Do(func() { close(r.partitionTwoBlocked) })
+ <-r.allowPartitionTwo
+ return r.success(), nil
+ default:
+ return nil, fmt.Errorf("unexpected subscribe topic %s",
cmd.GetTopic())
+ }
+}
+
+func (r *partitionExpansionRaceRPCClient) success() *internal.RPCResult {
+ successType := pb.BaseCommand_SUCCESS
+ return &internal.RPCResult{
+ Response: &pb.BaseCommand{Type: &successType},
+ Cnx: r.cnx,
+ }
+}
+
+func (r *partitionExpansionRaceRPCClient) LookupService(_ string)
(internal.LookupService, error) {
+ return &grabConnMockLookup{result: r.lookupResult}, nil
+}
+
+type blockingSubscribeRPCClient struct {
+ internal.RPCClient
+ lookupResult *internal.LookupResult
+ subscribeStarted chan struct{}
+ allowSubscribe chan struct{}
+ subscribeErr error
+ nextConsumerID uint64
+ startOnce sync.Once
+}
+
+func (r *blockingSubscribeRPCClient) NewRequestID() uint64 {
+ return 1
+}
+
+func (r *blockingSubscribeRPCClient) NewProducerID() uint64 {
+ return 1
+}
+
+func (r *blockingSubscribeRPCClient) NewConsumerID() uint64 {
+ id := r.nextConsumerID
+ r.nextConsumerID++
+ return id
+}
+
+func (r *blockingSubscribeRPCClient) RequestOnCnxNoWait(
+ _ internal.Connection, _ pb.BaseCommand_Type, _ proto.Message) error {
+ return nil
+}
+
+func (r *blockingSubscribeRPCClient) RequestOnCnx(
+ _ internal.Connection, _ uint64, cmdType pb.BaseCommand_Type, _
proto.Message,
+) (*internal.RPCResult, error) {
+ switch cmdType {
+ case pb.BaseCommand_SUBSCRIBE:
+ r.startOnce.Do(func() { close(r.subscribeStarted) })
+ <-r.allowSubscribe
+ return nil, r.subscribeErr
+ case pb.BaseCommand_CLOSE_CONSUMER:
+ return nil, nil
+ default:
+ return nil, fmt.Errorf("unexpected command type %v", cmdType)
+ }
+}
+
+func (r *blockingSubscribeRPCClient) LookupService(_ string)
(internal.LookupService, error) {
+ return &grabConnMockLookup{result: r.lookupResult}, nil
+}
+
// closeInterceptor captures the (consumer, err) pair delivered to
// ConsumerCloseInterceptor.OnConsumerClose and signals via fired.
type closeInterceptor struct {
@@ -6006,7 +6384,7 @@ func TestConsumerOnCloseInterceptorOnMaxReconnect(t
*testing.T) {
assert.NotNil(t, interceptor.err, "interceptor should receive the cause
of the close")
assert.Equal(t, testConsumer, interceptor.consumer, "interceptor should
receive the parent consumer")
- pc := testConsumer.(*consumer).consumers[0]
+ pc := testConsumer.(*consumer).partitionConsumers()[0]
require.Eventually(t, func() bool {
return pc.getConsumerState() == consumerClosed
}, 30*time.Second, 100*time.Millisecond, "consumer should be closed
after exhausting max reconnect retries")
diff --git a/pulsar/consumer_zero_queue.go b/pulsar/consumer_zero_queue.go
index 4978fae2..20a0944e 100644
--- a/pulsar/consumer_zero_queue.go
+++ b/pulsar/consumer_zero_queue.go
@@ -80,7 +80,7 @@ func newZeroConsumer(client *client, options ConsumerOptions,
topic string,
pc.availablePermits.inc()
}
}
- pc, err := newPartitionConsumer(zc, zc.client, opts, zc.messageCh,
zc.dlq, zc.metrics)
+ pc, err := newPartitionConsumer(zc, zc.client, opts, zc.messageCh,
zc.dlq, zc.metrics, true)
if err != nil {
return nil, err
}
diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go
index 12f0517c..2d6462e7 100644
--- a/pulsar/message_chunking_test.go
+++ b/pulsar/message_chunking_test.go
@@ -178,7 +178,7 @@ func TestMaxPendingChunkMessages(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, c)
defer c.Close()
- pc := c.(*consumer).consumers[0]
+ pc := c.(*consumer).partitionConsumers()[0]
sendSingleChunk(producer, "0", 0, 2)
// MaxPendingChunkedMessage is 1, the chunked message with uuid 0 will
be discarded
@@ -228,7 +228,7 @@ func TestExpireIncompleteChunks(t *testing.T) {
defer c.Close()
uuid := "test-uuid"
- chunkCtxMap := c.(*consumer).consumers[0].chunkedMsgCtxMap
+ chunkCtxMap := c.(*consumer).partitionConsumers()[0].chunkedMsgCtxMap
chunkCtxMap.addIfAbsent(uuid, 2, 100)
ctx := chunkCtxMap.get(uuid)
assert.NotNil(t, ctx)
diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go
index 55b05037..ad65c285 100644
--- a/pulsar/reader_impl.go
+++ b/pulsar/reader_impl.go
@@ -224,8 +224,9 @@ func (r *reader) SeekByTime(time time.Time) error {
}
func (r *reader) GetLastMessageID() (MessageID, error) {
- if len(r.c.consumers) > 1 {
+ consumers := r.c.partitionConsumers()
+ if len(consumers) > 1 {
return nil, fmt.Errorf("GetLastMessageID is not supported for
multi-topics reader")
}
- return r.c.consumers[0].getLastMessageID()
+ return consumers[0].getLastMessageID()
}
diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go
index 29d61b2b..d8ebc240 100644
--- a/pulsar/reader_test.go
+++ b/pulsar/reader_test.go
@@ -952,7 +952,7 @@ func TestReaderWithBackoffPolicy(t *testing.T) {
assert.NotNil(t, _reader)
assert.Nil(t, err)
- partitionConsumerImp := _reader.(*reader).c.consumers[0]
+ partitionConsumerImp := _reader.(*reader).c.partitionConsumers()[0]
// 1 s
startTime := time.Now()
partitionConsumerImp.reconnectToBroker(nil)
@@ -1061,7 +1061,7 @@ func TestReaderHasNextFailed(t *testing.T) {
StartMessageID: EarliestMessageID(),
})
assert.Nil(t, err)
- r.(*reader).c.consumers[0].state.Store(consumerClosing)
+ r.(*reader).c.partitionConsumers()[0].state.Store(consumerClosing)
assert.False(t, r.HasNext())
}
@@ -1082,7 +1082,7 @@ func TestReaderHasNextRetryFailed(t *testing.T) {
defer close(c)
// Close the consumer events loop and assign a mock eventsCh
- pc := r.(*reader).c.consumers[0]
+ pc := r.(*reader).c.partitionConsumers()[0]
pc.Close()
pc.state.Store(consumerReady)
pc.eventsCh = c