From f2477b92f4831221274c63960b5d0a65d84180bf Mon Sep 17 00:00:00 2001 From: Peter Argue <89119817+peterargue@users.noreply.github.com> Date: Thu, 1 Feb 2024 12:24:36 -0800 Subject: [PATCH] Improve peer connection handling --- bitswap/network/ipfs_impl.go | 95 ++++++++++++++++++++++++++++--- bitswap/network/ipfs_impl_test.go | 63 ++++++++++++++++++++ 2 files changed, 149 insertions(+), 9 deletions(-) diff --git a/bitswap/network/ipfs_impl.go b/bitswap/network/ipfs_impl.go index a1446775c..9a8ad984d 100644 --- a/bitswap/network/ipfs_impl.go +++ b/bitswap/network/ipfs_impl.go @@ -14,6 +14,7 @@ import ( cid "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -87,6 +88,8 @@ type impl struct { // inbound messages from the network are forwarded to the receiver receivers []Receiver + + cancel context.CancelFunc } type streamMessageSender struct { @@ -354,8 +357,18 @@ func (bsnet *impl) Start(r ...Receiver) { bsnet.connectEvtMgr = newConnectEventManager(connectionListeners...) } for _, proto := range bsnet.supportedProtocols { + log.Debugf("setting up handler for protocol: %s", proto) bsnet.host.SetStreamHandler(proto, bsnet.handleNewStream) } + + // try to subscribe to libp2p events that indicate a change in connection state + // if this fails, continue as normal + err := bsnet.trySubscribePeerUpdates() + if err != nil { + log.Errorf("failed to subscribe to libp2p events: %s", err) + } + + // listen for disconnects and start processing the events bsnet.host.Network().Notify((*netNotifiee)(bsnet)) bsnet.connectEvtMgr.Start() } @@ -363,6 +376,77 @@ func (bsnet *impl) Start(r ...Receiver) { func (bsnet *impl) Stop() { bsnet.connectEvtMgr.Stop() bsnet.host.Network().StopNotify((*netNotifiee)(bsnet)) + bsnet.cancel() +} + +func (bsnet *impl) trySubscribePeerUpdates() error { + // first, subscribe to libp2p events that indicate a change in connection state + sub, err := bsnet.host.EventBus().Subscribe([]interface{}{ + &event.EvtPeerProtocolsUpdated{}, + &event.EvtPeerIdentificationCompleted{}, + }) + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + bsnet.cancel = cancel + + go bsnet.peerUpdatedSubscription(ctx, sub) + + // next, add any peers with existing connections that support bitswap protocols + for _, conn := range bsnet.host.Network().Conns() { + peerID := conn.RemotePeer() + if bsnet.peerSupportsBitswap(peerID) { + log.Debugf("connecting to existing peer: %s", peerID) + bsnet.connectEvtMgr.Connected(peerID) + } + } + + return nil +} + +func (bsnet *impl) peerUpdatedSubscription(ctx context.Context, sub event.Subscription) { + for { + select { + case <-ctx.Done(): + return + case evt := <-sub.Out(): + switch e := evt.(type) { + case event.EvtPeerProtocolsUpdated: + if bsnet.hasBitswapProtocol(e.Added) { + log.Debugf("connecting to peer with updated protocol list: %s", e.Peer) + bsnet.connectEvtMgr.Connected(e.Peer) + continue + } + + if bsnet.hasBitswapProtocol(e.Removed) && !bsnet.peerSupportsBitswap(e.Peer) { + log.Debugf("disconnecting from peer with updated protocol list: %s", e.Peer) + bsnet.connectEvtMgr.Disconnected(e.Peer) + } + case event.EvtPeerIdentificationCompleted: + if bsnet.peerSupportsBitswap(e.Peer) { + log.Debugf("connecting to peer with new identification: %s", e.Peer) + bsnet.connectEvtMgr.Connected(e.Peer) + } + } + } + } +} + +func (bsnet *impl) peerSupportsBitswap(peerID peer.ID) bool { + protocols, err := bsnet.host.Peerstore().SupportsProtocols(peerID, bsnet.supportedProtocols...) + return err == nil && len(protocols) > 0 +} + +func (bsnet *impl) hasBitswapProtocol(protos []protocol.ID) bool { + for _, p := range protos { + switch p { + case bsnet.protocolBitswap, bsnet.protocolBitswapOneOne, bsnet.protocolBitswapOneZero, bsnet.protocolBitswapNoVers: + return true + } + } + return false } func (bsnet *impl) ConnectTo(ctx context.Context, p peer.ID) error { @@ -450,23 +534,16 @@ func (nn *netNotifiee) impl() *impl { return (*impl)(nn) } -func (nn *netNotifiee) Connected(n network.Network, v network.Conn) { - // ignore transient connections - if v.Stat().Transient { - return - } - - nn.impl().connectEvtMgr.Connected(v.RemotePeer()) -} - func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { // Only record a "disconnect" when we actually disconnect. if n.Connectedness(v.RemotePeer()) == network.Connected { return } + log.Debugf("peer disconnected: %s", v.RemotePeer()) nn.impl().connectEvtMgr.Disconnected(v.RemotePeer()) } +func (nn *netNotifiee) Connected(n network.Network, v network.Conn) {} func (nn *netNotifiee) OpenedStream(n network.Network, s network.Stream) {} func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} diff --git a/bitswap/network/ipfs_impl_test.go b/bitswap/network/ipfs_impl_test.go index af76e20d6..306d08896 100644 --- a/bitswap/network/ipfs_impl_test.go +++ b/bitswap/network/ipfs_impl_test.go @@ -669,3 +669,66 @@ func TestNetworkCounters(t *testing.T) { testNetworkCounters(t, 10-n, n) } } + +func TestPeerDiscovery(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + mn := mocknet.New() + defer mn.Close() + + mr := mockrouting.NewServer() + streamNet, err := tn.StreamNet(ctx, mn, mr) + if err != nil { + t.Fatal("Unable to setup network") + } + + // start 2 disconnected nodes + p1 := tnet.RandIdentityOrFatal(t) + p2 := tnet.RandIdentityOrFatal(t) + + bsnet1 := streamNet.Adapter(p1) + bsnet2 := streamNet.Adapter(p2) + r1 := newReceiver() + r2 := newReceiver() + bsnet1.Start(r1) + t.Cleanup(bsnet1.Stop) + bsnet2.Start(r2) + t.Cleanup(bsnet2.Stop) + + err = mn.LinkAll() + if err != nil { + t.Fatal(err) + } + + // send request from node 1 to node 2 + blockGenerator := blocksutil.NewBlockGenerator() + block := blockGenerator.Next() + sent := bsmsg.New(false) + sent.AddBlock(block) + + err = bsnet1.SendMessage(ctx, p2.ID(), sent) + if err != nil { + t.Fatal(err) + } + + // node 2 should connect to node 1 + select { + case <-ctx.Done(): + t.Fatal("did not connect peer") + case <-r2.connectionEvent: + } + + // verify the message is received + select { + case <-ctx.Done(): + t.Fatal("did not receive message sent") + case <-r2.messageReceived: + } + + sender := r2.lastSender + if sender != p1.ID() { + t.Fatal("received message from wrong node") + } +}