Skip to content

Commit

Permalink
Fix closed
Browse files Browse the repository at this point in the history
  • Loading branch information
mingyech committed Oct 19, 2023
1 parent c3c80b1 commit 82a82da
Showing 1 changed file with 53 additions and 26 deletions.
79 changes: 53 additions & 26 deletions pkg/dtls/heartbeat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -23,13 +24,17 @@ type mockStream struct {
recvCh <-chan []byte
closeOnce sync.Once
closeCh chan struct{}
closed uint32
}

func (*mockStream) BufferedAmount() uint64 { return 0 }
func (*mockStream) SetBufferedAmountLowThreshold(th uint64) {}
func (*mockStream) OnBufferedAmountLow(f func()) {}

func (s *mockStream) Read(b []byte) (int, error) {
if atomic.LoadUint32(&s.closed) == 1 {
return 0, net.ErrClosed
}
s.rddlMutex.RLock()
defer s.rddlMutex.RUnlock()

Expand All @@ -52,6 +57,10 @@ func (s *mockStream) Read(b []byte) (int, error) {
}
}
func (s *mockStream) Write(b []byte) (int, error) {
if atomic.LoadUint32(&s.closed) == 1 {
return 0, net.ErrClosed
}

s.wddlMutex.RLock()
defer s.wddlMutex.RUnlock()

Expand Down Expand Up @@ -88,6 +97,7 @@ func (s *mockStream) SetWriteDeadline(t time.Time) error {
}

func (s *mockStream) Close() error {
atomic.StoreUint32(&s.closed, 1)
s.closeOnce.Do(func() { close(s.closeCh) })

return nil
Expand Down Expand Up @@ -121,50 +131,67 @@ func TestHeartbeatReadWrite(t *testing.T) {
sent := 0
toSend := []byte("testtt")
sendTimes := 5
sleepInterval := 1000 * time.Millisecond
sleepInterval := conf.Interval / 3
var wg sync.WaitGroup

ctx, cancel := context.WithTimeout(context.Background(), sleepInterval*5+sleepInterval/2)
defer cancel()

wg.Add(1)
go func() {
go func(ctx1 context.Context) {
defer wg.Done()
defer client.Close()
defer server.Close()
for i := 0; i < sendTimes; i++ {
buffer := make([]byte, 4096)
err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
n, err := s.Read(buffer)
if err != nil {
t.Log(err)
return
}
if string(toSend) != string(buffer[:n]) {
t.Log("read incorrect value", toSend, buffer[:n])
t.Fail()
select {
case <-ctx1.Done():
server.Close()
return
default:
buffer := make([]byte, 4096)
err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
n, err := s.Read(buffer)
if err != nil {
t.Log(err)
return
}
if string(toSend) != string(buffer[:n]) {
t.Log("read incorrect value", toSend, buffer[:n])
t.Fail()
return
}
recvd++

}
recvd++
}
}()
}(ctx)

wg.Add(1)
go func() {
go func(ctx2 context.Context) {
defer wg.Done()

for i := 0; i < sendTimes; i++ {
err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
_, err = client.Write(toSend)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Log("encountered error writing", err)
t.Fail()
}
select {
case <-ctx2.Done():
client.Close()
return
default:
err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2))
require.Nil(t, err)
_, err = client.Write(toSend)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Log("encountered error writing", err)
t.Fail()
}
return
}
sent++
}
sent++
time.Sleep(sleepInterval)
}
}()
}(ctx)

wg.Wait()
require.Equal(t, sent, recvd)
Expand Down

0 comments on commit 82a82da

Please sign in to comment.