diff --git a/pkg/transports/connecting/dtls/client.go b/pkg/transports/connecting/dtls/client.go index 5264d94f..4be24674 100644 --- a/pkg/transports/connecting/dtls/client.go +++ b/pkg/transports/connecting/dtls/client.go @@ -44,7 +44,6 @@ type ClientTransport struct { psk []byte stunServer string disableIRWorkaround bool - listenTimeout *time.Duration } type ClientConfig struct { @@ -132,7 +131,6 @@ func (t *ClientTransport) SetParams(p any) error { case *ClientConfig: t.stunServer = params.STUNServer t.disableIRWorkaround = params.DisableIRWorkaround - t.listenTimeout = params.ListenTimeout } return nil @@ -205,27 +203,45 @@ func (t *ClientTransport) GetDstPort(seed []byte) (uint16, error) { func (t *ClientTransport) WrapDial(dialer dialFunc) (dialFunc, error) { dtlsDialer := func(ctx context.Context, network, localAddr, address string) (net.Conn, error) { - // Create a context that will automatically cancel after 5 seconds or when the existing context is cancelled, whichever comes first. - timeout := t.listenTimeout - if timeout == nil { - time := defaultListenTime - timeout = &time - } - ctxtimeout, cancel := context.WithTimeout(ctx, *timeout) + + dialCtx, cancel := context.WithCancel(ctx) defer cancel() - conn, errListen := t.listen(ctxtimeout, dialer, address) - if errListen != nil { - // fallback to dial - conn, errDial := t.dial(ctx, dialer, address) - if errDial != nil { - return nil, fmt.Errorf("error listening: %v, error dialing: %v", errListen, errDial) + type result struct { + conn net.Conn + err error + } + + results := make(chan result, 2) + + go func() { + conn, err := t.listen(dialCtx, dialer, address) + results <- result{conn, err} + }() + + go func() { + conn, err := t.dial(dialCtx, dialer, address) + results <- result{conn, err} + }() + + first := <-results + if first.err == nil { + // Interrupt the other dial + cancel() + second := <-results + if second.conn != nil { + _ = second.conn.Close() } + return first.conn, nil + } - return conn, nil + second := <-results + if second.err == nil { + return second.conn, nil } - return conn, nil + // TODO: once our minimum golang version is >= 1.20 change this to "%w; %w" + return nil, fmt.Errorf("%w; %s", first.err, second.err) } return dtlsDialer, nil