This is an automated email from the ASF dual-hosted git repository. mattisonchao pushed a commit to branch codex/fix-consumer-ack-partition-lookup in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
commit 387d0f63f93e98b8cf76c549a50a05bf2072c1f4 Author: mattisonchao <[email protected]> AuthorDate: Mon May 11 22:55:35 2026 +0800 fix: guard partition consumer lookup during ack --- pulsar/consumer_impl.go | 42 +++++++++++++++++++++++++++++------------- pulsar/consumer_multitopic.go | 28 ++++++++++++++++++++++------ pulsar/consumer_test.go | 18 ++++++++++++++++++ 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index ece37370..b33cf545 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -49,7 +49,7 @@ type acker interface { } type consumer struct { - sync.Mutex + sync.RWMutex topic string client *client options ConsumerOptions @@ -549,11 +549,12 @@ func (c *consumer) Receive(ctx context.Context) (message Message, err error) { func (c *consumer) AckWithTxn(msg Message, txn Transaction) error { msgID := msg.ID() - if err := c.checkMsgIDPartition(msgID); err != nil { + consumer, err := c.getPartitionConsumer(msgID) + if err != nil { return err } - return c.consumers[msgID.PartitionIdx()].AckIDWithTxn(msgID, txn) + return consumer.AckIDWithTxn(msgID, txn) } // Chan return the message chan to users @@ -568,23 +569,21 @@ func (c *consumer) Ack(msg Message) error { // AckID the consumption of a single message, identified by its MessageID func (c *consumer) AckID(msgID MessageID) error { - if err := c.checkMsgIDPartition(msgID); err != nil { + consumer, err := c.getPartitionConsumer(msgID) + if err != nil { return err } if c.options.AckWithResponse { - return c.consumers[msgID.PartitionIdx()].AckIDWithResponse(msgID) + return consumer.AckIDWithResponse(msgID) } - return c.consumers[msgID.PartitionIdx()].AckID(msgID) + return consumer.AckID(msgID) } func (c *consumer) AckIDList(msgIDs []MessageID) error { return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) { - if err := c.checkMsgIDPartition(msgID); err != nil { - return nil, err - } - return c.consumers[msgID.PartitionIdx()], nil + return c.getPartitionConsumer(msgID) }) } @@ -597,15 +596,16 @@ func (c *consumer) AckCumulative(msg Message) error { // AckIDCumulative the reception of all the messages in the stream up to (and including) // the provided message, identified by its MessageID func (c *consumer) AckIDCumulative(msgID MessageID) error { - if err := c.checkMsgIDPartition(msgID); err != nil { + consumer, err := c.getPartitionConsumer(msgID) + if err != nil { return err } if c.options.AckWithResponse { - return c.consumers[msgID.PartitionIdx()].AckIDWithResponseCumulative(msgID) + return consumer.AckIDWithResponseCumulative(msgID) } - return c.consumers[msgID.PartitionIdx()].AckIDCumulative(msgID) + return consumer.AckIDCumulative(msgID) } // ReconsumeLater mark a message for redelivery after custom delay @@ -792,6 +792,22 @@ func (c *consumer) checkMsgIDPartition(msgID MessageID) error { return nil } +func (c *consumer) getPartitionConsumer(msgID MessageID) (*partitionConsumer, error) { + c.RLock() + defer c.RUnlock() + + if err := c.checkMsgIDPartition(msgID); err != nil { + return nil, err + } + + consumer := c.consumers[msgID.PartitionIdx()] + if consumer == nil { + return nil, fmt.Errorf("partition consumer is nil for partition %d", msgID.PartitionIdx()) + } + + return consumer, nil +} + func (c *consumer) hasNext() bool { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Make sure all paths cancel the context to avoid context leak diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go index 3e4c8ebd..59030216 100644 --- a/pulsar/consumer_multitopic.go +++ b/pulsar/consumer_multitopic.go @@ -214,26 +214,42 @@ func (c *multiTopicConsumer) AckIDList(msgIDs []MessageID) error { func ackIDListFromMultiTopics(log log.Logger, msgIDs []MessageID, findConsumer func(MessageID) (acker, error)) error { consumerToMsgIDs := make(map[acker][]MessageID) + ackError := AckError{} for _, msgID := range msgIDs { if consumer, err := findConsumer(msgID); err == nil { consumerToMsgIDs[consumer] = append(consumerToMsgIDs[consumer], msgID) } else { log.Warnf("Can not find consumer for %v", msgID) + ackError[msgID] = err } } - subErrCh := make(chan error, len(consumerToMsgIDs)) + type ackResult struct { + ids []MessageID + err error + } + subErrCh := make(chan ackResult, len(consumerToMsgIDs)) for consumer, ids := range consumerToMsgIDs { - go func() { - subErrCh <- consumer.AckIDList(ids) - }() + go func(consumer acker, ids []MessageID) { + subErrCh <- ackResult{ + ids: ids, + err: consumer.AckIDList(ids), + } + }(consumer, ids) } - ackError := AckError{} for i := 0; i < len(consumerToMsgIDs); i++ { - if topicAckError, ok := (<-subErrCh).(AckError); ok { + result := <-subErrCh + if result.err == nil { + continue + } + if topicAckError, ok := result.err.(AckError); ok { for id, err := range topicAckError { ackError[id] = err } + continue + } + for _, id := range result.ids { + ackError[id] = result.err } } if len(ackError) == 0 { diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go index 4e232b8a..a8bb4576 100644 --- a/pulsar/consumer_test.go +++ b/pulsar/consumer_test.go @@ -5583,6 +5583,24 @@ func TestAckIDList(t *testing.T) { } } +func TestAckIDListReturnsErrorForNilPartitionConsumer(t *testing.T) { + msgID := newMessageID(1, 2, -1, 0, 0) + consumer := &consumer{ + consumers: []*partitionConsumer{nil}, + log: plog.DefaultNopLogger(), + } + + require.NotPanics(t, func() { + err := consumer.AckIDList([]MessageID{msgID}) + require.Error(t, err) + + ackError, ok := err.(AckError) + require.True(t, ok) + require.Contains(t, ackError, msgID) + require.ErrorContains(t, ackError[msgID], "partition consumer is nil for partition 0") + }) +} + func getAckCount(registry *prometheus.Registry) (int, error) { metrics, err := registry.Gather() if err != nil {
