This is an automated email from the ASF dual-hosted git repository.

zixuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 274a10fe Fix: Potential data race (#1338)
274a10fe is described below

commit 274a10fe9898df7d0ef3375f7982c23e6bcb9560
Author: gunli <[email protected]>
AuthorDate: Mon Mar 10 14:46:38 2025 +0800

    Fix: Potential data race (#1338)
    
    * fix: potential data race
    
    * stop writing if ctx is done
    
    * pass a not nil context
    
    * check nil ctx
    
    * revert
    
    * delete ctx nil check
    
    * revert pendingItem.done() to its old position
---
 pulsar/consumer_multitopic_test.go |  6 +++--
 pulsar/internal/connection.go      | 48 +++++++++++++++++++++++++-------------
 pulsar/producer_partition.go       | 19 +++++++++++----
 3 files changed, 50 insertions(+), 23 deletions(-)

diff --git a/pulsar/consumer_multitopic_test.go 
b/pulsar/consumer_multitopic_test.go
index cd236ecc..30ae5ccd 100644
--- a/pulsar/consumer_multitopic_test.go
+++ b/pulsar/consumer_multitopic_test.go
@@ -18,18 +18,20 @@
 package pulsar
 
 import (
+       "context"
        "errors"
        "fmt"
        "strings"
        "testing"
        "time"
 
+       "github.com/stretchr/testify/assert"
+
        "github.com/apache/pulsar-client-go/pulsar/internal"
        pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
        "github.com/apache/pulsar-client-go/pulsaradmin"
        "github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config"
        "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils"
-       "github.com/stretchr/testify/assert"
 )
 
 func TestMultiTopicConsumerReceive(t *testing.T) {
@@ -330,7 +332,7 @@ func (dummyConnection) SendRequestNoWait(_ *pb.BaseCommand) 
error {
        return nil
 }
 
-func (dummyConnection) WriteData(_ internal.Buffer) {
+func (dummyConnection) WriteData(_ context.Context, _ internal.Buffer) {
 }
 
 func (dummyConnection) RegisterListener(_ uint64, _ 
internal.ConnectionListener) error {
diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go
index 57fc7241..2ad1acb5 100644
--- a/pulsar/internal/connection.go
+++ b/pulsar/internal/connection.go
@@ -18,6 +18,7 @@
 package internal
 
 import (
+       "context"
        "crypto/tls"
        "crypto/x509"
        "errors"
@@ -78,7 +79,7 @@ type ConnectionListener interface {
 type Connection interface {
        SendRequest(requestID uint64, req *pb.BaseCommand, callback 
func(*pb.BaseCommand, error))
        SendRequestNoWait(req *pb.BaseCommand) error
-       WriteData(data Buffer)
+       WriteData(ctx context.Context, data Buffer)
        RegisterListener(id uint64, listener ConnectionListener) error
        UnregisterListener(id uint64)
        AddConsumeHandler(id uint64, handler ConsumerHandler) error
@@ -129,6 +130,11 @@ type request struct {
        callback func(command *pb.BaseCommand, err error)
 }
 
+type dataRequest struct {
+       ctx  context.Context
+       data Buffer
+}
+
 type connection struct {
        started           int32
        connectionTimeout time.Duration
@@ -157,7 +163,7 @@ type connection struct {
        incomingRequestsCh chan *request
        closeCh            chan struct{}
        readyCh            chan struct{}
-       writeRequestsCh    chan Buffer
+       writeRequestsCh    chan *dataRequest
 
        pendingLock sync.Mutex
        pendingReqs map[uint64]*request
@@ -209,7 +215,7 @@ func newConnection(opts connectionOptions) *connection {
                // partition produces writing on a single connection. In 
general it's
                // good to keep this above the number of partition producers 
assigned
                // to a single connection.
-               writeRequestsCh:  make(chan Buffer, 256),
+               writeRequestsCh:  make(chan *dataRequest, 256),
                listeners:        make(map[uint64]ConnectionListener),
                consumerHandlers: make(map[uint64]ConsumerHandler),
                metrics:          opts.metrics,
@@ -421,11 +427,11 @@ func (c *connection) run() {
                                return // TODO: this never gonna be happen
                        }
                        c.internalSendRequest(req)
-               case data := <-c.writeRequestsCh:
-                       if data == nil {
+               case req := <-c.writeRequestsCh:
+                       if req == nil {
                                return
                        }
-                       c.internalWriteData(data)
+                       c.internalWriteData(req.ctx, req.data)
 
                case <-pingSendTicker.C:
                        c.sendPing()
@@ -450,22 +456,26 @@ func (c *connection) runPingCheck(pingCheckTicker 
*time.Ticker) {
        }
 }
 
-func (c *connection) WriteData(data Buffer) {
+func (c *connection) WriteData(ctx context.Context, data Buffer) {
        select {
-       case c.writeRequestsCh <- data:
+       case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}:
                // Channel is not full
                return
-
+       case <-ctx.Done():
+               c.log.Debug("Write data context cancelled")
+               return
        default:
                // Channel full, fallback to probe if connection is closed
        }
 
        for {
                select {
-               case c.writeRequestsCh <- data:
+               case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}:
                        // Successfully wrote on the channel
                        return
-
+               case <-ctx.Done():
+                       c.log.Debug("Write data context cancelled")
+                       return
                case <-time.After(100 * time.Millisecond):
                        // The channel is either:
                        // 1. blocked, in which case we need to wait until we 
have space
@@ -481,11 +491,17 @@ func (c *connection) WriteData(data Buffer) {
 
 }
 
-func (c *connection) internalWriteData(data Buffer) {
+func (c *connection) internalWriteData(ctx context.Context, data Buffer) {
        c.log.Debug("Write data: ", data.ReadableBytes())
-       if _, err := c.cnx.Write(data.ReadableSlice()); err != nil {
-               c.log.WithError(err).Warn("Failed to write on connection")
-               c.Close()
+
+       select {
+       case <-ctx.Done():
+               return
+       default:
+               if _, err := c.cnx.Write(data.ReadableSlice()); err != nil {
+                       c.log.WithError(err).Warn("Failed to write on 
connection")
+                       c.Close()
+               }
        }
 }
 
@@ -510,7 +526,7 @@ func (c *connection) writeCommand(cmd *pb.BaseCommand) {
        }
 
        c.writeBuffer.WrittenBytes(cmdSize)
-       c.internalWriteData(c.writeBuffer)
+       c.internalWriteData(context.Background(), c.writeBuffer)
 }
 
 func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload 
Buffer) {
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 448f780c..f6523124 100755
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -394,7 +394,7 @@ func (p *partitionProducer) grabCnx(assignedBrokerURL 
string) error {
                        pi.sentAt = time.Now()
                        pi.Unlock()
                        p.pendingQueue.Put(pi)
-                       p._getConn().WriteData(pi.buffer)
+                       p._getConn().WriteData(pi.ctx, pi.buffer)
 
                        if pi == lastViewItem {
                                break
@@ -837,6 +837,8 @@ func (p *partitionProducer) internalSingleSend(
 
 type pendingItem struct {
        sync.Mutex
+       ctx           context.Context
+       cancel        context.CancelFunc
        buffer        internal.Buffer
        sequenceID    uint64
        createdAt     time.Time
@@ -895,14 +897,17 @@ func (p *partitionProducer) writeData(buffer 
internal.Buffer, sequenceID uint64,
                return
        default:
                now := time.Now()
+               ctx, cancel := context.WithCancel(context.Background())
                p.pendingQueue.Put(&pendingItem{
+                       ctx:          ctx,
+                       cancel:       cancel,
                        createdAt:    now,
                        sentAt:       now,
                        buffer:       buffer,
                        sequenceID:   sequenceID,
                        sendRequests: callbacks,
                })
-               p._getConn().WriteData(buffer)
+               p._getConn().WriteData(ctx, buffer)
        }
 }
 
@@ -1579,14 +1584,14 @@ type sendRequest struct {
        uuid             string
        chunkRecorder    *chunkRecorder
 
-       /// resource management
+       // resource management
 
        memLimit          internal.MemoryLimitController
        reservedMem       int64
        semaphore         internal.Semaphore
        reservedSemaphore int
 
-       /// convey settable state
+       // convey settable state
 
        sendAsBatch         bool
        transaction         *transaction
@@ -1659,7 +1664,7 @@ func (sr *sendRequest) done(msgID MessageID, err error) {
 }
 
 func (p *partitionProducer) blockIfQueueFull() bool {
-       //DisableBlockIfQueueFull == false means enable block
+       // DisableBlockIfQueueFull == false means enable block
        return !p.options.DisableBlockIfQueueFull
 }
 
@@ -1741,6 +1746,10 @@ func (i *pendingItem) done(err error) {
        if i.flushCallback != nil {
                i.flushCallback(err)
        }
+
+       if i.cancel != nil {
+               i.cancel()
+       }
 }
 
 // _setConn sets the internal connection field of this partition producer 
atomically.

Reply via email to