commit ddcdfc4f0922e00c672797b0d6544371423f2989
Author: Cecylia Bocovich <coh...@torproject.org>
Date:   Thu Jun 17 16:36:50 2021 -0400

    Fix datarace for WebRTCPeer.closed
    
    The race condition occurs because concurrent goroutines are intermixing
    reads and writes of `WebRTCPeer.closed`.
    
    Spotted when integrating Snowflake inside OONI in
    https://github.com/ooni/probe-cli/pull/373.
---
 client/lib/lib_test.go |  6 +++---
 client/lib/peers.go    |  4 ++--
 client/lib/webrtc.go   | 24 ++++++++++++++++++------
 3 files changed, 23 insertions(+), 11 deletions(-)

diff --git a/client/lib/lib_test.go b/client/lib/lib_test.go
index e742e06..55ea7b9 100644
--- a/client/lib/lib_test.go
+++ b/client/lib/lib_test.go
@@ -33,7 +33,7 @@ type FakeDialer struct {
 
 func (w FakeDialer) Catch() (*WebRTCPeer, error) {
        fmt.Println("Caught a dummy snowflake.")
-       return &WebRTCPeer{}, nil
+       return &WebRTCPeer{closed: make(chan struct{})}, nil
 }
 
 func (w FakeDialer) GetMax() int {
@@ -97,7 +97,7 @@ func TestSnowflakeClient(t *testing.T) {
                        So(err, ShouldNotBeNil)
                        So(p.Count(), ShouldEqual, c)
 
-                       // But popping and closing allows it to continue.
+                       // But popping allows it to continue.
                        s := p.Pop()
                        s.Close()
                        So(s, ShouldNotBeNil)
@@ -127,7 +127,7 @@ func TestSnowflakeClient(t *testing.T) {
                        cnt := 5
                        p, _ := NewPeers(FakeDialer{max: cnt})
                        for i := 0; i < cnt; i++ {
-                               p.activePeers.PushBack(&WebRTCPeer{})
+                               p.activePeers.PushBack(&WebRTCPeer{closed: 
make(chan struct{})})
                        }
                        So(p.Count(), ShouldEqual, cnt)
                        p.End()
diff --git a/client/lib/peers.go b/client/lib/peers.go
index d02eed3..6fa2d29 100644
--- a/client/lib/peers.go
+++ b/client/lib/peers.go
@@ -83,7 +83,7 @@ func (p *Peers) Pop() *WebRTCPeer {
                if !ok {
                        return nil
                }
-               if snowflake.closed {
+               if snowflake.Closed() {
                        continue
                }
                // Set to use the same rate-limited traffic logger to keep 
consistency.
@@ -110,7 +110,7 @@ func (p *Peers) purgeClosedPeers() {
                next := e.Next()
                conn := e.Value.(*WebRTCPeer)
                // Purge those marked for deletion.
-               if conn.closed {
+               if conn.Closed() {
                        p.activePeers.Remove(e)
                }
                e = next
diff --git a/client/lib/webrtc.go b/client/lib/webrtc.go
index 6a42ebd..234f53c 100644
--- a/client/lib/webrtc.go
+++ b/client/lib/webrtc.go
@@ -28,7 +28,7 @@ type WebRTCPeer struct {
        lastReceive time.Time
 
        open   chan struct{} // Channel to notify when datachannel opens
-       closed bool
+       closed chan struct{}
 
        once sync.Once // Synchronization for PeerConnection destruction
 
@@ -46,6 +46,7 @@ func NewWebRTCPeer(config *webrtc.Configuration,
                }
                connection.id = "snowflake-" + hex.EncodeToString(buf[:])
        }
+       connection.closed = make(chan struct{})
 
        // Override with something that's not NullLogger to have real logging.
        connection.BytesLogger = &BytesNullLogger{}
@@ -78,9 +79,19 @@ func (c *WebRTCPeer) Write(b []byte) (int, error) {
        return len(b), nil
 }
 
+//Returns a boolean indicated whether the peer is closed
+func (c *WebRTCPeer) Closed() bool {
+       select {
+       case <-c.closed:
+               return true
+       default:
+       }
+       return false
+}
+
 func (c *WebRTCPeer) Close() error {
        c.once.Do(func() {
-               c.closed = true
+               close(c.closed)
                c.cleanup()
                log.Printf("WebRTC: Closing")
        })
@@ -95,9 +106,6 @@ func (c *WebRTCPeer) checkForStaleness() {
        c.lastReceive = time.Now()
        c.mu.Unlock()
        for {
-               if c.closed {
-                       return
-               }
                c.mu.Lock()
                lastReceive := c.lastReceive
                c.mu.Unlock()
@@ -107,7 +115,11 @@ func (c *WebRTCPeer) checkForStaleness() {
                        c.Close()
                        return
                }
-               <-time.After(time.Second)
+               select {
+               case <-c.closed:
+                       return
+               case <-time.After(time.Second):
+               }
        }
 }
 



_______________________________________________
tor-commits mailing list
tor-commits@lists.torproject.org
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to