This is an automated email from the ASF dual-hosted git repository.

xiangying pushed a commit to branch txn_API
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git

commit 94d776554c5cd2863c1e043984c4988b6f4e8626
Author: xiangying <1984997...@qq.com>
AuthorDate: Wed Apr 5 14:27:19 2023 +0800

    [feat][txn]Implement transactional consumer/producer API
---
 pulsar/client.go                                   |  11 ++
 pulsar/client_impl.go                              |   5 +
 pulsar/consumer.go                                 |   3 +
 pulsar/consumer_impl.go                            |  10 ++
 pulsar/consumer_multitopic.go                      |  17 ++++
 pulsar/consumer_partition.go                       | 108 ++++++++++++++++++++
 pulsar/consumer_regex.go                           |  18 ++++
 pulsar/helper_for_test.go                          |   2 +-
 pulsar/internal/batch_builder.go                   |   8 ++
 pulsar/internal/commands.go                        |  12 ++-
 pulsar/internal/http_client.go                     |   2 +-
 pulsar/internal/key_based_batch_builder.go         |   5 +-
 .../pulsartracing/consumer_interceptor_test.go     |   4 +
 pulsar/message.go                                  |   4 +
 pulsar/producer_partition.go                       |  95 +++++++++++++++---
 pulsar/transaction.go                              |   8 +-
 pulsar/transaction_coordinator_client.go           |  18 ++--
 pulsar/transaction_impl.go                         |   2 +
 pulsar/transaction_test.go                         | 111 ++++++++++++++++++++-
 19 files changed, 405 insertions(+), 38 deletions(-)

diff --git a/pulsar/client.go b/pulsar/client.go
index bc3f4f5..7e6725d 100644
--- a/pulsar/client.go
+++ b/pulsar/client.go
@@ -184,6 +184,17 @@ type Client interface {
        // {@link Consumer} or {@link Producer} instances directly on a 
particular partition.
        TopicPartitions(topic string) ([]string, error)
 
+       // NewTransaction creates a new Transaction instance.
+       //
+       // This function is used to initiate a new transaction for performing
+       // atomic operations on the message broker. It returns a Transaction
+       // object that can be used to produce, consume and commit messages in a
+       // transactional manner.
+       //
+       // In case of any errors while creating the transaction, an error will
+       // be returned.
+       NewTransaction(duration time.Duration) (Transaction, error)
+
        // Close Closes the Client and free associated resources
        Close()
 }
diff --git a/pulsar/client_impl.go b/pulsar/client_impl.go
index 5322597..ba040ae 100644
--- a/pulsar/client_impl.go
+++ b/pulsar/client_impl.go
@@ -195,6 +195,11 @@ func newClient(options ClientOptions) (Client, error) {
        return c, nil
 }
 
