From cee9aa110162d80b3fda2ba8414c1807076073bc Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 3 May 2024 21:29:53 +0530 Subject: [PATCH] Always send 1 event for a connection --- p2p/net/swarm/connectedness_event_emitter.go | 151 +++++++++++++++++++ p2p/net/swarm/dial_worker.go | 2 +- p2p/net/swarm/swarm.go | 87 ++++------- p2p/net/swarm/swarm_conn.go | 17 ++- p2p/net/swarm/swarm_listen.go | 3 +- 5 files changed, 194 insertions(+), 66 deletions(-) create mode 100644 p2p/net/swarm/connectedness_event_emitter.go diff --git a/p2p/net/swarm/connectedness_event_emitter.go b/p2p/net/swarm/connectedness_event_emitter.go new file mode 100644 index 0000000000..1470e59fcd --- /dev/null +++ b/p2p/net/swarm/connectedness_event_emitter.go @@ -0,0 +1,151 @@ +package swarm + +import ( + "context" + "errors" + "sync" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" +) + +// connectednessEventEmitter emits PeerConnectednessChanged events. +// We ensure that for any peer we connected to we always sent atleast 1 NotConnected Event after +// the peer disconnects. This is because peers can observe a connection before they are notified +// of the connection by a peer connectedness changed event. +type connectednessEventEmitter struct { + mx sync.RWMutex + // newConns is the channel that holds the peerIDs we recently connected to + newConns chan peer.ID + removeConnsMx sync.Mutex + // removeConns is a slice of peerIDs we have recently closed connections to + removeConns []peer.ID + // lastEvent is the last connectedness event sent for a particular peer. + lastEvent map[peer.ID]network.Connectedness + // connectedness is the function that gives the peers current connectedness state + connectedness func(peer.ID) network.Connectedness + // emitter is the PeerConnectednessChanged event emitter + emitter event.Emitter + wg sync.WaitGroup + removeConnNotif chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +func newConnectednessEventEmitter(connectedness func(peer.ID) network.Connectedness, emitter event.Emitter) *connectednessEventEmitter { + ctx, cancel := context.WithCancel(context.Background()) + c := &connectednessEventEmitter{ + newConns: make(chan peer.ID, 32), + lastEvent: make(map[peer.ID]network.Connectedness), + removeConnNotif: make(chan struct{}, 1), + connectedness: connectedness, + emitter: emitter, + ctx: ctx, + cancel: cancel, + } + c.wg.Add(1) + go c.runEmitter() + return c +} + +func (c *connectednessEventEmitter) AddConn(ctx context.Context, p peer.ID) error { + c.mx.RLock() + defer c.mx.RUnlock() + if c.ctx.Err() != nil { + return errors.New("emitter closed") + } + + select { + case c.newConns <- p: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-c.ctx.Done(): + return errors.New("emitter closed") + } +} + +func (c *connectednessEventEmitter) RemoveConn(p peer.ID) { + c.mx.RLock() + defer c.mx.RUnlock() + if c.ctx.Err() != nil { + return + } + + c.removeConnsMx.Lock() + // This queue is not unbounded since we block in the AddConn method + // So we are adding connections to the swarm only at a rate + // the subscriber for our peer connectedness changed events can consume them. + // If a lot of open connections are closed at once, increasing the disconnected + // event notification rate, the rate of adding connections to the swarm would + // proportionately reduce, which would eventually reduce the length of this slice. + c.removeConns = append(c.removeConns, p) + c.removeConnsMx.Unlock() + + select { + case c.removeConnNotif <- struct{}{}: + default: + } +} + +func (c *connectednessEventEmitter) Close() { + c.cancel() + c.wg.Wait() +} + +func (c *connectednessEventEmitter) runEmitter() { + defer c.wg.Done() + for { + select { + case p := <-c.newConns: + c.notifyPeer(p, true) + case <-c.removeConnNotif: + c.sendConnRemovedNotifications() + case <-c.ctx.Done(): + c.mx.Lock() // Wait for all pending AddConn & RemoveConn operations to complete + defer c.mx.Unlock() + for { + select { + case p := <-c.newConns: + c.notifyPeer(p, true) + case <-c.removeConnNotif: + c.sendConnRemovedNotifications() + default: + return + } + } + } + } +} + +func (c *connectednessEventEmitter) notifyPeer(p peer.ID, forceNotConnectedEvent bool) { + oldState := c.lastEvent[p] + c.lastEvent[p] = c.connectedness(p) + if c.lastEvent[p] == network.NotConnected { + delete(c.lastEvent, p) + } + if (forceNotConnectedEvent && c.lastEvent[p] == network.NotConnected) || c.lastEvent[p] != oldState { + c.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: c.lastEvent[p], + }) + } +} + +func (c *connectednessEventEmitter) sendConnRemovedNotifications() { + c.removeConnsMx.Lock() + defer c.removeConnsMx.Unlock() + for { + if len(c.removeConns) == 0 { + return + } + p := c.removeConns[0] + c.removeConns[0] = "" + c.removeConns = c.removeConns[1:] + + c.removeConnsMx.Unlock() + c.notifyPeer(p, false) + c.removeConnsMx.Lock() + } +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 0cac6e4fa3..2ebc4e1efd 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -340,7 +340,7 @@ loop: ad.expectedTCPUpgradeTime = time.Time{} if res.Conn != nil { // we got a connection, add it to the swarm - conn, err := w.s.addConn(res.Conn, network.DirOutbound) + conn, err := w.s.addConn(ad.ctx, res.Conn, network.DirOutbound) if err != nil { // oops no, we failed to add it to the swarm res.Conn.Close() diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 3242bf3076..a7c323453a 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -144,9 +144,7 @@ type Swarm struct { // down before continuing. refs sync.WaitGroup - emitter event.Emitter - connectednessEventCh chan struct{} - connectednessEmitterDone chan struct{} + emitter event.Emitter rcmgr network.ResourceManager @@ -158,8 +156,7 @@ type Swarm struct { conns struct { sync.RWMutex - m map[peer.ID][]*Conn - connectednessEvents chan peer.ID + m map[peer.ID][]*Conn } listeners struct { @@ -206,9 +203,10 @@ type Swarm struct { dialRanker network.DialRanker - udpBlackHoleConfig blackHoleConfig - ipv6BlackHoleConfig blackHoleConfig - bhd *blackHoleDetector + udpBlackHoleConfig blackHoleConfig + ipv6BlackHoleConfig blackHoleConfig + bhd *blackHoleDetector + connectednessEventEmitter *connectednessEventEmitter } // NewSwarm constructs a Swarm. @@ -219,17 +217,15 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts } ctx, cancel := context.WithCancel(context.Background()) s := &Swarm{ - local: local, - peers: peers, - emitter: emitter, - connectednessEventCh: make(chan struct{}, 1), - connectednessEmitterDone: make(chan struct{}), - ctx: ctx, - ctxCancel: cancel, - dialTimeout: defaultDialTimeout, - dialTimeoutLocal: defaultDialTimeoutLocal, - maResolver: madns.DefaultResolver, - dialRanker: DefaultDialRanker, + local: local, + peers: peers, + emitter: emitter, + ctx: ctx, + ctxCancel: cancel, + dialTimeout: defaultDialTimeout, + dialTimeoutLocal: defaultDialTimeoutLocal, + maResolver: madns.DefaultResolver, + dialRanker: DefaultDialRanker, // A black hole is a binary property. On a network if UDP dials are blocked or there is // no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials @@ -239,11 +235,11 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts } s.conns.m = make(map[peer.ID][]*Conn) - s.conns.connectednessEvents = make(chan peer.ID, 32) s.listeners.m = make(map[transport.Listener]struct{}) s.transports.m = make(map[int]transport.Transport) s.notifs.m = make(map[network.Notifiee]struct{}) s.directConnNotifs.m = make(map[peer.ID][]chan struct{}) + s.connectednessEventEmitter = newConnectednessEventEmitter(s.Connectedness, emitter) for _, opt := range opts { if err := opt(s); err != nil { @@ -260,7 +256,6 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.backf.init(s.ctx) s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer) - go s.connectednessEventEmitter() return s, nil } @@ -306,8 +301,7 @@ func (s *Swarm) close() { // Wait for everything to finish. s.refs.Wait() - close(s.conns.connectednessEvents) - <-s.connectednessEmitterDone + s.connectednessEventEmitter.Close() s.emitter.Close() // Now close out any transports (if necessary). Do this after closing @@ -338,7 +332,7 @@ func (s *Swarm) close() { wg.Wait() } -func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { +func (s *Swarm) addConn(ctx context.Context, tc transport.CapableConn, dir network.Direction) (*Conn, error) { var ( p = tc.RemotePeer() addr = tc.RemoteMultiaddr() @@ -403,12 +397,14 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() - // Block this goroutine till this request is enqueued. - // This ensures that there are only a finite number of goroutines that are waiting to send - // the connectedness event on the disconnection side in swarm.removeConn. - // This is so because the goroutine to enqueue disconnection event can only be started - // from either a subscriber or a notifier or after calling c.start - s.conns.connectednessEvents <- p + err := s.connectednessEventEmitter.AddConn(ctx, p) + if err != nil { + // Either the subscriber is busy or the swarm is closed + c.Close() + s.removeConn(c) + c.notifyLk.Unlock() + return nil, fmt.Errorf("failed to send connectedness event: %w", err) + } if !isLimited { // Notify goroutines waiting for a direct connection @@ -427,6 +423,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) + c.notified = true c.notifyLk.Unlock() c.start() @@ -787,36 +784,6 @@ func (s *Swarm) removeConn(c *Conn) { } } s.conns.Unlock() - // Do this in a separate go routine to not block the caller. - // This ensures that if a event subscriber closes the connection from the subscription goroutine - // this doesn't deadlock - s.refs.Add(1) - go func() { - defer s.refs.Done() - s.conns.connectednessEvents <- p - }() -} - -func (s *Swarm) connectednessEventEmitter() { - defer close(s.connectednessEmitterDone) - lastConnectednessEvents := make(map[peer.ID]network.Connectedness) - for p := range s.conns.connectednessEvents { - s.conns.Lock() - oldState := lastConnectednessEvents[p] - newState := s.connectednessUnlocked(p) - if newState != network.NotConnected { - lastConnectednessEvents[p] = newState - } else { - delete(lastConnectednessEvents, p) - } - s.conns.Unlock() - if newState != oldState { - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: newState, - }) - } - } } // String returns a string representation of Network. diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 17ae1dffae..5a767e95e0 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -31,6 +31,7 @@ type Conn struct { err error notifyLk sync.Mutex + notified bool streams struct { sync.Mutex @@ -73,6 +74,11 @@ func (c *Conn) doClose() { c.err = c.conn.Close() + // Send the connectedness event after closing the connection. + // This ensures that both remote connection close and local connection + // close events are sent after the underlying transport connection is closed. + c.swarm.connectednessEventEmitter.RemoveConn(c.RemotePeer()) + // This is just for cleaning up state. The connection has already been closed. // We *could* optimize this but it really isn't worth it. for s := range streams { @@ -85,10 +91,13 @@ func (c *Conn) doClose() { c.notifyLk.Lock() defer c.notifyLk.Unlock() - c.swarm.notifyAll(func(f network.Notifiee) { - f.Disconnected(c.swarm, c) - }) - c.swarm.refs.Done() // taken in Swarm.addConn + defer c.swarm.refs.Done() // taken in Swarm.addConn + if c.notified { + // Only notify for disconnection if we notified for connection + c.swarm.notifyAll(func(f network.Notifiee) { + f.Disconnected(c.swarm, c) + }) + } }() } diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go index 0905e84513..2376a7e379 100644 --- a/p2p/net/swarm/swarm_listen.go +++ b/p2p/net/swarm/swarm_listen.go @@ -1,6 +1,7 @@ package swarm import ( + "context" "errors" "fmt" "time" @@ -142,7 +143,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { s.refs.Add(1) go func() { defer s.refs.Done() - _, err := s.addConn(c, network.DirInbound) + _, err := s.addConn(context.Background(), c, network.DirInbound) switch err { case nil: case ErrSwarmClosed: