From 0df526a524e4b0250187c82efb1fd636ba60873a Mon Sep 17 00:00:00 2001 From: Mingye Chen Date: Thu, 19 Oct 2023 11:57:19 -0400 Subject: [PATCH] Fix data race --- pkg/dtls/heartbeat.go | 25 +++---- pkg/dtls/heartbeat_test.go | 131 +++++++++++++++---------------------- pkg/dtls/server_test.go | 3 - 3 files changed, 67 insertions(+), 92 deletions(-) diff --git a/pkg/dtls/heartbeat.go b/pkg/dtls/heartbeat.go index 5209ec41..4ac6b75b 100644 --- a/pkg/dtls/heartbeat.go +++ b/pkg/dtls/heartbeat.go @@ -12,11 +12,11 @@ var ErrInsufficientBuffer = errors.New("buffer too small to hold the received da type hbConn struct { stream msgStream - recvCh chan errBytes - waiting uint32 - hb []byte - timeout time.Duration - buffer []byte + recvCh chan errBytes + waiting uint32 + hb []byte + timeout time.Duration + maxMessageSize int } type errBytes struct { @@ -29,10 +29,10 @@ func heartbeatServer(stream msgStream, config *heartbeatConfig, maxMessageSize i conf := validate(config) c := &hbConn{stream: stream, - recvCh: make(chan errBytes), - timeout: conf.Interval, - hb: conf.Heartbeat, - buffer: make([]byte, maxMessageSize), + recvCh: make(chan errBytes), + timeout: conf.Interval, + hb: conf.Heartbeat, + maxMessageSize: maxMessageSize, } atomic.StoreUint32(&c.waiting, 2) @@ -58,15 +58,16 @@ func (c *hbConn) hbLoop() { func (c *hbConn) recvLoop() { for { + buffer := make([]byte, c.maxMessageSize) - n, err := c.stream.Read(c.buffer) + n, err := c.stream.Read(buffer) - if bytes.Equal(c.hb, c.buffer[:n]) { + if bytes.Equal(c.hb, buffer[:n]) { atomic.AddUint32(&c.waiting, 1) continue } - c.recvCh <- errBytes{c.buffer[:n], err} + c.recvCh <- errBytes{buffer[:n], err} } } diff --git a/pkg/dtls/heartbeat_test.go b/pkg/dtls/heartbeat_test.go index 65b11762..2f183f29 100644 --- a/pkg/dtls/heartbeat_test.go +++ b/pkg/dtls/heartbeat_test.go @@ -5,7 +5,6 @@ import ( "errors" "net" "sync" - "sync/atomic" "testing" "time" @@ -16,12 +15,14 @@ var maxMsgSize = 65535 var conf = &heartbeatConfig{Interval: 1 * time.Second, Heartbeat: []byte("hihihihihihihihihi")} type mockStream struct { - rddl time.Time - wddl time.Time - sendCh chan<- []byte - recvCh <-chan []byte - closeCh chan struct{} - closed bool + rddl time.Time + wddl time.Time + rddlMutex sync.RWMutex + wddlMutex sync.RWMutex + sendCh chan<- []byte + recvCh <-chan []byte + closeOnce sync.Once + closeCh chan struct{} } func (*mockStream) BufferedAmount() uint64 { return 0 } @@ -29,70 +30,65 @@ func (*mockStream) SetBufferedAmountLowThreshold(th uint64) {} func (*mockStream) OnBufferedAmountLow(f func()) {} func (s *mockStream) Read(b []byte) (int, error) { - if s.closed { - return 0, net.ErrClosed - } + s.rddlMutex.RLock() + defer s.rddlMutex.RUnlock() if s.rddl.IsZero() { select { - case buf := <-s.recvCh: - return copy(b, buf), nil case <-s.closeCh: return 0, net.ErrClosed + case buf := <-s.recvCh: + return copy(b, buf), nil } } select { + case <-s.closeCh: + return 0, net.ErrClosed case buf := <-s.recvCh: return copy(b, buf), nil case <-time.After(time.Until(s.rddl)): return 0, net.ErrClosed - case <-s.closeCh: - return 0, net.ErrClosed } } func (s *mockStream) Write(b []byte) (int, error) { - - if s.closed { - return 0, net.ErrClosed - } + s.wddlMutex.RLock() + defer s.wddlMutex.RUnlock() if s.wddl.IsZero() { select { - case s.sendCh <- b: - return len(b), nil case <-s.closeCh: return 0, net.ErrClosed + case s.sendCh <- b: + return len(b), nil } } select { + case <-s.closeCh: + return 0, net.ErrClosed case s.sendCh <- b: return len(b), nil case <-time.After(time.Until(s.wddl)): return 0, net.ErrClosed - case <-s.closeCh: - return 0, net.ErrClosed } } func (s *mockStream) SetReadDeadline(t time.Time) error { + s.rddlMutex.Lock() + defer s.rddlMutex.Unlock() s.rddl = t return nil } func (s *mockStream) SetWriteDeadline(t time.Time) error { + s.wddlMutex.Lock() + defer s.wddlMutex.Unlock() s.wddl = t return nil } func (s *mockStream) Close() error { - if s.closed { - return errors.New("mockStream already closed") - } - - s.closed = true - - close(s.closeCh) + s.closeOnce.Do(func() { close(s.closeCh) }) return nil } @@ -121,72 +117,54 @@ func TestHeartbeatReadWrite(t *testing.T) { err = heartbeatClient(client, conf) require.Nil(t, err) - sent := uint32(0) - recvd := uint32(0) + recvd := 0 toSend := []byte("testtt") + sendTimes := 5 sleepInterval := 400 * time.Millisecond var wg sync.WaitGroup - ctx, cancel := context.WithTimeout( - context.Background(), - time.Duration(10*sleepInterval+sleepInterval/2)) - - defer cancel() - wg.Add(1) - go func(ctx1 context.Context) { + go func() { defer wg.Done() defer client.Close() defer server.Close() - for { - select { - case <-ctx1.Done(): + for i := 0; i < sendTimes; i++ { + buffer := make([]byte, 4096) + err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2)) + require.Nil(t, err) + n, err := s.Read(buffer) + if err != nil { + return + } + if string(toSend) != string(buffer[:n]) { + t.Log("read incorrect value", toSend, buffer[:n]) + t.Fail() return - default: - buffer := make([]byte, 4096) - err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2)) - require.Nil(t, err) - n, err := s.Read(buffer) - if err != nil { - return - } - if string(toSend) != string(buffer[:n]) { - t.Log("read incorrect value", toSend, buffer[:n]) - t.Fail() - return - } - atomic.AddUint32(&recvd, 1) } + recvd++ } - }(ctx) + }() wg.Add(1) - go func(ctx2 context.Context) { + go func() { defer wg.Done() - for { - select { - case <-ctx2.Done(): - client.Close() - return - default: - err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2)) - require.Nil(t, err) - _, err = client.Write(toSend) - if err != nil { - if !errors.Is(err, net.ErrClosed) { - t.Log("encountered error writing", err) - t.Fail() - } - return + for i := 0; i < sendTimes; i++ { + err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2)) + require.Nil(t, err) + _, err = client.Write(toSend) + if err != nil { + if !errors.Is(err, net.ErrClosed) { + t.Log("encountered error writing", err) + t.Fail() } - atomic.AddUint32(&sent, 1) + return } time.Sleep(sleepInterval) } - }(ctx) + }() wg.Wait() - require.Equal(t, atomic.LoadUint32(&sent), atomic.LoadUint32(&recvd)) + require.Equal(t, sendTimes, recvd) } func TestHeartbeatSend(t *testing.T) { @@ -248,8 +226,7 @@ func TestHeartbeatTimeout(t *testing.T) { _, err = s.Write([]byte("123")) require.Nil(t, err) - stop := time.After(conf.Interval + 100*time.Millisecond) - <-stop + time.Sleep(2 * conf.Interval) _, err = s.Write([]byte("123")) require.NotNil(t, err) diff --git a/pkg/dtls/server_test.go b/pkg/dtls/server_test.go index febf05a8..23de97a6 100644 --- a/pkg/dtls/server_test.go +++ b/pkg/dtls/server_test.go @@ -5,7 +5,6 @@ import ( "net" "sync" "testing" - "time" "github.com/stretchr/testify/require" ) @@ -36,8 +35,6 @@ func TestSend(t *testing.T) { require.Equal(t, toSend, received) }() - time.Sleep(1 * time.Second) - c, err := Client(client, &Config{PSK: sharedSecret, SCTP: ClientOpen}) require.Nil(t, err)