+func (c *client) NewTransaction(timeout time.Duration) (Transaction, error) {
+       id, err := c.tcClient.newTransaction(timeout)
+       return newTransaction(*id, c.tcClient, timeout), err
+}
+
 func (c *client) CreateProducer(options ProducerOptions) (Producer, error) {
        producer, err := newProducer(c, &options)
        if err == nil {
diff --git a/pulsar/consumer.go b/pulsar/consumer.go
index 64a096d..3ef72c7 100644
--- a/pulsar/consumer.go
+++ b/pulsar/consumer.go
@@ -265,6 +265,9 @@ type Consumer interface {
        // AckID the consumption of a single message, identified by its 
MessageID
        AckID(MessageID) error
 
+       // AckWithTxn the consumption of a single message with a transaction
+       AckWithTxn(Message, Transaction) error
+
        // AckCumulative the reception of all the messages in the stream up to 
(and including)
        // the provided message.
        AckCumulative(msg Message) error
diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index fd7fa57..07f38c3 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -38,6 +38,7 @@ type acker interface {
        // AckID does not handle errors returned by the Broker side, so no need 
to wait for doneCh to finish.
        AckID(id MessageID) error
        AckIDWithResponse(id MessageID) error
+       AckIDWithTxn(msgID MessageID, txn Transaction) error
        AckIDCumulative(msgID MessageID) error
        AckIDWithResponseCumulative(msgID MessageID) error
        NackID(id MessageID)
@@ -478,6 +479,15 @@ func (c *consumer) Receive(ctx context.Context) (message 
Message, err error) {
        }
 }
 
+func (c *consumer) AckWithTxn(msg Message, txn Transaction) error {
+       msgID := msg.ID()
+       if err := c.checkMsgIDPartition(msgID); err != nil {
+               return err
+       }
+
+       return c.consumers[msgID.PartitionIdx()].AckIDWithTxn(msgID, txn)
+}
+
 // Chan return the message chan to users
 func (c *consumer) Chan() <-chan ConsumerMessage {
        return c.messageCh
diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go
index 8108c29..f6630dd 100644
--- a/pulsar/consumer_multitopic.go
+++ b/pulsar/consumer_multitopic.go
@@ -143,6 +143,23 @@ func (c *multiTopicConsumer) AckID(msgID MessageID) error {
        return mid.consumer.AckID(msgID)
 }
 
+// AckWithTxn the consumption of a single message with a transaction
+func (c *multiTopicConsumer) AckWithTxn(msg Message, txn Transaction) error {
+       msgID := msg.ID()
+       if !checkMessageIDType(msgID) {
+               c.log.Warnf("invalid message id type %T", msgID)
+               return errors.New("invalid message id type in multi_consumer")
+       }
+       mid := toTrackingMessageID(msgID)
+
+       if mid.consumer == nil {
+               c.log.Warnf("unable to ack messageID=%+v can not determine 
topic", msgID)
+               return errors.New("unable to ack message because consumer is 
nil")
+       }
+
+       return mid.consumer.AckIDWithTxn(msgID, txn)
+}
+
 // AckCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message
 func (c *multiTopicConsumer) AckCumulative(msg Message) error {
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index fb77d0d..be28bce 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -417,6 +417,93 @@ func (pc *partitionConsumer) Unsubscribe() error {
        return req.err
 }
 
+func (pc *partitionConsumer) AckIDWithTxn(msgID MessageID, txn Transaction) 
error {
+       if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
+               pc.log.WithField("state", state).Error("Failed to ack by 
closing or closed consumer")
+               return errors.New("consumer state is closed")
+       }
+
+       if cmid, ok := msgID.(*chunkMessageID); ok {
+               return pc.unAckChunksTracker.ack(cmid)
+       }
+
+       trackingID := toTrackingMessageID(msgID)
+
+       if trackingID != nil && trackingID.ack() {
+               // All messages in the same batch have been acknowledged, we 
only need to acknowledge the
+               // MessageID that represents the entry that stores the whole 
batch
+               trackingID = &trackingMessageID{
+                       messageID: &messageID{
+                               ledgerID: trackingID.ledgerID,
+                               entryID:  trackingID.entryID,
+                       },
+               }
+               pc.metrics.AcksCounter.Inc()
+               
pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-trackingID.receivedTime.UnixNano())
 / 1.0e9)
+       } else if !pc.options.enableBatchIndexAck {
+               return nil
+       }
+
+       ackReq := pc.sendIndividualAckWithTxn(trackingID, txn.(*transaction))
+       <-ackReq.doneCh
+       pc.options.interceptors.OnAcknowledge(pc.parentConsumer, msgID)
+       if ackReq == nil {
+               return nil
+       }
+       return ackReq.err
+}
+
+func (pc *partitionConsumer) internalAckWithTxn(req *ackWithTxnRequest) {
+       defer close(req.doneCh)
+       if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
+               pc.log.WithField("state", state).Error("Failed to ack by 
closing or closed consumer")
+               req.err = newError(ConsumerClosed, "Failed to ack by closing or 
closed consumer")
+               return
+       }
+       msgID := req.msgID
+
+       messageIDs := make([]*pb.MessageIdData, 1)
+       messageIDs[0] = &pb.MessageIdData{
+               LedgerId: proto.Uint64(uint64(msgID.ledgerID)),
+               EntryId:  proto.Uint64(uint64(msgID.entryID)),
+       }
+       if pc.options.enableBatchIndexAck && msgID.tracker != nil {
+               ackSet := msgID.tracker.toAckSet()
+               if ackSet != nil {
+                       messageIDs[0].AckSet = ackSet
+               }
+       }
+
+       reqID := pc.client.rpcClient.NewRequestID()
+       txnID := req.Transaction.GetTxnID()
+       cmdAck := &pb.CommandAck{
+               ConsumerId:     proto.Uint64(pc.consumerID),
+               MessageId:      messageIDs,
+               AckType:        pb.CommandAck_Individual.Enum(),
+               TxnidMostBits:  proto.Uint64(txnID.MostSigBits),
+               TxnidLeastBits: proto.Uint64(txnID.LeastSigBits),
+       }
+
+       err := req.Transaction.registerAckTopic(pc.options.topic, 
pc.options.subscription)
+       if err != nil {
+               req.err = err
+               return
+       }
+       err = req.Transaction.registerSendOrAckOp()
+       if err != nil {
+               req.err = err
+               return
+       }
+       cmdAck.RequestId = proto.Uint64(reqID)
+       _, err = pc.client.rpcClient.RequestOnCnx(pc._getConn(), reqID, 
pb.BaseCommand_ACK, cmdAck)
+       if err != nil {
+               pc.log.WithError(err).Error("Ack with response error")
+               req.err = err
+       }
+       req.Transaction.endSendOrAckOp(err)
+       req.err = err
+}
+
 func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
        defer close(unsub.doneCh)
 
@@ -539,6 +626,17 @@ func (pc *partitionConsumer) sendIndividualAck(msgID 
MessageID) *ackRequest {
        return ackReq
 }
 
+func (pc *partitionConsumer) sendIndividualAckWithTxn(msgID MessageID, txn 
*transaction) *ackWithTxnRequest {
+       ackReq := &ackWithTxnRequest{
+               Transaction: txn,
+               doneCh:      make(chan struct{}),
+               ackType:     individualAck,
+               msgID:       *msgID.(*trackingMessageID),
+       }
+       pc.eventsCh <- ackReq
+       return ackReq
+}
+
 func (pc *partitionConsumer) AckIDWithResponse(msgID MessageID) error {
        if !checkMessageIDType(msgID) {
                pc.log.Errorf("invalid message id type %T", msgID)
@@ -1389,6 +1487,14 @@ type ackRequest struct {
        err     error
 }
 
+type ackWithTxnRequest struct {
+       doneCh      chan struct{}
+       msgID       trackingMessageID
+       Transaction *transaction
+       ackType     int
+       err         error
+}
+
 type unsubscribeRequest struct {
        doneCh chan struct{}
        err    error
@@ -1444,6 +1550,8 @@ func (pc *partitionConsumer) runEventsLoop() {
                        switch v := i.(type) {
                        case *ackRequest:
                                pc.internalAck(v)
+                       case *ackWithTxnRequest:
+                               pc.internalAckWithTxn(v)
                        case []*pb.MessageIdData:
                                pc.internalAckList(v)
                        case *redeliveryRequest:
diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go
index 2520af5..79e2293 100644
--- a/pulsar/consumer_regex.go
+++ b/pulsar/consumer_regex.go
@@ -193,6 +193,24 @@ func (c *regexConsumer) AckID(msgID MessageID) error {
        return mid.consumer.AckID(msgID)
 }
 
+// AckID the consumption of a single message, identified by its MessageID
+func (c *regexConsumer) AckWithTxn(msg Message, txn Transaction) error {
+       msgID := msg.ID()
+       if !checkMessageIDType(msgID) {
+               c.log.Warnf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
+
+       mid := toTrackingMessageID(msgID)
+
+       if mid.consumer == nil {
+               c.log.Warnf("unable to ack messageID=%+v can not determine 
topic", msgID)
+               return errors.New("consumer is nil in consumer_regex")
+       }
+
+       return mid.consumer.AckIDWithTxn(msgID, txn)
+}
+
 // AckCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message.
 func (c *regexConsumer) AckCumulative(msg Message) error {
diff --git a/pulsar/helper_for_test.go b/pulsar/helper_for_test.go
index 7bbf66e..426855b 100644
--- a/pulsar/helper_for_test.go
+++ b/pulsar/helper_for_test.go
@@ -159,7 +159,7 @@ func topicStats(topic string) (map[string]interface{}, 
error) {
 
 func transactionStats(id *TxnID) (map[string]interface{}, error) {
        var metadata map[string]interface{}
-       path := fmt.Sprintf("admin/v3/transactions/transactionMetadata/%d/%d", 
id.mostSigBits, id.leastSigBits)
+       path := fmt.Sprintf("admin/v3/transactions/transactionMetadata/%d/%d", 
id.MostSigBits, id.LeastSigBits)
        err := httpGet(path, &metadata)
        return metadata, err
 }
diff --git a/pulsar/internal/batch_builder.go b/pulsar/internal/batch_builder.go
index 649aba4..6df3a61 100644
--- a/pulsar/internal/batch_builder.go
+++ b/pulsar/internal/batch_builder.go
@@ -51,6 +51,9 @@ type BatchBuilder interface {
                payload []byte,
                callback interface{}, replicateTo []string, deliverAt time.Time,
                schemaVersion []byte, multiSchemaEnabled bool,
+               useTxn bool,
+               mostSigBits uint64,
+               leastSigBits uint64,
        ) bool
 
        // Flush all the messages buffered in the client and wait until all 
messages have been successfully persisted.
@@ -185,6 +188,7 @@ func (bc *batchContainer) Add(
        payload []byte,
        callback interface{}, replicateTo []string, deliverAt time.Time,
        schemaVersion []byte, multiSchemaEnabled bool,
+       useTxn bool, mostSigBits uint64, leastSigBits uint64,
 ) bool {
 
        if replicateTo != nil && bc.numMessages != 0 {
@@ -223,6 +227,10 @@ func (bc *batchContainer) Add(
                }
 
                bc.cmdSend.Send.SequenceId = proto.Uint64(sequenceID)
+               if useTxn {
+                       bc.cmdSend.Send.TxnidMostBits = 
proto.Uint64(mostSigBits)
+                       bc.cmdSend.Send.TxnidLeastBits = 
proto.Uint64(leastSigBits)
+               }
        }
        addSingleMessageToBatch(bc.buffer, metadata, payload)
 
diff --git a/pulsar/internal/commands.go b/pulsar/internal/commands.go
index 00e075b..7471ee0 100644
--- a/pulsar/internal/commands.go
+++ b/pulsar/internal/commands.go
@@ -22,11 +22,10 @@ import (
        "errors"
        "fmt"
 
-       "google.golang.org/protobuf/proto"
-
        "github.com/apache/pulsar-client-go/pulsar/internal/compression"
        "github.com/apache/pulsar-client-go/pulsar/internal/crypto"
        pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
+       "google.golang.org/protobuf/proto"
 )
 
 const (
@@ -332,7 +331,10 @@ func SingleSend(wb Buffer,
        msgMetadata *pb.MessageMetadata,
        compressedPayload Buffer,
        encryptor crypto.Encryptor,
-       maxMassageSize uint32) error {
+       maxMassageSize uint32,
+       useTxn bool,
+       mostSigBits uint64,
+       leastSigBits uint64) error {
        cmdSend := baseCommand(
                pb.BaseCommand_SEND,
                &pb.CommandSend{
@@ -344,6 +346,10 @@ func SingleSend(wb Buffer,
                isChunk := true
                cmdSend.Send.IsChunk = &isChunk
        }
+       if useTxn {
+               cmdSend.Send.TxnidMostBits = proto.Uint64(mostSigBits)
+               cmdSend.Send.TxnidLeastBits = proto.Uint64(leastSigBits)
+       }
        // payload has been compressed so compressionProvider can be nil
        return serializeMessage(wb, cmdSend, msgMetadata, compressedPayload,
                nil, encryptor, maxMassageSize, false)
diff --git a/pulsar/internal/http_client.go b/pulsar/internal/http_client.go
index dccc143..7cb7e8e 100644
--- a/pulsar/internal/http_client.go
+++ b/pulsar/internal/http_client.go
@@ -190,10 +190,10 @@ func (c *httpClient) GetWithOptions(endpoint string, obj 
interface{}, params map
        }
 
        resp, err := checkSuccessful(c.doRequest(req))
+       defer safeRespClose(resp)
        if err != nil {
                return nil, err
        }
-       defer safeRespClose(resp)
 
        if obj != nil {
                if err := decodeJSONBody(resp, &obj); err != nil {
diff --git a/pulsar/internal/key_based_batch_builder.go 
b/pulsar/internal/key_based_batch_builder.go
index 334e674..88a4d5e 100644
--- a/pulsar/internal/key_based_batch_builder.go
+++ b/pulsar/internal/key_based_batch_builder.go
@@ -132,6 +132,9 @@ func (bc *keyBasedBatchContainer) Add(
        payload []byte,
        callback interface{}, replicateTo []string, deliverAt time.Time,
        schemaVersion []byte, multiSchemaEnabled bool,
+       useTxn bool,
+       mostSigBits uint64,
+       leastSigBits uint64,
 ) bool {
        if replicateTo != nil && bc.numMessages != 0 {
                // If the current batch is not empty and we're trying to set 
the replication clusters,
@@ -162,7 +165,7 @@ func (bc *keyBasedBatchContainer) Add(
        add := batchPart.Add(
                metadata, sequenceIDGenerator, payload, callback, replicateTo,
                deliverAt,
-               schemaVersion, multiSchemaEnabled,
+               schemaVersion, multiSchemaEnabled, useTxn, mostSigBits, 
leastSigBits,
        )
        if !add {
                return false
diff --git a/pulsar/internal/pulsartracing/consumer_interceptor_test.go 
b/pulsar/internal/pulsartracing/consumer_interceptor_test.go
index 34e09d5..06c9a58 100644
--- a/pulsar/internal/pulsartracing/consumer_interceptor_test.go
+++ b/pulsar/internal/pulsartracing/consumer_interceptor_test.go
@@ -52,6 +52,10 @@ func (c *mockConsumer) Subscription() string {
        return ""
 }
 
+func (c *mockConsumer) AckWithTxn(msg pulsar.Message, txn pulsar.Transaction) 
error {
+       return nil
+}
+
 func (c *mockConsumer) Unsubscribe() error {
        return nil
 }
diff --git a/pulsar/message.go b/pulsar/message.go
index 98190e9..83afd3f 100644
--- a/pulsar/message.go
+++ b/pulsar/message.go
@@ -69,6 +69,10 @@ type ProducerMessage struct {
        //Schema assign to the current message
        //Note: messages may have a different schema from producer schema, use 
it instead of producer schema when assigned
        Schema Schema
+
+       //Transaction assign to the current message
+       //Note: The message is not visible before the transaction is committed.
+       Transaction Transaction
 }
 
 // Message abstraction used in Pulsar
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 744df79..428a8dd 100644
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -663,6 +663,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                                        chunkID:          chunkID,
                                        uuid:             uuid,
                                        chunkRecorder:    cr,
+                                       transaction:      request.transaction,
                                }
                                // the permit of first chunk has acquired
                                if chunkID != 0 && !p.canAddToQueue(nsr, 0) {
@@ -681,16 +682,16 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
        } else {
                smm := p.genSingleMessageMetadataInBatch(msg, uncompressedSize)
                multiSchemaEnabled := !p.options.DisableMultiSchema
-               added := p.batchBuilder.Add(smm, p.sequenceIDGenerator, 
uncompressedPayload, request,
-                       msg.ReplicationClusters, deliverAt, schemaVersion, 
multiSchemaEnabled)
+               added := addRequestToBatch(smm, p, uncompressedPayload, 
request, msg, deliverAt, schemaVersion,
+                       multiSchemaEnabled)
                if !added {
-                       // The current batch is full.. flush it and retry
+                       // The current batch is full. flush it and retry
 
                        p.internalFlushCurrentBatch()
 
                        // after flushing try again to add the current payload
-                       if ok := p.batchBuilder.Add(smm, p.sequenceIDGenerator, 
uncompressedPayload, request,
-                               msg.ReplicationClusters, deliverAt, 
schemaVersion, multiSchemaEnabled); !ok {
+                       if ok := addRequestToBatch(smm, p, uncompressedPayload, 
request, msg, deliverAt, schemaVersion,
+                               multiSchemaEnabled); !ok {
                                
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                                request.callback(nil, request.msg, 
errFailAddToBatch)
                                p.log.WithField("size", uncompressedSize).
@@ -707,6 +708,23 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
        }
 }
 
+func addRequestToBatch(smm *pb.SingleMessageMetadata, p *partitionProducer,
+       uncompressedPayload []byte,
+       request *sendRequest, msg *ProducerMessage, deliverAt time.Time,
+       schemaVersion []byte, multiSchemaEnabled bool) bool {
+       var ok bool
+       if request.transaction != nil {
+               txnID := request.transaction.GetTxnID()
+               ok = p.batchBuilder.Add(smm, p.sequenceIDGenerator, 
uncompressedPayload, request,
+                       msg.ReplicationClusters, deliverAt, schemaVersion, 
multiSchemaEnabled, true, txnID.MostSigBits,
+                       txnID.LeastSigBits)
+       } else {
+               ok = p.batchBuilder.Add(smm, p.sequenceIDGenerator, 
uncompressedPayload, request,
+                       msg.ReplicationClusters, deliverAt, schemaVersion, 
multiSchemaEnabled, false, 0, 0)
+       }
+       return ok
+}
+
 func (p *partitionProducer) genMetadata(msg *ProducerMessage,
        uncompressedSize int,
        deliverAt time.Time) (mm *pb.MessageMetadata) {
@@ -789,16 +807,36 @@ func (p *partitionProducer) internalSingleSend(mm 
*pb.MessageMetadata,
        }
 
        sid := *mm.SequenceId
-
-       if err := internal.SingleSend(
-               buffer,
-               p.producerID,
-               sid,
-               mm,
-               payloadBuf,
-               p.encryptor,
-               maxMessageSize,
-       ); err != nil {
+       var err error
+       if request.transaction != nil {
+               txnID := request.transaction.GetTxnID()
+               err = internal.SingleSend(
+                       buffer,
+                       p.producerID,
+                       sid,
+                       mm,
+                       payloadBuf,
+                       p.encryptor,
+                       maxMessageSize,
+                       true,
+                       txnID.MostSigBits,
+                       txnID.LeastSigBits,
+               )
+       } else {
+               err = internal.SingleSend(
+                       buffer,
+                       p.producerID,
+                       sid,
+                       mm,
+                       payloadBuf,
+                       p.encryptor,
+                       maxMessageSize,
+                       false,
+                       0,
+                       0,
+               )
+       }
+       if err != nil {
                request.callback(nil, request.msg, err)
                p.releaseSemaphoreAndMem(int64(len(msg.Payload)))
                p.log.WithError(err).Errorf("Single message serialize failed 
%s", msg.Value)
@@ -952,6 +990,9 @@ func (p *partitionProducer) failTimeoutMessages() {
                                                sr.callback(nil, sr.msg, 
errSendTimeout)
                                        })
                                }
+                               if sr.transaction != nil {
+                                       sr.transaction.endSendOrAckOp(nil)
+                               }
                        }
 
                        // flag the send has completed with error, flush make 
no effect
@@ -1067,6 +1108,20 @@ func (p *partitionProducer) SendAsync(ctx 
context.Context, msg *ProducerMessage,
 
 func (p *partitionProducer) internalSendAsync(ctx context.Context, msg 
*ProducerMessage,
        callback func(MessageID, *ProducerMessage, error), flushImmediately 
bool) {
+       //Register transaction operation to transaction and the transaction 
coordinator.
+       if msg.Transaction != nil {
+               transactionImpl := (msg.Transaction).(*transaction)
+
+               err := transactionImpl.registerProducerTopic(p.topic)
+               if err != nil {
+                       callback(nil, msg, err)
+                       return
+               }
+               err = transactionImpl.registerSendOrAckOp()
+               if err != nil {
+                       callback(nil, msg, err)
+               }
+       }
        if p.getProducerState() != producerReady {
                // Producer is closing
                callback(nil, msg, errProducerClosed)
@@ -1078,7 +1133,10 @@ func (p *partitionProducer) internalSendAsync(ctx 
context.Context, msg *Producer
 
        // callbackOnce make sure the callback is only invoked once in chunking
        callbackOnce := &sync.Once{}
-
+       var txn *transaction
+       if msg.Transaction != nil {
+               txn = (msg.Transaction).(*transaction)
+       }
        sr := &sendRequest{
                ctx:              ctx,
                msg:              msg,
@@ -1088,6 +1146,7 @@ func (p *partitionProducer) internalSendAsync(ctx 
context.Context, msg *Producer
                publishTime:      time.Now(),
                blockCh:          bc,
                closeBlockChOnce: &sync.Once{},
+               transaction:      txn,
        }
        p.options.Interceptors.BeforeSend(p, msg)
 
@@ -1191,6 +1250,9 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                                        
p.options.Interceptors.OnSendAcknowledgement(p, sr.msg, msgID)
                                }
                        }
+                       if sr.transaction != nil {
+                               sr.transaction.endSendOrAckOp(nil)
+                       }
                }
 
                // Mark this pending item as done
@@ -1287,6 +1349,7 @@ type sendRequest struct {
        chunkID          int
        uuid             string
        chunkRecorder    *chunkRecorder
+       transaction      *transaction
 }
 
 // stopBlock can be invoked multiple times safety
diff --git a/pulsar/transaction.go b/pulsar/transaction.go
index 60e1d2b..944c7e3 100644
--- a/pulsar/transaction.go
+++ b/pulsar/transaction.go
@@ -49,10 +49,10 @@ const (
 
 // TxnID An identifier for representing a transaction.
 type TxnID struct {
-       // mostSigBits The most significant 64 bits of this TxnID.
-       mostSigBits uint64
-       // leastSigBits The least significant 64 bits of this TxnID.
-       leastSigBits uint64
+       // MostSigBits The most significant 64 bits of this TxnID.
+       MostSigBits uint64
+       // LeastSigBits The least significant 64 bits of this TxnID.
+       LeastSigBits uint64
 }
 
 // Transaction used to guarantee exactly-once
diff --git a/pulsar/transaction_coordinator_client.go 
b/pulsar/transaction_coordinator_client.go
index 1535fad..96cca87 100644
--- a/pulsar/transaction_coordinator_client.go
+++ b/pulsar/transaction_coordinator_client.go
@@ -135,11 +135,11 @@ func (tc *transactionCoordinatorClient) 
addPublishPartitionToTxn(id *TxnID, part
        requestID := tc.client.rpcClient.NewRequestID()
        cmdAddPartitions := &pb.CommandAddPartitionToTxn{
                RequestId:      proto.Uint64(requestID),
-               TxnidMostBits:  proto.Uint64(id.mostSigBits),
-               TxnidLeastBits: proto.Uint64(id.leastSigBits),
+               TxnidMostBits:  proto.Uint64(id.MostSigBits),
+               TxnidLeastBits: proto.Uint64(id.LeastSigBits),
                Partitions:     partitions,
        }
-       res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.mostSigBits], 
requestID,
+       res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.MostSigBits], 
requestID,
                pb.BaseCommand_ADD_PARTITION_TO_TXN, cmdAddPartitions)
        tc.semaphore.Release()
        if err != nil {
@@ -163,11 +163,11 @@ func (tc *transactionCoordinatorClient) 
addSubscriptionToTxn(id *TxnID, topic st
        }
        cmdAddSubscription := &pb.CommandAddSubscriptionToTxn{
                RequestId:      proto.Uint64(requestID),
-               TxnidMostBits:  proto.Uint64(id.mostSigBits),
-               TxnidLeastBits: proto.Uint64(id.leastSigBits),
+               TxnidMostBits:  proto.Uint64(id.MostSigBits),
+               TxnidLeastBits: proto.Uint64(id.LeastSigBits),
                Subscription:   []*pb.Subscription{sub},
        }
-       res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.mostSigBits], 
requestID,
+       res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.MostSigBits], 
requestID,
                pb.BaseCommand_ADD_SUBSCRIPTION_TO_TXN, cmdAddSubscription)
        tc.semaphore.Release()
        if err != nil {
@@ -187,10 +187,10 @@ func (tc *transactionCoordinatorClient) endTxn(id *TxnID, 
action pb.TxnAction) e
        cmdEndTxn := &pb.CommandEndTxn{
                RequestId:      proto.Uint64(requestID),
                TxnAction:      &action,
-               TxnidMostBits:  proto.Uint64(id.mostSigBits),
-               TxnidLeastBits: proto.Uint64(id.leastSigBits),
+               TxnidMostBits:  proto.Uint64(id.MostSigBits),
+               TxnidLeastBits: proto.Uint64(id.LeastSigBits),
        }
-       res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.mostSigBits], 
requestID, pb.BaseCommand_END_TXN, cmdEndTxn)
+       res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.MostSigBits], 
requestID, pb.BaseCommand_END_TXN, cmdEndTxn)
        tc.semaphore.Release()
        if err != nil {
                return err
diff --git a/pulsar/transaction_impl.go b/pulsar/transaction_impl.go
index 7cc93ec..8a24c46 100644
--- a/pulsar/transaction_impl.go
+++ b/pulsar/transaction_impl.go
@@ -231,6 +231,8 @@ func (state TxnState) string() string {
                return "TxnAborted"
        case TxnTimeout:
                return "TxnTimeout"
+       case TxnError:
+               return "TxnError"
        default:
                return "Unknown"
        }
diff --git a/pulsar/transaction_test.go b/pulsar/transaction_test.go
index 362e4d2..53edbb8 100644
--- a/pulsar/transaction_test.go
+++ b/pulsar/transaction_test.go
@@ -103,8 +103,8 @@ func TestTxnImplCommitOrAbort(t *testing.T) {
        //The operations of committing txn1 should success at the first time 
and fail at the second time.
        txn1 := createTxn(tc, t)
        err := txn1.Commit(context.Background())
-       require.Nil(t, err, fmt.Sprintf("Failed to commit the transaction 
%d:%d\n", txn1.txnID.mostSigBits,
-               txn1.txnID.leastSigBits))
+       require.Nil(t, err, fmt.Sprintf("Failed to commit the transaction 
%d:%d\n", txn1.txnID.MostSigBits,
+               txn1.txnID.LeastSigBits))
        txn1.state = TxnOpen
        txn1.opsFlow <- true
        err = txn1.Commit(context.Background())
@@ -117,7 +117,7 @@ func TestTxnImplCommitOrAbort(t *testing.T) {
        txn2 := newTransaction(*id2, tc, time.Hour)
        err = txn2.Abort(context.Background())
        require.Nil(t, err, fmt.Sprintf("Failed to abort the transaction 
%d:%d\n",
-               id2.mostSigBits, id2.leastSigBits))
+               id2.MostSigBits, id2.LeastSigBits))
        txn2.state = TxnOpen
        txn2.opsFlow <- true
        err = txn2.Abort(context.Background())
@@ -209,6 +209,7 @@ func createTcClient(t *testing.T) 
(*transactionCoordinatorClient, *client) {
                URL:                   webServiceURLTLS,
                TLSTrustCertsFilePath: caCertsPath,
                Authentication:        NewAuthenticationTLS(tlsClientCertPath, 
tlsClientKeyPath),
+               EnableTransaction:     true,
        })
        require.Nil(t, err, "Failed to create client.")
        tcClient := newTransactionCoordinatorClientImpl(c.(*client))
@@ -217,3 +218,107 @@ func createTcClient(t *testing.T) 
(*transactionCoordinatorClient, *client) {
 
        return tcClient, c.(*client)
 }
+
+// TestConsumeAndProduceWithTxn is a test function that validates the behavior 
of producing and consuming
+// messages with and without transactions. It consists of the following steps:
+//
+// 1. Prepare: Create a PulsarClient and initialize the transaction 
coordinator client.
+// 2. Prepare: Create a topic and a subscription.
+// 3. Produce 10 messages with a transaction and 10 messages without a 
transaction.
+// - Expectation: The consumer should be able to receive the 10 messages sent 
without a transaction,
+// but not the 10 messages sent with the transaction.
+// 4. Commit the transaction and receive the remaining 10 messages.
+// - Expectation: The consumer should be able to receive the 10 messages sent 
with the transaction.
+// 5. Clean up: Close the consumer and producer instances.
+//
+// The test ensures that the consumer can only receive messages sent with a 
transaction after it is committed,
+// and that it can always receive messages sent without a transaction.
+func TestConsumeAndProduceWithTxn(t *testing.T) {
+       // Step 1: Prepare - Create PulsarClient and initialize the transaction 
coordinator client.
+       topic := newTopicName()
+       sub := "my-sub"
+       _, client := createTcClient(t)
+       // Step 2: Prepare - Create Topic and Subscription.
+       consumer, err := client.Subscribe(ConsumerOptions{
+               Topic:            topic,
+               SubscriptionName: sub,
+       })
+       assert.NoError(t, err)
+       producer, _ := client.CreateProducer(ProducerOptions{
+               Topic:       topic,
+               SendTimeout: 0,
+       })
+       // Step 3: Open a transaction, send 10 messages with the transaction 
and 10 messages without the transaction.
+       // Expectation: We can receive the 10 messages sent without a 
transaction and
+       // cannot receive the 10 messages sent with the transaction.
+       txn, err := client.NewTransaction(time.Hour)
+       require.Nil(t, err)
+       for i := 0; i < 10; i++ {
+               _, err = producer.Send(context.Background(), &ProducerMessage{
+                       Payload: make([]byte, 1024),
+               })
+               require.Nil(t, err)
+       }
+       for i := 0; i < 10; i++ {
+               _, err := producer.Send(context.Background(), &ProducerMessage{
+                       Transaction: txn,
+                       Payload:     make([]byte, 1024),
+               })
+               require.Nil(t, err)
+       }
+       // Attempt to receive and acknowledge the 10 messages sent without a 
transaction.
+       for i := 0; i < 10; i++ {
+               msg, _ := consumer.Receive(context.Background())
+               assert.NotNil(t, msg)
+               err = consumer.Ack(msg)
+               assert.Nil(t, err)
+       }
+       // Create a goroutine to attempt receiving a message and send it to the 
'done' channel.
+       done := make(chan Message)
+       go func() {
+               msg, _ := consumer.Receive(context.Background())
+               err := consumer.AckID(msg.ID())
+               require.Nil(t, err)
+               close(done)
+       }()
+       // Expectation: The consumer should not receive uncommitted messages.
+       select {
+       case <-done:
+               require.Fail(t, "The consumer should not receive uncommitted 
message")
+       case <-time.After(time.Second):
+       }
+       // Step 4: After committing the transaction, we should be able to 
receive the remaining 10 messages.
+       // Acknowledge the rest of the 10 messages with the transaction.
+       // Expectation: After committing the transaction, all messages of the 
subscription will be acknowledged.
+       _ = txn.Commit(context.Background())
+       txn, err = client.NewTransaction(time.Hour)
+       require.Nil(t, err)
+       for i := 0; i < 9; i++ {
+               msg, _ := consumer.Receive(context.Background())
+               require.NotNil(t, msg)
+               err = consumer.AckWithTxn(msg, txn)
+               require.Nil(t, err)
+       }
+       consumer.Close()
+       consumer, _ = client.Subscribe(ConsumerOptions{
+               Topic:            topic,
+               SubscriptionName: sub,
+       })
+       // Create a goroutine to attempt receiving a message and send it to the 
'done' channel.
+       done = make(chan Message)
+       go func() {
+               consumer.Receive(context.Background())
+               close(done)
+       }()
+
+       // Expectation: The consumer should not receive uncommitted messages.
+       select {
+       case <-done:
+               require.Fail(t, "The consumer should not receive uncommitted 
message")
+       case <-time.After(time.Second):
+       }
+
+       // Step 5: Clean up - Close the consumer and producer instances.
+       consumer.Close()
+       producer.Close()
+}

Reply via email to