diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 123293911a..367fca05f2 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -27,6 +27,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/protocol/ping" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/prometheus/client_golang/prometheus" @@ -786,7 +787,9 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { copy(addrs, addrsOld) for i, addr := range addrs { - if ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr); ok && n == 0 { + wtOK, wtN := libp2pwebtransport.IsWebtransportMultiaddr(addr) + webrtcOK, webrtcN := libp2pwebrtc.IsWebRTCDirectMultiaddr(addr) + if (wtOK && wtN == 0) || (webrtcOK && webrtcN == 0) { t := s.TransportForListening(addr) tpt, ok := t.(addCertHasher) if !ok { @@ -794,7 +797,7 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { } addrWithCerthash, added := tpt.AddCertHashes(addr) if !added { - log.Debug("Couldn't add certhashes to webtransport multiaddr because we aren't listening on webtransport") + log.Debugf("Couldn't add certhashes to multiaddr: %s", addr) continue } addrs[i] = addrWithCerthash diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index 98d8cf45f1..e6cd7ea9d9 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -3,6 +3,7 @@ package basichost import ( "context" "fmt" + "strings" "testing" "time" @@ -12,6 +13,8 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -158,3 +161,41 @@ func TestNewStreamTransientConnection(t *testing.T) { <-done <-done } + +func TestAddrFactorCertHashAppend(t *testing.T) { + wtAddr := "/ip4/1.2.3.4/udp/1/quic-v1/webtransport" + webrtcAddr := "/ip4/1.2.3.4/udp/2/webrtc-direct" + addrsFactory := func(addrs []ma.Multiaddr) []ma.Multiaddr { + return append(addrs, + ma.StringCast(wtAddr), + ma.StringCast(webrtcAddr), + ) + } + h, err := libp2p.New( + libp2p.AddrsFactory(addrsFactory), + libp2p.Transport(libp2pwebrtc.New), + libp2p.Transport(libp2pwebtransport.New), + libp2p.ListenAddrStrings( + "/ip4/0.0.0.0/udp/0/quic-v1/webtransport", + "/ip4/0.0.0.0/udp/0/webrtc-direct", + ), + ) + require.NoError(t, err) + require.Eventually(t, func() bool { + addrs := h.Addrs() + var hasWebRTC, hasWebTransport bool + for _, addr := range addrs { + if strings.HasPrefix(addr.String(), webrtcAddr) { + if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err == nil { + hasWebRTC = true + } + } + if strings.HasPrefix(addr.String(), wtAddr) { + if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err == nil { + hasWebTransport = true + } + } + } + return hasWebRTC && hasWebTransport + }, 5*time.Second, 100*time.Millisecond) +} diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index ae7ae339fe..b04753ecab 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -40,7 +40,6 @@ import ( "github.com/libp2p/go-msgio" ma "github.com/multiformats/go-multiaddr" - mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multihash" @@ -48,8 +47,6 @@ import ( "github.com/pion/webrtc/v3" ) -var dialMatcher = mafmt.And(mafmt.UDP, mafmt.Base(ma.P_WEBRTC_DIRECT), mafmt.Base(ma.P_CERTHASH)) - var webrtcComponent *ma.Component func init() { @@ -179,7 +176,8 @@ func (t *WebRTCTransport) Proxy() bool { } func (t *WebRTCTransport) CanDial(addr ma.Multiaddr) bool { - return dialMatcher.Matches(addr) + isValid, n := IsWebRTCDirectMultiaddr(addr) + return isValid && n > 0 } // Listen returns a listener for addr. @@ -514,6 +512,24 @@ func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerCon return secureConn.RemotePublicKey(), nil } +func (t *WebRTCTransport) AddCertHashes(addr ma.Multiaddr) (ma.Multiaddr, bool) { + listenerFingerprint, err := t.getCertificateFingerprint() + if err != nil { + return nil, false + } + + encodedLocalFingerprint, err := encodeDTLSFingerprint(listenerFingerprint) + if err != nil { + return nil, false + } + + certComp, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, encodedLocalFingerprint) + if err != nil { + return nil, false + } + return addr.Encapsulate(certComp), true +} + type netConnWrapper struct { *stream } @@ -601,3 +617,35 @@ func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configurat IncomingDataChannels: incomingDataChannels, }, nil } + +// IsWebRTCDirectMultiaddr returns whether addr is a /webrtc-direct multiaddr with the count of certhashes +// in addr +func IsWebRTCDirectMultiaddr(addr ma.Multiaddr) (bool, int) { + var foundUDP, foundWebRTC bool + certHashCount := 0 + ma.ForEach(addr, func(c ma.Component) bool { + if !foundUDP { + if c.Protocol().Code == ma.P_UDP { + foundUDP = true + } + return true + } + if !foundWebRTC && foundUDP { + // protocol after udp must be webrtc-direct + if c.Protocol().Code != ma.P_WEBRTC_DIRECT { + return false + } + foundWebRTC = true + return true + } + if foundWebRTC { + if c.Protocol().Code == ma.P_CERTHASH { + certHashCount++ + } else { + return false + } + } + return true + }) + return foundUDP && foundWebRTC, certHashCount +} diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index e5ef3ca1d0..a3054a82df 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -40,6 +40,58 @@ func getTransport(t *testing.T, opts ...Option) (*WebRTCTransport, peer.ID) { return transport, peerID } +func TestIsWebRTCDirectMultiaddr(t *testing.T) { + invalid := []string{ + "/ip4/1.2.3.4/tcp/10/", + "/ip6/1::3/udp/100/quic-v1/", + "/ip4/1.2.3.4/udp/1/quic-v1/webrtc-direct", + } + + valid := []struct { + addr string + count int + }{ + { + addr: "/ip4/1.2.3.4/udp/1234/webrtc-direct", + count: 0, + }, + { + addr: "/dns/test.test/udp/1234/webrtc-direct", + count: 0, + }, + { + addr: "/ip4/1.2.3.4/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg", + count: 1, + }, + { + addr: "/ip6/0:0:0:0:0:0:0:1/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg", + count: 1, + }, + { + addr: "/dns/test.test/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg", + count: 1, + }, + { + addr: "/dns/test.test/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7ZGrV4VZ3hpEKTd_zg", + count: 2, + }, + } + + for _, addr := range invalid { + a := ma.StringCast(addr) + isValid, n := IsWebRTCDirectMultiaddr(a) + require.Equal(t, 0, n) + require.False(t, isValid) + } + + for _, tc := range valid { + a := ma.StringCast(tc.addr) + isValid, n := IsWebRTCDirectMultiaddr(a) + require.Equal(t, tc.count, n) + require.True(t, isValid) + } +} + func TestTransportWebRTC_CanDial(t *testing.T) { tr, _ := getTransport(t) invalid := []string{ @@ -65,6 +117,21 @@ func TestTransportWebRTC_CanDial(t *testing.T) { } } +func TestTransportAddCertHasher(t *testing.T) { + tr, _ := getTransport(t) + addrs := []string{ + "/ip4/1.2.3.4/udp/1/webrtc-direct", + "/ip6/1::3/udp/2/webrtc-direct", + } + for _, a := range addrs { + addr, added := tr.AddCertHashes(ma.StringCast(a)) + require.True(t, added) + _, err := addr.ValueForProtocol(ma.P_CERTHASH) + require.NoError(t, err) + require.True(t, strings.HasPrefix(addr.String(), a)) + } +} + func TestTransportWebRTC_ListenFailsOnNonWebRTCMultiaddr(t *testing.T) { tr, _ := getTransport(t) testAddrs := []string{