This is an automated email from the ASF dual-hosted git repository.
zixuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push:
new ba83732a Fix wrong result of reader.hasNext/Next after seeking by id
or time (#1340)
ba83732a is described below
commit ba83732a596a78ddda42a02cadd74a0bb77123de
Author: Baodi Shi <[email protected]>
AuthorDate: Mon Mar 10 14:46:21 2025 +0800
Fix wrong result of reader.hasNext/Next after seeking by id or time (#1340)
* Fix wrong result of hasNext after seeking by id or time
* fix unit test
* Address code reviews.
* Address code review
* Add annotation to StartMessageIDInclusive
---
pulsar/consumer.go | 1 +
pulsar/consumer_partition.go | 98 +++++++++++++++++-----
pulsar/consumer_test.go | 5 +-
pulsar/consumer_zero_queue_test.go | 1 +
pulsar/impl_message.go | 8 ++
pulsar/reader.go | 1 +
pulsar/reader_impl.go | 22 ++---
pulsar/reader_test.go | 167 +++++++++++++++++++++++++++++++++++++
8 files changed, 262 insertions(+), 41 deletions(-)
diff --git a/pulsar/consumer.go b/pulsar/consumer.go
index 7aee9645..d611c691 100644
--- a/pulsar/consumer.go
+++ b/pulsar/consumer.go
@@ -261,6 +261,7 @@ type ConsumerOptions struct {
SubscriptionMode SubscriptionMode
// StartMessageIDInclusive, if true, the consumer will start at the
`StartMessageID`, included.
+ // Note: This configuration also affects the seek operation.
// Default is `false` and the consumer will start from the "next"
message
StartMessageIDInclusive bool
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index aa922af2..cc2554c0 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -190,9 +190,13 @@ type partitionConsumer struct {
backoffPolicyFunc func() backoff.Policy
dispatcherSeekingControlCh chan struct{}
- isSeeking atomic.Bool
- ctx context.Context
- cancelFunc context.CancelFunc
+ // handle to the dispatcher goroutine
+ isSeeking atomic.Bool
+ // After executing seekByTime, the client is unaware of the
startMessageId.
+ // Use this flag to compare markDeletePosition with BrokerLastMessageId
when checking hasMoreMessages.
+ hasSoughtByTime atomic.Bool
+ ctx context.Context
+ cancelFunc context.CancelFunc
}
// pauseDispatchMessage used to discard the message in the dispatcher
goroutine.
@@ -429,11 +433,12 @@ func newPartitionConsumer(parent Consumer, client
*client, options *partitionCon
startingMessageID := pc.startMessageID.get()
if pc.options.startMessageIDInclusive && startingMessageID != nil &&
startingMessageID.equal(latestMessageID) {
- msgID, err := pc.requestGetLastMessageID()
+ msgIDResp, err := pc.requestGetLastMessageID()
if err != nil {
pc.Close()
return nil, err
}
+ msgID := convertToMessageID(msgIDResp.GetLastMessageId())
if msgID.entryID != noMessageEntry {
pc.startMessageID.set(msgID)
@@ -616,18 +621,27 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub
*unsubscribeRequest) {
}
func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) {
+ res, err := pc.getLastMessageIDAndMarkDeletePosition()
+ if err != nil {
+ return nil, err
+ }
+ return res.msgID, err
+}
+
+func (pc *partitionConsumer) getLastMessageIDAndMarkDeletePosition()
(*getLastMsgIDResult, error) {
if state := pc.getConsumerState(); state == consumerClosed || state ==
consumerClosing {
pc.log.WithField("state", state).Error("Failed to
getLastMessageID for the closing or closed consumer")
return nil, errors.New("failed to getLastMessageID for the
closing or closed consumer")
}
bo := pc.backoffPolicyFunc()
- request := func() (*trackingMessageID, error) {
+ request := func() (*getLastMsgIDResult, error) {
req := &getLastMsgIDRequest{doneCh: make(chan struct{})}
pc.eventsCh <- req
// wait for the request to complete
<-req.doneCh
- return req.msgID, req.err
+ res := &getLastMsgIDResult{req.msgID, req.markDeletePosition}
+ return res, req.err
}
ctx, cancel := context.WithTimeout(context.Background(),
pc.client.operationTimeout)
@@ -647,10 +661,16 @@ func (pc *partitionConsumer) getLastMessageID()
(*trackingMessageID, error) {
func (pc *partitionConsumer) internalGetLastMessageID(req
*getLastMsgIDRequest) {
defer close(req.doneCh)
- req.msgID, req.err = pc.requestGetLastMessageID()
+ rsp, err := pc.requestGetLastMessageID()
+ if err != nil {
+ req.err = err
+ return
+ }
+ req.msgID = convertToMessageID(rsp.GetLastMessageId())
+ req.markDeletePosition =
convertToMessageID(rsp.GetConsumerMarkDeletePosition())
}
-func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID,
error) {
+func (pc *partitionConsumer) requestGetLastMessageID()
(*pb.CommandGetLastMessageIdResponse, error) {
if state := pc.getConsumerState(); state == consumerClosed || state ==
consumerClosing {
pc.log.WithField("state", state).Error("Failed to
getLastMessageID closing or closed consumer")
return nil, errors.New("failed to getLastMessageID closing or
closed consumer")
@@ -667,8 +687,7 @@ func (pc *partitionConsumer) requestGetLastMessageID()
(*trackingMessageID, erro
pc.log.WithError(err).Error("Failed to get last message id")
return nil, err
}
- id := res.Response.GetLastMessageIdResponse.GetLastMessageId()
- return convertToMessageID(id), nil
+ return res.Response.GetLastMessageIdResponse, nil
}
func (pc *partitionConsumer) sendIndividualAck(msgID MessageID) *ackRequest {
@@ -997,7 +1016,15 @@ func (pc *partitionConsumer) requestSeek(msgID
*messageID) error {
if err := pc.requestSeekWithoutClear(msgID); err != nil {
return err
}
- pc.clearReceiverQueue()
+ // When the seek operation is successful, it indicates:
+ // 1. The broker has reset the cursor and sent a request to close the
consumer on the client side.
+ // Since this method is in the same goroutine as the
reconnectToBroker,
+ // we can safely clear the messages in the queue (at this point, it
won't contain messages after the seek).
+ // 2. The startMessageID is reset to ensure accurate judgment when
calling hasNext next time.
+ // Since the messages in the queue are cleared here reconnection
won't reset startMessageId.
+ pc.lastDequeuedMsg = nil
+ pc.startMessageID.set(toTrackingMessageID(msgID))
+ pc.clearQueueAndGetNextMessage()
return nil
}
@@ -1069,7 +1096,9 @@ func (pc *partitionConsumer) internalSeekByTime(seek
*seekByTimeRequest) {
seek.err = err
return
}
- pc.clearReceiverQueue()
+ pc.lastDequeuedMsg = nil
+ pc.hasSoughtByTime.Store(true)
+ pc.clearQueueAndGetNextMessage()
}
func (pc *partitionConsumer) internalAck(req *ackRequest) {
@@ -1451,10 +1480,6 @@ func (pc *partitionConsumer)
messageShouldBeDiscarded(msgID *trackingMessageID)
if pc.startMessageID.get() == nil {
return false
}
- // if we start at latest message, we should never discard
- if pc.options.startMessageID != nil &&
pc.options.startMessageID.equal(latestMessageID) {
- return false
- }
if pc.options.startMessageIDInclusive {
return pc.startMessageID.get().greater(msgID.messageID)
@@ -1709,9 +1734,15 @@ type redeliveryRequest struct {
}
type getLastMsgIDRequest struct {
- doneCh chan struct{}
- msgID *trackingMessageID
- err error
+ doneCh chan struct{}
+ msgID *trackingMessageID
+ markDeletePosition *trackingMessageID
+ err error
+}
+
+type getLastMsgIDResult struct {
+ msgID *trackingMessageID
+ markDeletePosition *trackingMessageID
}
type seekRequest struct {
@@ -2200,6 +2231,25 @@ func (pc *partitionConsumer)
discardCorruptedMessage(msgID *pb.MessageIdData,
}
func (pc *partitionConsumer) hasNext() bool {
+
+ // If a seek by time has been performed, then the `startMessageId`
becomes irrelevant.
+ // We need to compare `markDeletePosition` and `lastMessageId`,
+ // and then reset `startMessageID` to `markDeletePosition`.
+ if pc.lastDequeuedMsg == nil && pc.hasSoughtByTime.CompareAndSwap(true,
false) {
+ res, err := pc.getLastMessageIDAndMarkDeletePosition()
+ if err != nil {
+ pc.log.WithError(err).Error("Failed to get last message
id")
+ pc.hasSoughtByTime.CompareAndSwap(false, true)
+ return false
+ }
+ pc.lastMessageInBroker = res.msgID
+ pc.startMessageID.set(res.markDeletePosition)
+ // We only care about comparing ledger ids and entry ids as
mark delete position
+ // doesn't have other ids such as batch index
+ compareResult :=
pc.lastMessageInBroker.messageID.compareLedgerAndEntryID(pc.startMessageID.get().messageID)
+ return compareResult > 0 || (pc.options.startMessageIDInclusive
&& compareResult == 0)
+ }
+
if pc.lastMessageInBroker != nil && pc.hasMoreMessages() {
return true
}
@@ -2261,12 +2311,14 @@ func convertToMessageID(id *pb.MessageIdData)
*trackingMessageID {
msgID := &trackingMessageID{
messageID: &messageID{
- ledgerID: int64(*id.LedgerId),
- entryID: int64(*id.EntryId),
+ ledgerID: int64(id.GetLedgerId()),
+ entryID: int64(id.GetEntryId()),
+ batchIdx: id.GetBatchIndex(),
+ batchSize: id.GetBatchSize(),
},
}
- if id.BatchIndex != nil {
- msgID.batchIdx = *id.BatchIndex
+ if msgID.batchIdx == -1 {
+ msgID.batchIdx = 0
}
return msgID
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index edef4404..43ce3c75 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -1262,8 +1262,9 @@ func TestConsumerSeek(t *testing.T) {
defer producer.Close()
consumer, err := client.Subscribe(ConsumerOptions{
- Topic: topicName,
- SubscriptionName: "sub-1",
+ Topic: topicName,
+ SubscriptionName: "sub-1",
+ StartMessageIDInclusive: true,
})
assert.Nil(t, err)
defer consumer.Close()
diff --git a/pulsar/consumer_zero_queue_test.go
b/pulsar/consumer_zero_queue_test.go
index 34e9df9f..2e7c4b27 100644
--- a/pulsar/consumer_zero_queue_test.go
+++ b/pulsar/consumer_zero_queue_test.go
@@ -474,6 +474,7 @@ func TestZeroQueueConsumer_Seek(t *testing.T) {
Topic: topicName,
EnableZeroQueueConsumer: true,
SubscriptionName: "sub-1",
+ StartMessageIDInclusive: true,
})
assert.Nil(t, err)
_, ok := consumer.(*zeroQueueConsumer)
diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go
index 0acd782b..f32cc7fc 100644
--- a/pulsar/impl_message.go
+++ b/pulsar/impl_message.go
@@ -18,6 +18,7 @@
package pulsar
import (
+ "cmp"
"errors"
"fmt"
"math"
@@ -147,6 +148,13 @@ func (id *messageID) equal(other *messageID) bool {
id.batchIdx == other.batchIdx
}
+func (id *messageID) compareLedgerAndEntryID(other *messageID) int {
+ if result := cmp.Compare(id.ledgerID, other.ledgerID); result != 0 {
+ return result
+ }
+ return cmp.Compare(id.entryID, other.entryID)
+}
+
func (id *messageID) greaterEqual(other *messageID) bool {
return id.equal(other) || id.greater(other)
}
diff --git a/pulsar/reader.go b/pulsar/reader.go
index 98bde4e3..95aed0e5 100644
--- a/pulsar/reader.go
+++ b/pulsar/reader.go
@@ -53,6 +53,7 @@ type ReaderOptions struct {
StartMessageID MessageID
// StartMessageIDInclusive, if true, the reader will start at the
`StartMessageID`, included.
+ // Note: This configuration also affects the seek operation.
// Default is `false` and the reader will start from the "next" message
StartMessageIDInclusive bool
diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go
index f76255e2..55b05037 100644
--- a/pulsar/reader_impl.go
+++ b/pulsar/reader_impl.go
@@ -196,19 +196,6 @@ func (r *reader) Close() {
r.metrics.ReadersClosed.Inc()
}
-func (r *reader) messageID(msgID MessageID) *trackingMessageID {
- mid := toTrackingMessageID(msgID)
-
- partition := int(mid.partitionIdx)
- // did we receive a valid partition index?
- if partition < 0 {
- r.log.Warnf("invalid partition index %d expected", partition)
- return nil
- }
-
- return mid
-}
-
func (r *reader) Seek(msgID MessageID) error {
r.Lock()
defer r.Unlock()
@@ -218,9 +205,12 @@ func (r *reader) Seek(msgID MessageID) error {
return fmt.Errorf("invalid message id type %T", msgID)
}
- mid := r.messageID(msgID)
- if mid == nil {
- return nil
+ mid := toTrackingMessageID(msgID)
+
+ partition := int(mid.partitionIdx)
+ if partition < 0 {
+ r.log.Warnf("invalid partition index %d expected", partition)
+ return fmt.Errorf("seek msgId must include partitoinIndex")
}
return r.c.Seek(mid)
diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go
index 83653570..2fbe89e4 100644
--- a/pulsar/reader_test.go
+++ b/pulsar/reader_test.go
@@ -1070,3 +1070,170 @@ func TestReaderNextReturnsOnClosedConsumer(t
*testing.T) {
assert.ErrorAs(t, err, &e)
assert.Equal(t, ConsumerClosed, e.Result())
}
+
+func testReaderSeekByIDWithHasNext(t *testing.T, startMessageID MessageID,
startMessageIDInclusive bool) {
+ client, err := NewClient(ClientOptions{
+ URL: lookupURL,
+ })
+
+ assert.Nil(t, err)
+ defer client.Close()
+
+ topic := newTopicName()
+ ctx := context.Background()
+
+ // create producer
+ producer, err := client.CreateProducer(ProducerOptions{
+ Topic: topic,
+ DisableBatching: true,
+ })
+ assert.Nil(t, err)
+ defer producer.Close()
+
+ // send 100 messages
+ var lastMsgID MessageID
+ for i := 0; i < 10; i++ {
+ lastMsgID, err = producer.Send(ctx, &ProducerMessage{
+ Payload: []byte(fmt.Sprintf("hello-%d", i)),
+ })
+ assert.NoError(t, err)
+ assert.NotNil(t, lastMsgID)
+ }
+
+ reader, err := client.CreateReader(ReaderOptions{
+ Topic: topic,
+ StartMessageID: startMessageID,
+ StartMessageIDInclusive: startMessageIDInclusive,
+ })
+ assert.Nil(t, err)
+ defer reader.Close()
+
+ // Seek to last message ID
+ err = reader.Seek(lastMsgID)
+ assert.NoError(t, err)
+
+ if startMessageIDInclusive {
+ assert.True(t, reader.HasNext())
+ ctx, cancel := context.WithTimeout(context.Background(),
1*time.Second)
+ msg, err := reader.Next(ctx)
+ assert.NoError(t, err)
+ assert.NotNil(t, msg)
+ assert.True(t, messageIDCompare(lastMsgID, msg.ID()) == 0)
+ cancel()
+ } else {
+ assert.False(t, reader.HasNext())
+ ctx, cancel := context.WithTimeout(context.Background(),
1*time.Second)
+ msg, err := reader.Next(ctx)
+ assert.Error(t, err)
+ assert.Nil(t, msg)
+ cancel()
+ }
+
+}
+
+func TestReaderWithSeekByID(t *testing.T) {
+ params := []struct {
+ messageID MessageID
+ startMessageIDInclusive bool
+ }{
+ {EarliestMessageID(), false},
+ {EarliestMessageID(), true},
+ {LatestMessageID(), false},
+ {LatestMessageID(), true},
+ }
+
+ for _, c := range params {
+ t.Run(fmt.Sprintf("TestReaderSeekByID_%v_%v", c.messageID,
c.startMessageIDInclusive),
+ func(t *testing.T) {
+ testReaderSeekByIDWithHasNext(t, c.messageID,
c.startMessageIDInclusive)
+ })
+ }
+}
+
+func testReaderSeekByTimeWithHasNext(t *testing.T, startMessageID MessageID) {
+ client, err := NewClient(ClientOptions{
+ URL: lookupURL,
+ })
+
+ assert.Nil(t, err)
+ defer client.Close()
+
+ topic := newTopicName()
+ ctx := context.Background()
+
+ // create producer
+ producer, err := client.CreateProducer(ProducerOptions{
+ Topic: topic,
+ DisableBatching: true,
+ })
+ assert.Nil(t, err)
+ defer producer.Close()
+
+ // 1. send 10 messages
+ var lastMsgID MessageID
+ for i := 0; i < 10; i++ {
+ lastMsgID, err = producer.Send(ctx, &ProducerMessage{
+ Payload: []byte(fmt.Sprintf("hello-%d", i)),
+ })
+ assert.NoError(t, err)
+
+ assert.NotNil(t, lastMsgID)
+ }
+
+ // 2. create reader
+ reader, err := client.CreateReader(ReaderOptions{
+ Topic: topic,
+ StartMessageID: startMessageID,
+ StartMessageIDInclusive: false,
+ })
+ assert.Nil(t, err)
+ defer reader.Close()
+
+ // 3. Seek time to now
+ reader.SeekByTime(time.Now())
+
+ // 4. Should not receive msg
+ {
+ assert.False(t, reader.HasNext())
+ timeoutCtx, cancel := context.WithTimeout(context.Background(),
1*time.Second)
+ msg, err := reader.Next(timeoutCtx)
+ assert.Error(t, err)
+ assert.Nil(t, msg)
+ cancel()
+ }
+
+ // 5. send more 10 messages
+ for i := 0; i < 10; i++ {
+ lastMsgID, err = producer.Send(ctx, &ProducerMessage{
+ Payload: []byte(fmt.Sprintf("hello2-%d", i)),
+ })
+ assert.NoError(t, err)
+ assert.NotNil(t, lastMsgID)
+ }
+
+ // 6. Assert these messages are received
+ for i := 0; i < 10; i++ {
+ assert.True(t, reader.HasNext())
+ msg, err := reader.Next(context.Background())
+ assert.NoError(t, err)
+ assert.Equal(t, fmt.Sprintf("hello2-%d", i),
string(msg.Payload()))
+ }
+
+ // assert not more msg
+ {
+ assert.False(t, reader.HasNext())
+ timeoutCtx, cancel := context.WithTimeout(context.Background(),
1*time.Second)
+ msg, err := reader.Next(timeoutCtx)
+ assert.Error(t, err)
+ assert.Nil(t, msg)
+ cancel()
+ }
+}
+func TestReaderWithSeekByTime(t *testing.T) {
+ startMessageIDs := []MessageID{EarliestMessageID(), LatestMessageID()}
+ for _, startMsgID := range startMessageIDs {
+ t.Run(fmt.Sprintf("TestReaderSeekByTime_%v", startMsgID),
func(t *testing.T) {
+ testReaderSeekByTimeWithHasNext(t, startMsgID)
+ })
+ }
+}