Skip to content

Commit

Permalink
add a wait before dialing different IPs
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Jun 21, 2024
1 parent adf26bd commit 9667cd8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 24 deletions.
34 changes: 21 additions & 13 deletions p2p/protocol/autonatv2/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@ import "time"

// autoNATSettings is used to configure AutoNAT
type autoNATSettings struct {
allowPrivateAddrs bool
serverRPM int
serverPerPeerRPM int
serverDialDataRPM int
dataRequestPolicy dataRequestPolicyFunc
now func() time.Time
allowPrivateAddrs bool
serverRPM int
serverPerPeerRPM int
serverDialDataRPM int
dataRequestPolicy dataRequestPolicyFunc
now func() time.Time
amplificatonAttackPreventionDialWait time.Duration
}

func defaultSettings() *autoNATSettings {
return &autoNATSettings{
allowPrivateAddrs: false,
// TODO: confirm rate limiting defaults
serverRPM: 20,
serverPerPeerRPM: 2,
serverDialDataRPM: 5,
dataRequestPolicy: amplificationAttackPrevention,
now: time.Now,
allowPrivateAddrs: false,
serverRPM: 60, // 1 every second
serverPerPeerRPM: 12, // 1 every 5 seconds
serverDialDataRPM: 12, // 1 every 5 seconds
dataRequestPolicy: amplificationAttackPrevention,
amplificatonAttackPreventionDialWait: 3 * time.Second,
now: time.Now,
}
}

Expand All @@ -46,3 +47,10 @@ func allowPrivateAddrs(s *autoNATSettings) error {
s.allowPrivateAddrs = true
return nil
}

func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption {
return func(s *autoNATSettings) error {
s.amplificatonAttackPreventionDialWait = d
return nil
}
}
35 changes: 26 additions & 9 deletions p2p/protocol/autonatv2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import (
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
"github.com/libp2p/go-msgio/pbio"

"math/rand"

ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"golang.org/x/exp/rand"
)

type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool
Expand All @@ -32,7 +33,8 @@ type server struct {

// dialDataRequestPolicy is used to determine whether dialing the address requires receiving
// dial data. It is set to amplification attack prevention by default.
dialDataRequestPolicy dataRequestPolicyFunc
dialDataRequestPolicy dataRequestPolicyFunc
amplificatonAttackPreventionDialWait time.Duration

// for tests
now func() time.Time
Expand All @@ -41,10 +43,11 @@ type server struct {

func newServer(host, dialer host.Host, s *autoNATSettings) *server {
return &server{
dialerHost: dialer,
host: host,
dialDataRequestPolicy: s.dataRequestPolicy,
allowPrivateAddrs: s.allowPrivateAddrs,
dialerHost: dialer,
host: host,
dialDataRequestPolicy: s.dataRequestPolicy,
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
allowPrivateAddrs: s.allowPrivateAddrs,
limiter: &rateLimiter{
RPM: s.serverRPM,
PerPeerRPM: s.serverPerPeerRPM,
Expand Down Expand Up @@ -81,6 +84,9 @@ func (as *server) handleDialRequest(s network.Stream) {
}
defer s.Scope().ReleaseMemory(maxMsgSize)

deadline := as.now().Add(streamTimeout)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
s.SetDeadline(as.now().Add(streamTimeout))
defer s.Close()

Expand Down Expand Up @@ -183,9 +189,20 @@ func (as *server) handleDialRequest(s network.Stream) {
log.Debugf("%s refused dial data request: %s", p, err)
return
}
// wait for a bit to prevent thundering herd style attacks on a victim
waitTime := time.Duration(rand.Intn(int(as.amplificatonAttackPreventionDialWait) + 1)) // the range is [0, n)
t := time.NewTimer(waitTime)
defer t.Stop()
select {
case <-ctx.Done():
s.Reset()
log.Debugf("rejecting request without dialing: %s %p ", p, ctx.Err())
return
case <-t.C:
}
}

dialStatus := as.dialBack(s.Conn().RemotePeer(), dialAddr, nonce)
dialStatus := as.dialBack(ctx, s.Conn().RemotePeer(), dialAddr, nonce)
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Expand Down Expand Up @@ -252,8 +269,8 @@ func readDialData(numBytes int, r io.Reader) error {
return nil
}

func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus {
ctx, cancel := context.WithTimeout(context.Background(), dialBackDialTimeout)
func (as *server) dialBack(ctx context.Context, p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus {
ctx, cancel := context.WithTimeout(ctx, dialBackDialTimeout)
ctx = network.WithForceDirectDial(ctx, "autonatv2")
as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL)
defer func() {
Expand Down
53 changes: 51 additions & 2 deletions p2p/protocol/autonatv2/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func newTestRequests(addrs []ma.Multiaddr, sendDialData bool) (reqs []Request) {
}

func TestServerInvalidAddrsRejected(t *testing.T) {
c := newAutoNAT(t, nil, allowPrivateAddrs)
c := newAutoNAT(t, nil, allowPrivateAddrs, withAmplificationAttackPreventionDialWait(0))
defer c.Close()
defer c.host.Close()

Expand All @@ -46,7 +46,6 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
idAndWait(t, c, an)

res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
fmt.Println(res, err)
require.ErrorIs(t, err, ErrDialRefused)
require.Equal(t, Result{}, res)
})
Expand Down Expand Up @@ -151,6 +150,7 @@ func TestServerDataRequest(t *testing.T) {
return false
}),
WithServerRateLimit(10, 10, 10),
withAmplificationAttackPreventionDialWait(0),
)
defer an.Close()
defer an.host.Close()
Expand Down Expand Up @@ -187,6 +187,55 @@ func TestServerDataRequest(t *testing.T) {
_, err = c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}})
require.Error(t, err)
}
func TestServerDataRequestJitter(t *testing.T) {
// server will skip all tcp addresses
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
// ask for dial data for quic address
an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy(
func(s network.Stream, dialAddr ma.Multiaddr) bool {
if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil {
return true
}
return false
}),
WithServerRateLimit(10, 10, 10),
withAmplificationAttackPreventionDialWait(5*time.Second),
)
defer an.Close()
defer an.host.Close()

c := newAutoNAT(t, nil, allowPrivateAddrs)
defer c.Close()
defer c.host.Close()

idAndWait(t, c, an)

var quicAddr, tcpAddr ma.Multiaddr
for _, a := range c.host.Addrs() {
if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil {
quicAddr = a
} else if _, err := a.ValueForProtocol(ma.P_TCP); err == nil {
tcpAddr = a
}
}

for i := 0; i < 10; i++ {
st := time.Now()
res, err := c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}})
took := time.Since(st)
require.NoError(t, err)

require.Equal(t, Result{
Addr: quicAddr,
Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res)
if took > 500*time.Millisecond {
return
}
}
t.Fatalf("expected server to delay at least 1 dial")
}

func TestServerDial(t *testing.T) {
an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowPrivateAddrs)
Expand Down

0 comments on commit 9667cd8

Please sign in to comment.