From e09ef046cf1ae7126fafabdc56deee3acf031603 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 14 Aug 2023 21:02:50 +0530 Subject: [PATCH] store map of peers supporting DialProtocol --- p2p/host/blank/blank.go | 7 +- p2p/protocol/autonatv2/autonat.go | 87 +++++++++-------- p2p/protocol/autonatv2/autonat_test.go | 126 +++++++++++++++++-------- p2p/protocol/autonatv2/server_test.go | 15 +-- 4 files changed, 145 insertions(+), 90 deletions(-) diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 24304498b0..bf00ecd565 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -63,9 +63,10 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost { } bh := &BlankHost{ - n: n, - cmgr: cfg.cmgr, - mux: mstream.NewMultistreamMuxer[protocol.ID](), + n: n, + cmgr: cfg.cmgr, + mux: mstream.NewMultistreamMuxer[protocol.ID](), + eventbus: cfg.eventBus, } if bh.eventbus == nil { bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer())) diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 8c40899a36..a401cafb93 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -15,7 +15,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" - "golang.org/x/exp/rand" ) const ( @@ -45,6 +44,8 @@ type AutoNAT struct { wg sync.WaitGroup srv *Server cli *Client + mx sync.Mutex + peers map[peer.ID]struct{} allowAllAddrs bool // for testing } @@ -55,7 +56,11 @@ func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error) return nil, err } } - sub, err := h.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged)) + sub, err := h.EventBus().Subscribe([]interface{}{ + new(event.EvtLocalReachabilityChanged), + new(event.EvtPeerProtocolsUpdated), + new(event.EvtPeerConnectednessChanged), + }) if err != nil { return nil, fmt.Errorf("failed to subscribe to event.EvtLocalReachabilityChanged: %w", err) } @@ -69,6 +74,7 @@ func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error) srv: NewServer(h, dialer, s), cli: NewClient(h), allowAllAddrs: s.allowAllAddrs, + peers: make(map[peer.ID]struct{}), } an.cli.Register() @@ -84,28 +90,31 @@ func (an *AutoNAT) background() { an.srv.Disable() an.wg.Done() return - case evt := <-an.sub.Out(): - // Enable the server only if we're publicly reachable. - // - // Currently this event is sent by the AutoNAT v1 module. During the - // transition period from AutoNAT v1 to v2, there won't be enough v2 servers - // on the network and most clients will be unable to discover a peer which - // supports AutoNAT v2. So, we use v1 to determine reachability for the - // transition period. - // - // Once there are enough v2 servers on the network for nodes to determine - // their reachability using AutoNAT v2, we'll use Address Pipeline - // (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a - // future release) to determine reachability using v2 client and send this - // event if we are publicly reachable. - revt, ok := evt.(event.EvtLocalReachabilityChanged) - if !ok { - log.Errorf("Unexpected event %s of type %T", evt, evt) - } - if revt.Reachability == network.ReachabilityPrivate { - an.srv.Disable() - } else { - an.srv.Enable() + case e := <-an.sub.Out(): + switch evt := e.(type) { + case event.EvtLocalReachabilityChanged: + // Enable the server only if we're publicly reachable. + // + // Currently this event is sent by the AutoNAT v1 module. During the + // transition period from AutoNAT v1 to v2, there won't be enough v2 servers + // on the network and most clients will be unable to discover a peer which + // supports AutoNAT v2. So, we use v1 to determine reachability for the + // transition period. + // + // Once there are enough v2 servers on the network for nodes to determine + // their reachability using AutoNAT v2, we'll use Address Pipeline + // (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a + // future release) to determine reachability using v2 client and send this + // event from Address Pipeline, if we are publicly reachable. + if evt.Reachability == network.ReachabilityPrivate { + an.srv.Disable() + } else { + an.srv.Enable() + } + case event.EvtPeerProtocolsUpdated: + an.updatePeer(evt.Peer) + case event.EvtPeerConnectednessChanged: + an.updatePeer(evt.Peer) } } } @@ -140,21 +149,25 @@ func (an *AutoNAT) CheckReachability(ctx context.Context, highPriorityAddrs []ma } func (an *AutoNAT) validPeer() peer.ID { - peers := an.host.Peerstore().Peers() - idx := 0 - for i := 0; i < len(peers); i++ { - if proto, err := an.host.Peerstore().SupportsProtocols(peers[i], DialProtocol); len(proto) == 0 || err != nil { - continue - } - peers[idx] = peers[i] - idx++ + an.mx.Lock() + defer an.mx.Unlock() + for p := range an.peers { + return p } - if idx == 0 { - return "" + return "" +} + +func (an *AutoNAT) updatePeer(p peer.ID) { + an.mx.Lock() + defer an.mx.Unlock() + + _, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) + connState := an.host.Network().Connectedness(p) + if err == nil && connState == network.Connected { + an.peers[p] = struct{}{} + } else { + delete(an.peers, p) } - peers = peers[:idx] - rand.Shuffle(len(peers), func(i, j int) { peers[i], peers[j] = peers[j], peers[i] }) - return peers[0] } type Result struct { diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 6581d352c7..725e1bc5c9 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "sync/atomic" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2" @@ -22,7 +24,8 @@ import ( func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT { t.Helper() - h := bhost.NewBlankHost(swarmt.GenSwarm(t)) + b := eventbus.NewBus() + h := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.EventBus(b)), bhost.WithEventBus(b)) if dialer == nil { dialer = bhost.NewBlankHost(swarmt.GenSwarm(t)) } @@ -30,6 +33,8 @@ func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT if err != nil { t.Error(err) } + an.srv.Enable() + an.cli.Register() return an } @@ -47,31 +52,28 @@ func parseAddrs(t *testing.T, msg *pbv2.Message) []ma.Multiaddr { return addrs } -func TestValidPeer(t *testing.T) { - an := newAutoNAT(t, nil) - require.Equal(t, an.validPeer(), peer.ID("")) - an.host.Peerstore().AddAddr("peer1", ma.StringCast("/ip4/127.0.0.1/tcp/1"), peerstore.PermanentAddrTTL) - an.host.Peerstore().AddAddr("peer2", ma.StringCast("/ip4/127.0.0.1/tcp/2"), peerstore.PermanentAddrTTL) - require.NoError(t, an.host.Peerstore().AddProtocols("peer1", DialProtocol)) - require.NoError(t, an.host.Peerstore().AddProtocols("peer2", DialProtocol)) - - var got1, got2 bool - for i := 0; i < 100; i++ { - p := an.validPeer() - switch p { - case peer.ID("peer1"): - got1 = true - case peer.ID("peer2"): - got2 = true - default: - t.Fatalf("invalid peer: %s", p) - } - if got1 && got2 { - break - } - } - require.True(t, got1) - require.True(t, got2) +func idAndConnect(t *testing.T, a, b host.Host) { + a.Peerstore().AddAddrs(b.ID(), b.Addrs(), peerstore.PermanentAddrTTL) + a.Peerstore().AddProtocols(b.ID(), DialProtocol) + + err := a.Connect(context.Background(), peer.AddrInfo{ID: b.ID()}) + require.NoError(t, err) +} + +// waitForPeer waits for a to process all peer events +func waitForPeer(t *testing.T, a *AutoNAT) { + t.Helper() + require.Eventually(t, func() bool { + a.mx.Lock() + defer a.mx.Unlock() + return len(a.peers) > 0 + }, 5*time.Second, 100*time.Millisecond) +} + +// identify provides server address and protocol to client +func identify(t *testing.T, cli *AutoNAT, srv *AutoNAT) { + idAndConnect(t, cli.host, srv.host) + waitForPeer(t, cli) } func TestAutoNATPrivateAddr(t *testing.T) { @@ -82,19 +84,24 @@ func TestAutoNATPrivateAddr(t *testing.T) { } func TestClientRequest(t *testing.T) { - an := newAutoNAT(t, nil) + an := newAutoNAT(t, nil, allowAll) addrs := an.host.Addrs() + var gotReq atomic.Bool p := bhost.NewBlankHost(swarmt.GenSwarm(t)) p.SetStreamHandler(DialProtocol, func(s network.Stream) { + gotReq.Store(true) r := pbio.NewDelimitedReader(s, maxMsgSize) var msg pbv2.Message - err := r.ReadMsg(&msg) - if err != nil { + if err := r.ReadMsg(&msg); err != nil { t.Error(err) + return + } + if msg.GetDialRequest() == nil { + t.Errorf("expected message to be of type DialRequest, got %T", msg.Msg) + return } - require.NotNil(t, msg.GetDialRequest()) addrsb := make([][]byte, len(addrs)) for i := 0; i < len(addrs); i++ { addrsb[i] = addrs[i].Bytes() @@ -105,11 +112,13 @@ func TestClientRequest(t *testing.T) { s.Reset() }) - an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.TempAddrTTL) - an.host.Peerstore().AddProtocols(p.ID(), DialProtocol) + idAndConnect(t, an.host, p) + waitForPeer(t, an) + res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:]) require.Nil(t, res) require.NotNil(t, err) + require.True(t, gotReq.Load()) } func TestClientServerError(t *testing.T) { @@ -117,8 +126,9 @@ func TestClientServerError(t *testing.T) { addrs := an.host.Addrs() p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.PermanentAddrTTL) - an.host.Peerstore().AddProtocols(p.ID(), DialProtocol) + idAndConnect(t, an.host, p) + waitForPeer(t, an) + done := make(chan bool) tests := []struct { handler func(network.Stream) @@ -163,8 +173,9 @@ func TestClientDataRequest(t *testing.T) { addrs := an.host.Addrs() p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.PermanentAddrTTL) - an.host.Peerstore().AddProtocols(p.ID(), DialProtocol) + idAndConnect(t, an.host, p) + waitForPeer(t, an) + done := make(chan bool) tests := []struct { handler func(network.Stream) @@ -234,9 +245,8 @@ func TestClientDialAttempts(t *testing.T) { addrs := an.host.Addrs() p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.PermanentAddrTTL) - an.host.Peerstore().AddProtocols(p.ID(), DialProtocol) - an.cli.Register() + idAndConnect(t, an.host, p) + waitForPeer(t, an) tests := []struct { handler func(network.Stream) @@ -419,3 +429,41 @@ func TestClientDialAttempts(t *testing.T) { }) } } + +func TestEventSubscription(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + c := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer c.Close() + + idAndConnect(t, an.host, b) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + idAndConnect(t, an.host, c) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers) == 2 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(b.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(c.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers) == 0 + }, 5*time.Second, 100*time.Millisecond) +} diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index b8055b723b..2ba6070671 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -6,7 +6,6 @@ import ( "time" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/test" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" @@ -15,12 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -// identify provides server address and protocol to client -func identify(cli *AutoNAT, srv *AutoNAT) { - cli.host.Peerstore().AddAddrs(srv.host.ID(), srv.host.Addrs(), peerstore.PermanentAddrTTL) - cli.host.Peerstore().AddProtocols(srv.host.ID(), DialProtocol) -} - func TestServerAllAddrsInvalid(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) an := newAutoNAT(t, dialer, allowAll) @@ -32,7 +25,7 @@ func TestServerAllAddrsInvalid(t *testing.T) { defer c.Close() defer c.host.Close() - identify(c, an) + identify(t, c, an) res, err := c.CheckReachability(context.Background(), c.host.Addrs(), nil) require.NoError(t, err) @@ -51,7 +44,7 @@ func TestServerPrivateRejected(t *testing.T) { defer c.Close() defer c.host.Close() - identify(c, an) + identify(t, c, an) res, err := c.CheckReachability(context.Background(), c.host.Addrs(), nil) require.NoError(t, err) @@ -79,7 +72,7 @@ func TestServerDataRequest(t *testing.T) { defer c.Close() defer c.host.Close() - identify(c, an) + identify(t, c, an) var quicAddr, tcpAddr ma.Multiaddr for _, a := range c.host.Addrs() { @@ -108,7 +101,7 @@ func TestServerDial(t *testing.T) { defer c.Close() defer c.host.Close() - identify(c, an) + identify(t, c, an) randAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") res, err := c.CheckReachability(context.Background(), []ma.Multiaddr{randAddr}, c.host.Addrs())