This is an automated email from the ASF dual-hosted git repository.

zixuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new ba83732a Fix wrong result of reader.hasNext/Next after seeking by id 
or time (#1340)
ba83732a is described below

commit ba83732a596a78ddda42a02cadd74a0bb77123de
Author: Baodi Shi <[email protected]>
AuthorDate: Mon Mar 10 14:46:21 2025 +0800

    Fix wrong result of reader.hasNext/Next after seeking by id or time (#1340)
    
    * Fix wrong result of hasNext after seeking by id or time
    
    * fix unit test
    
    * Address code reviews.
    
    * Address code review
    
    * Add annotation to StartMessageIDInclusive
---
 pulsar/consumer.go                 |   1 +
 pulsar/consumer_partition.go       |  98 +++++++++++++++++-----
 pulsar/consumer_test.go            |   5 +-
 pulsar/consumer_zero_queue_test.go |   1 +
 pulsar/impl_message.go             |   8 ++
 pulsar/reader.go                   |   1 +
 pulsar/reader_impl.go              |  22 ++---
 pulsar/reader_test.go              | 167 +++++++++++++++++++++++++++++++++++++
 8 files changed, 262 insertions(+), 41 deletions(-)

diff --git a/pulsar/consumer.go b/pulsar/consumer.go
index 7aee9645..d611c691 100644
--- a/pulsar/consumer.go
+++ b/pulsar/consumer.go
@@ -261,6 +261,7 @@ type ConsumerOptions struct {
        SubscriptionMode SubscriptionMode
 
        // StartMessageIDInclusive, if true, the consumer will start at the 
`StartMessageID`, included.
+       // Note: This configuration also affects the seek operation.
        // Default is `false` and the consumer will start from the "next" 
message
        StartMessageIDInclusive bool
 
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index aa922af2..cc2554c0 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -190,9 +190,13 @@ type partitionConsumer struct {
        backoffPolicyFunc    func() backoff.Policy
 
        dispatcherSeekingControlCh chan struct{}
-       isSeeking                  atomic.Bool
-       ctx                        context.Context
-       cancelFunc                 context.CancelFunc
+       // handle to the dispatcher goroutine
+       isSeeking atomic.Bool
+       // After executing seekByTime, the client is unaware of the 
startMessageId.
+       // Use this flag to compare markDeletePosition with BrokerLastMessageId 
when checking hasMoreMessages.
+       hasSoughtByTime atomic.Bool
+       ctx             context.Context
+       cancelFunc      context.CancelFunc
 }
 
 // pauseDispatchMessage used to discard the message in the dispatcher 
goroutine.
@@ -429,11 +433,12 @@ func newPartitionConsumer(parent Consumer, client 
*client, options *partitionCon
 
        startingMessageID := pc.startMessageID.get()
        if pc.options.startMessageIDInclusive && startingMessageID != nil && 
startingMessageID.equal(latestMessageID) {
-               msgID, err := pc.requestGetLastMessageID()
+               msgIDResp, err := pc.requestGetLastMessageID()
                if err != nil {
                        pc.Close()
                        return nil, err
                }
+               msgID := convertToMessageID(msgIDResp.GetLastMessageId())
                if msgID.entryID != noMessageEntry {
                        pc.startMessageID.set(msgID)
 
@@ -616,18 +621,27 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub 
*unsubscribeRequest) {
 }
 
 func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) {
+       res, err := pc.getLastMessageIDAndMarkDeletePosition()
+       if err != nil {
+               return nil, err
+       }
+       return res.msgID, err
+}
+
+func (pc *partitionConsumer) getLastMessageIDAndMarkDeletePosition() 
(*getLastMsgIDResult, error) {
        if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
                pc.log.WithField("state", state).Error("Failed to 
getLastMessageID for the closing or closed consumer")
                return nil, errors.New("failed to getLastMessageID for the 
closing or closed consumer")
        }
        bo := pc.backoffPolicyFunc()
-       request := func() (*trackingMessageID, error) {
+       request := func() (*getLastMsgIDResult, error) {
                req := &getLastMsgIDRequest{doneCh: make(chan struct{})}
                pc.eventsCh <- req
 
                // wait for the request to complete
                <-req.doneCh
-               return req.msgID, req.err
+               res := &getLastMsgIDResult{req.msgID, req.markDeletePosition}
+               return res, req.err
        }
 
        ctx, cancel := context.WithTimeout(context.Background(), 
pc.client.operationTimeout)
@@ -647,10 +661,16 @@ func (pc *partitionConsumer) getLastMessageID() 
(*trackingMessageID, error) {
 
 func (pc *partitionConsumer) internalGetLastMessageID(req 
*getLastMsgIDRequest) {
        defer close(req.doneCh)
-       req.msgID, req.err = pc.requestGetLastMessageID()
+       rsp, err := pc.requestGetLastMessageID()
+       if err != nil {
+               req.err = err
+               return
+       }
+       req.msgID = convertToMessageID(rsp.GetLastMessageId())
+       req.markDeletePosition = 
convertToMessageID(rsp.GetConsumerMarkDeletePosition())
 }
 
-func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID, 
error) {
+func (pc *partitionConsumer) requestGetLastMessageID() 
(*pb.CommandGetLastMessageIdResponse, error) {
        if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
                pc.log.WithField("state", state).Error("Failed to 
getLastMessageID closing or closed consumer")
                return nil, errors.New("failed to getLastMessageID closing or 
closed consumer")
@@ -667,8 +687,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() 
(*trackingMessageID, erro
                pc.log.WithError(err).Error("Failed to get last message id")
                return nil, err
        }
-       id := res.Response.GetLastMessageIdResponse.GetLastMessageId()
-       return convertToMessageID(id), nil
+       return res.Response.GetLastMessageIdResponse, nil
 }
 
 func (pc *partitionConsumer) sendIndividualAck(msgID MessageID) *ackRequest {
@@ -997,7 +1016,15 @@ func (pc *partitionConsumer) requestSeek(msgID 
*messageID) error {
        if err := pc.requestSeekWithoutClear(msgID); err != nil {
                return err
        }
-       pc.clearReceiverQueue()
+       // When the seek operation is successful, it indicates:
+       // 1. The broker has reset the cursor and sent a request to close the 
consumer on the client side.
+       //    Since this method is in the same goroutine as the 
reconnectToBroker,
+       //    we can safely clear the messages in the queue (at this point, it 
won't contain messages after the seek).
+       // 2. The startMessageID is reset to ensure accurate judgment when 
calling hasNext next time.
+       //    Since the messages in the queue are cleared here reconnection 
won't reset startMessageId.
+       pc.lastDequeuedMsg = nil
+       pc.startMessageID.set(toTrackingMessageID(msgID))
+       pc.clearQueueAndGetNextMessage()
        return nil
 }
 
@@ -1069,7 +1096,9 @@ func (pc *partitionConsumer) internalSeekByTime(seek 
*seekByTimeRequest) {
                seek.err = err
                return
        }
-       pc.clearReceiverQueue()
+       pc.lastDequeuedMsg = nil
+       pc.hasSoughtByTime.Store(true)
+       pc.clearQueueAndGetNextMessage()
 }
 
 func (pc *partitionConsumer) internalAck(req *ackRequest) {
@@ -1451,10 +1480,6 @@ func (pc *partitionConsumer) 
messageShouldBeDiscarded(msgID *trackingMessageID)
        if pc.startMessageID.get() == nil {
                return false
        }
-       // if we start at latest message, we should never discard
-       if pc.options.startMessageID != nil && 
pc.options.startMessageID.equal(latestMessageID) {
-               return false
-       }
 
        if pc.options.startMessageIDInclusive {
                return pc.startMessageID.get().greater(msgID.messageID)
@@ -1709,9 +1734,15 @@ type redeliveryRequest struct {
 }
 
 type getLastMsgIDRequest struct {
-       doneCh chan struct{}
-       msgID  *trackingMessageID
-       err    error
+       doneCh             chan struct{}
+       msgID              *trackingMessageID
+       markDeletePosition *trackingMessageID
+       err                error
+}
+
+type getLastMsgIDResult struct {
+       msgID              *trackingMessageID
+       markDeletePosition *trackingMessageID
 }
 
 type seekRequest struct {
@@ -2200,6 +2231,25 @@ func (pc *partitionConsumer) 
discardCorruptedMessage(msgID *pb.MessageIdData,
 }
 
 func (pc *partitionConsumer) hasNext() bool {
+
+       // If a seek by time has been performed, then the `startMessageId` 
becomes irrelevant.
+       // We need to compare `markDeletePosition` and `lastMessageId`,
+       // and then reset `startMessageID` to `markDeletePosition`.
+       if pc.lastDequeuedMsg == nil && pc.hasSoughtByTime.CompareAndSwap(true, 
false) {
+               res, err := pc.getLastMessageIDAndMarkDeletePosition()
+               if err != nil {
+                       pc.log.WithError(err).Error("Failed to get last message 
id")
+                       pc.hasSoughtByTime.CompareAndSwap(false, true)
+                       return false
+               }
+               pc.lastMessageInBroker = res.msgID
+               pc.startMessageID.set(res.markDeletePosition)
+               // We only care about comparing ledger ids and entry ids as 
mark delete position
+               // doesn't have other ids such as batch index
+               compareResult := 
pc.lastMessageInBroker.messageID.compareLedgerAndEntryID(pc.startMessageID.get().messageID)
+               return compareResult > 0 || (pc.options.startMessageIDInclusive 
&& compareResult == 0)
+       }
+
        if pc.lastMessageInBroker != nil && pc.hasMoreMessages() {
                return true
        }
@@ -2261,12 +2311,14 @@ func convertToMessageID(id *pb.MessageIdData) 
*trackingMessageID {
 
        msgID := &trackingMessageID{
                messageID: &messageID{
-                       ledgerID: int64(*id.LedgerId),
-                       entryID:  int64(*id.EntryId),
+                       ledgerID:  int64(id.GetLedgerId()),
+                       entryID:   int64(id.GetEntryId()),
+                       batchIdx:  id.GetBatchIndex(),
+                       batchSize: id.GetBatchSize(),
                },
        }
-       if id.BatchIndex != nil {
-               msgID.batchIdx = *id.BatchIndex
+       if msgID.batchIdx == -1 {
+               msgID.batchIdx = 0
        }
 
        return msgID
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index edef4404..43ce3c75 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -1262,8 +1262,9 @@ func TestConsumerSeek(t *testing.T) {
        defer producer.Close()
 
        consumer, err := client.Subscribe(ConsumerOptions{
-               Topic:            topicName,
-               SubscriptionName: "sub-1",
+               Topic:                   topicName,
+               SubscriptionName:        "sub-1",
+               StartMessageIDInclusive: true,
        })
        assert.Nil(t, err)
        defer consumer.Close()
diff --git a/pulsar/consumer_zero_queue_test.go 
b/pulsar/consumer_zero_queue_test.go
index 34e9df9f..2e7c4b27 100644
--- a/pulsar/consumer_zero_queue_test.go
+++ b/pulsar/consumer_zero_queue_test.go
@@ -474,6 +474,7 @@ func TestZeroQueueConsumer_Seek(t *testing.T) {
                Topic:                   topicName,
                EnableZeroQueueConsumer: true,
                SubscriptionName:        "sub-1",
+               StartMessageIDInclusive: true,
        })
        assert.Nil(t, err)
        _, ok := consumer.(*zeroQueueConsumer)
diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go
index 0acd782b..f32cc7fc 100644
--- a/pulsar/impl_message.go
+++ b/pulsar/impl_message.go
@@ -18,6 +18,7 @@
 package pulsar
 
 import (
+       "cmp"
        "errors"
        "fmt"
        "math"
@@ -147,6 +148,13 @@ func (id *messageID) equal(other *messageID) bool {
                id.batchIdx == other.batchIdx
 }
 
+func (id *messageID) compareLedgerAndEntryID(other *messageID) int {
+       if result := cmp.Compare(id.ledgerID, other.ledgerID); result != 0 {
+               return result
+       }
+       return cmp.Compare(id.entryID, other.entryID)
+}
+
 func (id *messageID) greaterEqual(other *messageID) bool {
        return id.equal(other) || id.greater(other)
 }
diff --git a/pulsar/reader.go b/pulsar/reader.go
index 98bde4e3..95aed0e5 100644
--- a/pulsar/reader.go
+++ b/pulsar/reader.go
@@ -53,6 +53,7 @@ type ReaderOptions struct {
        StartMessageID MessageID
 
        // StartMessageIDInclusive, if true, the reader will start at the 
`StartMessageID`, included.
+       // Note: This configuration also affects the seek operation.
        // Default is `false` and the reader will start from the "next" message
        StartMessageIDInclusive bool
 
diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go
index f76255e2..55b05037 100644
--- a/pulsar/reader_impl.go
+++ b/pulsar/reader_impl.go
@@ -196,19 +196,6 @@ func (r *reader) Close() {
        r.metrics.ReadersClosed.Inc()
 }
 
-func (r *reader) messageID(msgID MessageID) *trackingMessageID {
-       mid := toTrackingMessageID(msgID)
-
-       partition := int(mid.partitionIdx)
-       // did we receive a valid partition index?
-       if partition < 0 {
-               r.log.Warnf("invalid partition index %d expected", partition)
-               return nil
-       }
-
-       return mid
-}
-
 func (r *reader) Seek(msgID MessageID) error {
        r.Lock()
        defer r.Unlock()
@@ -218,9 +205,12 @@ func (r *reader) Seek(msgID MessageID) error {
                return fmt.Errorf("invalid message id type %T", msgID)
        }
 
-       mid := r.messageID(msgID)
-       if mid == nil {
-               return nil
+       mid := toTrackingMessageID(msgID)
+
+       partition := int(mid.partitionIdx)
+       if partition < 0 {
+               r.log.Warnf("invalid partition index %d expected", partition)
+               return fmt.Errorf("seek msgId must include partitoinIndex")
        }
 
        return r.c.Seek(mid)
diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go
index 83653570..2fbe89e4 100644
--- a/pulsar/reader_test.go
+++ b/pulsar/reader_test.go
@@ -1070,3 +1070,170 @@ func TestReaderNextReturnsOnClosedConsumer(t 
*testing.T) {
        assert.ErrorAs(t, err, &e)
        assert.Equal(t, ConsumerClosed, e.Result())
 }
+
+func testReaderSeekByIDWithHasNext(t *testing.T, startMessageID MessageID, 
startMessageIDInclusive bool) {
+       client, err := NewClient(ClientOptions{
+               URL: lookupURL,
+       })
+
+       assert.Nil(t, err)
+       defer client.Close()
+
+       topic := newTopicName()
+       ctx := context.Background()
+
+       // create producer
+       producer, err := client.CreateProducer(ProducerOptions{
+               Topic:           topic,
+               DisableBatching: true,
+       })
+       assert.Nil(t, err)
+       defer producer.Close()
+
+       // send 100 messages
+       var lastMsgID MessageID
+       for i := 0; i < 10; i++ {
+               lastMsgID, err = producer.Send(ctx, &ProducerMessage{
+                       Payload: []byte(fmt.Sprintf("hello-%d", i)),
+               })
+               assert.NoError(t, err)
+               assert.NotNil(t, lastMsgID)
+       }
+
+       reader, err := client.CreateReader(ReaderOptions{
+               Topic:                   topic,
+               StartMessageID:          startMessageID,
+               StartMessageIDInclusive: startMessageIDInclusive,
+       })
+       assert.Nil(t, err)
+       defer reader.Close()
+
+       // Seek to last message ID
+       err = reader.Seek(lastMsgID)
+       assert.NoError(t, err)
+
+       if startMessageIDInclusive {
+               assert.True(t, reader.HasNext())
+               ctx, cancel := context.WithTimeout(context.Background(), 
1*time.Second)
+               msg, err := reader.Next(ctx)
+               assert.NoError(t, err)
+               assert.NotNil(t, msg)
+               assert.True(t, messageIDCompare(lastMsgID, msg.ID()) == 0)
+               cancel()
+       } else {
+               assert.False(t, reader.HasNext())
+               ctx, cancel := context.WithTimeout(context.Background(), 
1*time.Second)
+               msg, err := reader.Next(ctx)
+               assert.Error(t, err)
+               assert.Nil(t, msg)
+               cancel()
+       }
+
+}
+
+func TestReaderWithSeekByID(t *testing.T) {
+       params := []struct {
+               messageID               MessageID
+               startMessageIDInclusive bool
+       }{
+               {EarliestMessageID(), false},
+               {EarliestMessageID(), true},
+               {LatestMessageID(), false},
+               {LatestMessageID(), true},
+       }
+
+       for _, c := range params {
+               t.Run(fmt.Sprintf("TestReaderSeekByID_%v_%v", c.messageID, 
c.startMessageIDInclusive),
+                       func(t *testing.T) {
+                               testReaderSeekByIDWithHasNext(t, c.messageID, 
c.startMessageIDInclusive)
+                       })
+       }
+}
+
+func testReaderSeekByTimeWithHasNext(t *testing.T, startMessageID MessageID) {
+       client, err := NewClient(ClientOptions{
+               URL: lookupURL,
+       })
+
+       assert.Nil(t, err)
+       defer client.Close()
+
+       topic := newTopicName()
+       ctx := context.Background()
+
+       // create producer
+       producer, err := client.CreateProducer(ProducerOptions{
+               Topic:           topic,
+               DisableBatching: true,
+       })
+       assert.Nil(t, err)
+       defer producer.Close()
+
+       // 1. send 10 messages
+       var lastMsgID MessageID
+       for i := 0; i < 10; i++ {
+               lastMsgID, err = producer.Send(ctx, &ProducerMessage{
+                       Payload: []byte(fmt.Sprintf("hello-%d", i)),
+               })
+               assert.NoError(t, err)
+
+               assert.NotNil(t, lastMsgID)
+       }
+
+       // 2. create reader
+       reader, err := client.CreateReader(ReaderOptions{
+               Topic:                   topic,
+               StartMessageID:          startMessageID,
+               StartMessageIDInclusive: false,
+       })
+       assert.Nil(t, err)
+       defer reader.Close()
+
+       // 3. Seek time to now
+       reader.SeekByTime(time.Now())
+
+       // 4. Should not receive msg
+       {
+               assert.False(t, reader.HasNext())
+               timeoutCtx, cancel := context.WithTimeout(context.Background(), 
1*time.Second)
+               msg, err := reader.Next(timeoutCtx)
+               assert.Error(t, err)
+               assert.Nil(t, msg)
+               cancel()
+       }
+
+       // 5. send more 10 messages
+       for i := 0; i < 10; i++ {
+               lastMsgID, err = producer.Send(ctx, &ProducerMessage{
+                       Payload: []byte(fmt.Sprintf("hello2-%d", i)),
+               })
+               assert.NoError(t, err)
+               assert.NotNil(t, lastMsgID)
+       }
+
+       // 6. Assert these messages are received
+       for i := 0; i < 10; i++ {
+               assert.True(t, reader.HasNext())
+               msg, err := reader.Next(context.Background())
+               assert.NoError(t, err)
+               assert.Equal(t, fmt.Sprintf("hello2-%d", i), 
string(msg.Payload()))
+       }
+
+       // assert not more msg
+       {
+               assert.False(t, reader.HasNext())
+               timeoutCtx, cancel := context.WithTimeout(context.Background(), 
1*time.Second)
+               msg, err := reader.Next(timeoutCtx)
+               assert.Error(t, err)
+               assert.Nil(t, msg)
+               cancel()
+       }
+}
+func TestReaderWithSeekByTime(t *testing.T) {
+       startMessageIDs := []MessageID{EarliestMessageID(), LatestMessageID()}
+       for _, startMsgID := range startMessageIDs {
+               t.Run(fmt.Sprintf("TestReaderSeekByTime_%v", startMsgID), 
func(t *testing.T) {
+                       testReaderSeekByTimeWithHasNext(t, startMsgID)
+               })
+       }
+}

Reply via email to