diff --git a/pkg/dtls/heartbeat_test.go b/pkg/dtls/heartbeat_test.go index ffd8464d..1b1e040d 100644 --- a/pkg/dtls/heartbeat_test.go +++ b/pkg/dtls/heartbeat_test.go @@ -5,6 +5,7 @@ import ( "errors" "net" "sync" + "sync/atomic" "testing" "time" @@ -23,6 +24,7 @@ type mockStream struct { recvCh <-chan []byte closeOnce sync.Once closeCh chan struct{} + closed uint32 } func (*mockStream) BufferedAmount() uint64 { return 0 } @@ -30,6 +32,9 @@ func (*mockStream) SetBufferedAmountLowThreshold(th uint64) {} func (*mockStream) OnBufferedAmountLow(f func()) {} func (s *mockStream) Read(b []byte) (int, error) { + if atomic.LoadUint32(&s.closed) == 1 { + return 0, net.ErrClosed + } s.rddlMutex.RLock() defer s.rddlMutex.RUnlock() @@ -52,6 +57,10 @@ func (s *mockStream) Read(b []byte) (int, error) { } } func (s *mockStream) Write(b []byte) (int, error) { + if atomic.LoadUint32(&s.closed) == 1 { + return 0, net.ErrClosed + } + s.wddlMutex.RLock() defer s.wddlMutex.RUnlock() @@ -88,6 +97,7 @@ func (s *mockStream) SetWriteDeadline(t time.Time) error { } func (s *mockStream) Close() error { + atomic.StoreUint32(&s.closed, 1) s.closeOnce.Do(func() { close(s.closeCh) }) return nil @@ -121,50 +131,67 @@ func TestHeartbeatReadWrite(t *testing.T) { sent := 0 toSend := []byte("testtt") sendTimes := 5 - sleepInterval := 1000 * time.Millisecond + sleepInterval := conf.Interval / 3 var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), sleepInterval*5+sleepInterval/2) + defer cancel() + wg.Add(1) - go func() { + go func(ctx1 context.Context) { defer wg.Done() defer client.Close() defer server.Close() for i := 0; i < sendTimes; i++ { - buffer := make([]byte, 4096) - err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2)) - require.Nil(t, err) - n, err := s.Read(buffer) - if err != nil { - t.Log(err) - return - } - if string(toSend) != string(buffer[:n]) { - t.Log("read incorrect value", toSend, buffer[:n]) - t.Fail() + select { + case <-ctx1.Done(): + server.Close() return + default: + buffer := make([]byte, 4096) + err := s.SetReadDeadline(time.Now().Add(sleepInterval * 2)) + require.Nil(t, err) + n, err := s.Read(buffer) + if err != nil { + t.Log(err) + return + } + if string(toSend) != string(buffer[:n]) { + t.Log("read incorrect value", toSend, buffer[:n]) + t.Fail() + return + } + recvd++ + } - recvd++ } - }() + }(ctx) wg.Add(1) - go func() { + go func(ctx2 context.Context) { defer wg.Done() + for i := 0; i < sendTimes; i++ { - err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2)) - require.Nil(t, err) - _, err = client.Write(toSend) - if err != nil { - if !errors.Is(err, net.ErrClosed) { - t.Log("encountered error writing", err) - t.Fail() - } + select { + case <-ctx2.Done(): + client.Close() return + default: + err := server.SetWriteDeadline(time.Now().Add(sleepInterval * 2)) + require.Nil(t, err) + _, err = client.Write(toSend) + if err != nil { + if !errors.Is(err, net.ErrClosed) { + t.Log("encountered error writing", err) + t.Fail() + } + return + } + sent++ } - sent++ time.Sleep(sleepInterval) } - }() + }(ctx) wg.Wait() require.Equal(t, sent, recvd)