This is an automated email from the ASF dual-hosted git repository.

mattisonchao pushed a commit to branch codex/fix-consumer-ack-partition-lookup
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git

commit 387d0f63f93e98b8cf76c549a50a05bf2072c1f4
Author: mattisonchao <[email protected]>
AuthorDate: Mon May 11 22:55:35 2026 +0800

    fix: guard partition consumer lookup during ack
---
 pulsar/consumer_impl.go       | 42 +++++++++++++++++++++++++++++-------------
 pulsar/consumer_multitopic.go | 28 ++++++++++++++++++++++------
 pulsar/consumer_test.go       | 18 ++++++++++++++++++
 3 files changed, 69 insertions(+), 19 deletions(-)

diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index ece37370..b33cf545 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -49,7 +49,7 @@ type acker interface {
 }
 
 type consumer struct {
-       sync.Mutex
+       sync.RWMutex
        topic                     string
        client                    *client
        options                   ConsumerOptions
@@ -549,11 +549,12 @@ func (c *consumer) Receive(ctx context.Context) (message 
Message, err error) {
 
 func (c *consumer) AckWithTxn(msg Message, txn Transaction) error {
        msgID := msg.ID()
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.getPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
 
-       return c.consumers[msgID.PartitionIdx()].AckIDWithTxn(msgID, txn)
+       return consumer.AckIDWithTxn(msgID, txn)
 }
 
 // Chan return the message chan to users
@@ -568,23 +569,21 @@ func (c *consumer) Ack(msg Message) error {
 
 // AckID the consumption of a single message, identified by its MessageID
 func (c *consumer) AckID(msgID MessageID) error {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.getPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
 
        if c.options.AckWithResponse {
-               return 
c.consumers[msgID.PartitionIdx()].AckIDWithResponse(msgID)
+               return consumer.AckIDWithResponse(msgID)
        }
 
-       return c.consumers[msgID.PartitionIdx()].AckID(msgID)
+       return consumer.AckID(msgID)
 }
 
 func (c *consumer) AckIDList(msgIDs []MessageID) error {
        return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) 
(acker, error) {
-               if err := c.checkMsgIDPartition(msgID); err != nil {
-                       return nil, err
-               }
-               return c.consumers[msgID.PartitionIdx()], nil
+               return c.getPartitionConsumer(msgID)
        })
 }
 
@@ -597,15 +596,16 @@ func (c *consumer) AckCumulative(msg Message) error {
 // AckIDCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message, identified by its MessageID
 func (c *consumer) AckIDCumulative(msgID MessageID) error {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.getPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
 
        if c.options.AckWithResponse {
-               return 
c.consumers[msgID.PartitionIdx()].AckIDWithResponseCumulative(msgID)
+               return consumer.AckIDWithResponseCumulative(msgID)
        }
 
-       return c.consumers[msgID.PartitionIdx()].AckIDCumulative(msgID)
+       return consumer.AckIDCumulative(msgID)
 }
 
 // ReconsumeLater mark a message for redelivery after custom delay
@@ -792,6 +792,22 @@ func (c *consumer) checkMsgIDPartition(msgID MessageID) 
error {
        return nil
 }
 
+func (c *consumer) getPartitionConsumer(msgID MessageID) (*partitionConsumer, 
error) {
+       c.RLock()
+       defer c.RUnlock()
+
+       if err := c.checkMsgIDPartition(msgID); err != nil {
+               return nil, err
+       }
+
+       consumer := c.consumers[msgID.PartitionIdx()]
+       if consumer == nil {
+               return nil, fmt.Errorf("partition consumer is nil for partition 
%d", msgID.PartitionIdx())
+       }
+
+       return consumer, nil
+}
+
 func (c *consumer) hasNext() bool {
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel() // Make sure all paths cancel the context to avoid 
context leak
diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go
index 3e4c8ebd..59030216 100644
--- a/pulsar/consumer_multitopic.go
+++ b/pulsar/consumer_multitopic.go
@@ -214,26 +214,42 @@ func (c *multiTopicConsumer) AckIDList(msgIDs 
[]MessageID) error {
 
 func ackIDListFromMultiTopics(log log.Logger, msgIDs []MessageID, findConsumer 
func(MessageID) (acker, error)) error {
        consumerToMsgIDs := make(map[acker][]MessageID)
+       ackError := AckError{}
        for _, msgID := range msgIDs {
                if consumer, err := findConsumer(msgID); err == nil {
                        consumerToMsgIDs[consumer] = 
append(consumerToMsgIDs[consumer], msgID)
                } else {
                        log.Warnf("Can not find consumer for %v", msgID)
+                       ackError[msgID] = err
                }
        }
 
-       subErrCh := make(chan error, len(consumerToMsgIDs))
+       type ackResult struct {
+               ids []MessageID
+               err error
+       }
+       subErrCh := make(chan ackResult, len(consumerToMsgIDs))
        for consumer, ids := range consumerToMsgIDs {
-               go func() {
-                       subErrCh <- consumer.AckIDList(ids)
-               }()
+               go func(consumer acker, ids []MessageID) {
+                       subErrCh <- ackResult{
+                               ids: ids,
+                               err: consumer.AckIDList(ids),
+                       }
+               }(consumer, ids)
        }
-       ackError := AckError{}
        for i := 0; i < len(consumerToMsgIDs); i++ {
-               if topicAckError, ok := (<-subErrCh).(AckError); ok {
+               result := <-subErrCh
+               if result.err == nil {
+                       continue
+               }
+               if topicAckError, ok := result.err.(AckError); ok {
                        for id, err := range topicAckError {
                                ackError[id] = err
                        }
+                       continue
+               }
+               for _, id := range result.ids {
+                       ackError[id] = result.err
                }
        }
        if len(ackError) == 0 {
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index 4e232b8a..a8bb4576 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -5583,6 +5583,24 @@ func TestAckIDList(t *testing.T) {
        }
 }
 
+func TestAckIDListReturnsErrorForNilPartitionConsumer(t *testing.T) {
+       msgID := newMessageID(1, 2, -1, 0, 0)
+       consumer := &consumer{
+               consumers: []*partitionConsumer{nil},
+               log:       plog.DefaultNopLogger(),
+       }
+
+       require.NotPanics(t, func() {
+               err := consumer.AckIDList([]MessageID{msgID})
+               require.Error(t, err)
+
+               ackError, ok := err.(AckError)
+               require.True(t, ok)
+               require.Contains(t, ackError, msgID)
+               require.ErrorContains(t, ackError[msgID], "partition consumer 
is nil for partition 0")
+       })
+}
+
 func getAckCount(registry *prometheus.Registry) (int, error) {
        metrics, err := registry.Gather()
        if err != nil {

Reply via email to