Skip to content

Commit

Permalink
Adds RPC inspector (#8)
Browse files Browse the repository at this point in the history
* perf: use msgio pooled buffers for received msgs (libp2p#500)

* perf: use pooled buffers for message writes (libp2p#507)

* improve handling of dead peers (libp2p#508)

* chore: ignore signing keys during WithLocalPublication publishing (libp2p#497)

* adds app specific rpc handler

Co-authored-by: Hlib Kanunnikov <[email protected]>
Co-authored-by: Viacheslav <[email protected]>
  • Loading branch information
3 people authored Nov 24, 2022
1 parent 1c99052 commit 216c157
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 27 deletions.
59 changes: 35 additions & 24 deletions comm.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package pubsub

import (
"bufio"
"context"
"encoding/binary"
"io"
"time"

"github.com/gogo/protobuf/proto"
pool "github.com/libp2p/go-buffer-pool"
"github.com/multiformats/go-varint"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-msgio"

pb "github.com/libp2p/go-libp2p-pubsub/pb"

"github.com/libp2p/go-msgio/protoio"

"github.com/gogo/protobuf/proto"
)

// get the initial RPC containing all of our subscriptions to send to new peers
Expand Down Expand Up @@ -60,11 +61,11 @@ func (p *PubSub) handleNewStream(s network.Stream) {
p.inboundStreamsMx.Unlock()
}()

r := protoio.NewDelimitedReader(s, p.maxMessageSize)
r := msgio.NewVarintReaderSize(s, p.maxMessageSize)
for {
rpc := new(RPC)
err := r.ReadMsg(&rpc.RPC)
msgbytes, err := r.ReadMsg()
if err != nil {
r.ReleaseMsg(msgbytes)
if err != io.EOF {
s.Reset()
log.Debugf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err)
Expand All @@ -77,6 +78,15 @@ func (p *PubSub) handleNewStream(s network.Stream) {
return
}

rpc := new(RPC)
err = rpc.Unmarshal(msgbytes)
r.ReleaseMsg(msgbytes)
if err != nil {
s.Reset()
log.Warnf("bogus rpc from %s: %s", s.Conn().RemotePeer(), err)
return
}

rpc.from = peer
select {
case p.incoming <- rpc:
Expand Down Expand Up @@ -115,7 +125,7 @@ func (p *PubSub) handleNewPeer(ctx context.Context, pid peer.ID, outgoing <-chan
}

go p.handleSendingMessages(ctx, s, outgoing)
go p.handlePeerEOF(ctx, s)
go p.handlePeerDead(s)
select {
case p.newPeerStream <- s:
case <-ctx.Done():
Expand All @@ -131,32 +141,33 @@ func (p *PubSub) handleNewPeerWithBackoff(ctx context.Context, pid peer.ID, back
}
}

func (p *PubSub) handlePeerEOF(ctx context.Context, s network.Stream) {
func (p *PubSub) handlePeerDead(s network.Stream) {
pid := s.Conn().RemotePeer()
r := protoio.NewDelimitedReader(s, p.maxMessageSize)
rpc := new(RPC)
for {
err := r.ReadMsg(&rpc.RPC)
if err != nil {
p.notifyPeerDead(pid)
return
}

_, err := s.Read([]byte{0})
if err == nil {
log.Debugf("unexpected message from %s", pid)
}

s.Reset()
p.notifyPeerDead(pid)
}

func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing <-chan *RPC) {
bufw := bufio.NewWriter(s)
wc := protoio.NewDelimitedWriter(bufw)
writeRpc := func(rpc *RPC) error {
size := uint64(rpc.Size())

buf := pool.Get(varint.UvarintSize(size) + int(size))
defer pool.Put(buf)

writeMsg := func(msg proto.Message) error {
err := wc.WriteMsg(msg)
n := binary.PutUvarint(buf, size)
_, err := rpc.MarshalTo(buf[n:])
if err != nil {
return err
}

return bufw.Flush()
_, err = s.Write(buf)
return err
}

defer s.Close()
Expand All @@ -167,7 +178,7 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou
return
}

err := writeMsg(&rpc.RPC)
err := writeRpc(rpc)
if err != nil {
s.Reset()
log.Debugf("writing message to %s: %s", s.Conn().RemotePeer(), err)
Expand Down
28 changes: 26 additions & 2 deletions gossipsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ func NewGossipSubWithRouter(ctx context.Context, h host.Host, rt PubSubRouter, o
}

// DefaultGossipSubRouter returns a new GossipSubRouter with default parameters.
func DefaultGossipSubRouter(h host.Host) *GossipSubRouter {
func DefaultGossipSubRouter(h host.Host, opts ...func(*GossipSubRouter)) *GossipSubRouter {
params := DefaultGossipSubParams()
return &GossipSubRouter{
rt := &GossipSubRouter{
peers: make(map[peer.ID]protocol.ID),
mesh: make(map[string]map[peer.ID]struct{}),
fanout: make(map[string]map[peer.ID]struct{}),
Expand All @@ -237,6 +237,18 @@ func DefaultGossipSubRouter(h host.Host) *GossipSubRouter {
tagTracer: newTagTracer(h.ConnManager()),
params: params,
}

for _, opt := range opts {
opt(rt)
}

return rt
}

func WithAppSpecificRpcInspector(inspector func(peer.ID, *RPC) bool) func(*GossipSubRouter) {
return func(rt *GossipSubRouter) {
rt.appSpecificRpcInspector = inspector
}
}

// DefaultGossipSubParams returns the default gossip sub parameters
Expand Down Expand Up @@ -474,6 +486,11 @@ type GossipSubRouter struct {
// number of heartbeats since the beginning of time; this allows us to amortize some resource
// clean up -- eg backoff clean up.
heartbeatTicks uint64

// appSpecificRpcInspector is an auxiliary that may be set by the application to inspect incoming RPCs prior to
// processing them. The inspector is invoked on an accepted RPC right prior to handling it.
// The return value of the inspector function is a boolean indicating whether the RPC should be processed or not.
appSpecificRpcInspector func(peer.ID, *RPC) bool
}

type connectInfo struct {
Expand Down Expand Up @@ -614,6 +631,13 @@ func (gs *GossipSubRouter) HandleRPC(rpc *RPC) {
return
}

if gs.appSpecificRpcInspector != nil {
// check if the RPC is allowed by the external inspector
if accept := gs.appSpecificRpcInspector(rpc.from, rpc); !accept {
return // reject the RPC
}
}

iwant := gs.handleIHave(rpc.from, ctl)
ihave := gs.handleIWant(rpc.from, ctl)
prune := gs.handleGraft(rpc.from, ctl)
Expand Down
2 changes: 1 addition & 1 deletion topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error
}
}

if pub.customKey != nil {
if pub.customKey != nil && !pub.local {
key, pid = pub.customKey()
if key == nil {
return ErrNilSignKey
Expand Down

0 comments on commit 216c157

Please sign in to comment.