diff --git a/comm.go b/comm.go index 14e8c773..e63a8559 100644 --- a/comm.go +++ b/comm.go @@ -6,14 +6,14 @@ import ( "io" "time" + "github.com/gogo/protobuf/proto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - - pb "github.com/libp2p/go-libp2p-pubsub/pb" - + "github.com/libp2p/go-msgio" "github.com/libp2p/go-msgio/protoio" - "github.com/gogo/protobuf/proto" + pb "github.com/libp2p/go-libp2p-pubsub/pb" ) // get the initial RPC containing all of our subscriptions to send to new peers @@ -60,11 +60,11 @@ func (p *PubSub) handleNewStream(s network.Stream) { p.inboundStreamsMx.Unlock() }() - r := protoio.NewDelimitedReader(s, p.maxMessageSize) + r := msgio.NewVarintReaderSize(s, p.maxMessageSize) for { - rpc := new(RPC) - err := r.ReadMsg(&rpc.RPC) + msgbytes, err := r.ReadMsg() if err != nil { + r.ReleaseMsg(msgbytes) if err != io.EOF { s.Reset() log.Debugf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err) @@ -77,6 +77,15 @@ func (p *PubSub) handleNewStream(s network.Stream) { return } + rpc := new(RPC) + err = rpc.Unmarshal(msgbytes) + r.ReleaseMsg(msgbytes) + if err != nil { + s.Reset() + log.Warnf("bogus rpc from %s: %s", s.Conn().RemotePeer(), err) + return + } + rpc.from = peer select { case p.incoming <- rpc: @@ -131,12 +140,11 @@ func (p *PubSub) handleNewPeerWithBackoff(ctx context.Context, pid peer.ID, back } } -func (p *PubSub) handlePeerEOF(ctx context.Context, s network.Stream) { +func (p *PubSub) handlePeerEOF(_ context.Context, s network.Stream) { pid := s.Conn().RemotePeer() - r := protoio.NewDelimitedReader(s, p.maxMessageSize) - rpc := new(RPC) + r := msgio.NewVarintReaderSize(s, p.maxMessageSize) for { - err := r.ReadMsg(&rpc.RPC) + _, err := r.ReadMsg() if err != nil { p.notifyPeerDead(pid) return