Skip to content

Commit

Permalink
Fix data race
Browse files Browse the repository at this point in the history
  • Loading branch information
mingyech committed Oct 19, 2023
1 parent 585fa82 commit 0df526a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 92 deletions.
25 changes: 13 additions & 12 deletions pkg/dtls/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ var ErrInsufficientBuffer = errors.New("buffer too small to hold the received da
type hbConn struct {
stream msgStream

recvCh chan errBytes
waiting uint32
hb []byte
timeout time.Duration
buffer []byte
recvCh chan errBytes
waiting uint32
hb []byte
timeout time.Duration
maxMessageSize int
}

type errBytes struct {
Expand All @@ -29,10 +29,10 @@ func heartbeatServer(stream msgStream, config *heartbeatConfig, maxMessageSize i
conf := validate(config)

c := &hbConn{stream: stream,
recvCh: make(chan errBytes),
timeout: conf.Interval,
hb: conf.Heartbeat,
buffer: make([]byte, maxMessageSize),
recvCh: make(chan errBytes),
timeout: conf.Interval,
hb: conf.Heartbeat,
maxMessageSize: maxMessageSize,
}

atomic.StoreUint32(&c.waiting, 2)
Expand All @@ -58,15 +58,16 @@ func (c *hbConn) hbLoop() {

func (c *hbConn) recvLoop() {
for {
buffer := make([]byte, c.maxMessageSize)

n, err := c.stream.Read(c.buffer)
n, err := c.stream.Read(buffer)

if bytes.Equal(c.hb, c.buffer[:n]) {
if bytes.Equal(c.hb, buffer[:n]) {
atomic.AddUint32(&c.waiting, 1)
continue
}

c.recvCh <- errBytes{c.buffer[:n], err}
c.recvCh <- errBytes{buffer[:n], err}
}

}
Expand Down
131 changes: 54 additions & 77 deletions pkg/dtls/heartbeat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -16,83 +15,80 @@ var maxMsgSize = 65535
var conf = &heartbeatConfig{Interval: 1 * time.Second, Heartbeat: []byte("hihihihihihihihihi")}

type mockStream struct {
rddl time.Time
wddl time.Time
sendCh chan<- []byte
recvCh <-chan []byte
closeCh chan struct{}
closed bool
rddl time.Time
wddl time.Time
rddlMutex sync.RWMutex
wddlMutex sync.RWMutex
sendCh chan<- []byte
recvCh <-chan []byte
closeOnce sync.Once
closeCh chan struct{}
}

func (*mockStream) BufferedAmount() uint64 { return 0 }
func (*mockStream) SetBufferedAmountLowThreshold(th uint64) {}
func (*mockStream) OnBufferedAmountLow(f func()) {}

func (s *mockStream) Read(b []byte) (int, error) {
if s.closed {
return 0, net.ErrClosed
}
s.rddlMutex.RLock()
defer s.rddlMutex.RUnlock()

if s.rddl.IsZero() {
select {
case buf := <-s.recvCh:
return copy(b, buf), nil
case <-s.closeCh:
return 0, net.ErrClosed
case buf := <-s.recvCh:
return copy(b, buf), nil
}
}

select {
case <-s.closeCh:
return 0, net.ErrClosed
case buf := <-s.recvCh:
return copy(b, buf), nil
case <-time.After(time.Until(s.rddl)):
return 0, net.ErrClosed
case <-s.closeCh:
return 0, net.ErrClosed
}
}
func (s *mockStream) Write(b []byte) (int, error) {

if s.closed {
return 0, net.ErrClosed
}
s.wddlMutex.RLock()
defer s.wddlMutex.RUnlock()

if s.wddl.IsZero() {
select {
case s.sendCh <- b:
return len(b), nil
case <-s.closeCh:
return 0, net.ErrClosed
case s.sendCh <- b:
return len(b), nil
}
}

select {
case <-s.closeCh:
return 0, net.ErrClosed
case s.sendCh <- b:
return len(b), nil
case <-time.After(time.Until(s.wddl)):
return 0, net.ErrClosed
case <-s.closeCh:
return 0, net.ErrClosed
}
}

func (s *mockStream) SetReadDeadline(t time.Time) error {
s.rddlMutex.Lock()
defer s.rddlMutex.Unlock()
s.rddl = t
return nil
}
func (s *mockStream) SetWriteDeadline(t time.Time) error {
s.wddlMutex.Lock()
defer s.wddlMutex.Unlock()
s.wddl = t
return nil
}

func (s *mockStream) Close() error {
if s.closed {
return errors.New("mockStream already closed")
}

s.closed = true

close(s.closeCh)
s.closeOnce.Do(func() { close(s.closeCh) })

return nil
}
Expand Down Expand Up @@ -121,72 +117,54 @@ func TestHeartbeatReadWrite(t *testing.T) {
err = heartbeatClient(client, conf)
require.Nil(t, err)

sent := uint32(0)
recvd := uint32(0)
recvd := 0
toSend := []byte("testtt")
sendTimes := 5
sleepInterval := 400 * time.Millisecond
var wg sync.WaitGroup

ctx, cancel := context.WithTimeout(
context.Background(),
time.Duration(10*sleepInterval+sleepInterval/2))

defer cancel()

wg.Add(1)
go func(ctx1 context.Context) {
go func() {
defer wg.Done()
defer client.Close()
defer server.Close()
for {
select {
case <-ctx1.Done():
for i := 0; i < sendTimes; i++ {
buffer := make([]byte, 4096)
err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
n, err := s.Read(buffer)
if err != nil {
return
}
if string(toSend) != string(buffer[:n]) {
t.Log("read incorrect value", toSend, buffer[:n])
t.Fail()
return
default:
buffer := make([]byte, 4096)
err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
n, err := s.Read(buffer)
if err != nil {
return
}
if string(toSend) != string(buffer[:n]) {
t.Log("read incorrect value", toSend, buffer[:n])
t.Fail()
return
}
atomic.AddUint32(&recvd, 1)
}
recvd++
}
}(ctx)
}()

wg.Add(1)
go func(ctx2 context.Context) {
go func() {
defer wg.Done()
for {
select {
case <-ctx2.Done():
client.Close()
return
default:
err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
_, err = client.Write(toSend)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Log("encountered error writing", err)
t.Fail()
}
return
for i := 0; i < sendTimes; i++ {
err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
_, err = client.Write(toSend)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Log("encountered error writing", err)
t.Fail()
}
atomic.AddUint32(&sent, 1)
return
}
time.Sleep(sleepInterval)
}
}(ctx)
}()

wg.Wait()
require.Equal(t, atomic.LoadUint32(&sent), atomic.LoadUint32(&recvd))
require.Equal(t, sendTimes, recvd)
}

func TestHeartbeatSend(t *testing.T) {
Expand Down Expand Up @@ -248,8 +226,7 @@ func TestHeartbeatTimeout(t *testing.T) {
_, err = s.Write([]byte("123"))
require.Nil(t, err)

stop := time.After(conf.Interval + 100*time.Millisecond)
<-stop
time.Sleep(2 * conf.Interval)
_, err = s.Write([]byte("123"))
require.NotNil(t, err)

Expand Down
3 changes: 0 additions & 3 deletions pkg/dtls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -36,8 +35,6 @@ func TestSend(t *testing.T) {
require.Equal(t, toSend, received)
}()

time.Sleep(1 * time.Second)

c, err := Client(client, &Config{PSK: sharedSecret, SCTP: ClientOpen})
require.Nil(t, err)

Expand Down

0 comments on commit 0df526a

Please sign in to comment.