Skip to content

Commit

Permalink
change result type to raw struct
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 20, 2023
1 parent afb3958 commit 6ec7194
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 78 deletions.
11 changes: 6 additions & 5 deletions p2p/protocol/autonatv2/autonat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (

var (
ErrNoValidPeers = errors.New("no valid peers for autonat v2")
ErrDialRefused = errors.New("dial refused")
)

var (
Expand Down Expand Up @@ -146,28 +147,28 @@ type Result struct {

// CheckReachability makes a single dial request for checking reachability. For highPriorityAddrs dial charge is paid
// if the server asks for it. For lowPriorityAddrs dial charge is rejected.
func (an *AutoNAT) CheckReachability(ctx context.Context, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr) (*Result, error) {
func (an *AutoNAT) CheckReachability(ctx context.Context, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr) (Result, error) {
if !an.allowAllAddrs {
for _, a := range highPriorityAddrs {
if !manet.IsPublicAddr(a) {
return nil, fmt.Errorf("private address cannot be verified by autonatv2: %s", a)
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", a)
}
}
for _, a := range lowPriorityAddrs {
if !manet.IsPublicAddr(a) {
return nil, fmt.Errorf("private address cannot be verified by autonatv2: %s", a)
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", a)
}
}
}
p := an.peers.GetRand()
if p == "" {
return nil, ErrNoValidPeers
return Result{}, ErrNoValidPeers
}

res, err := an.cli.CheckReachability(ctx, p, highPriorityAddrs, lowPriorityAddrs)
if err != nil {
log.Debugf("reachability check with %s failed, err: %s", p, err)
return nil, fmt.Errorf("reachability check with %s failed: %w", p, err)
return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err)
}
log.Debugf("reachability check with %s successful", p)
return res, nil
Expand Down
10 changes: 5 additions & 5 deletions p2p/protocol/autonatv2/autonat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func identify(t *testing.T, cli *AutoNAT, srv *AutoNAT) {
func TestAutoNATPrivateAddr(t *testing.T) {
an := newAutoNAT(t, nil)
res, err := an.CheckReachability(context.Background(), []ma.Multiaddr{ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}, nil)
require.Nil(t, res)
require.Equal(t, res, Result{})
require.NotNil(t, err)
}

Expand Down Expand Up @@ -116,7 +116,7 @@ func TestClientRequest(t *testing.T) {
waitForPeer(t, an)

res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
require.Nil(t, res)
require.Equal(t, res, Result{})
require.NotNil(t, err)
require.True(t, gotReq.Load())
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func TestClientServerError(t *testing.T) {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
p.SetStreamHandler(DialProtocol, tc.handler)
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
require.Nil(t, res)
require.Equal(t, res, Result{})
require.NotNil(t, err)
<-done
})
Expand Down Expand Up @@ -252,7 +252,7 @@ func TestClientDataRequest(t *testing.T) {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
p.SetStreamHandler(DialProtocol, tc.handler)
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
require.Nil(t, res)
require.Equal(t, res, Result{})
require.NotNil(t, err)
<-done
})
Expand Down Expand Up @@ -475,7 +475,7 @@ func TestClientDialBacks(t *testing.T) {
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
if !tc.success {
if tc.isError {
require.Nil(t, res)
require.Equal(t, res, Result{})
require.Error(t, err)
} else {
require.NoError(t, err)
Expand Down
43 changes: 22 additions & 21 deletions p2p/protocol/autonatv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ func NewClient(h host.Host) *Client {

// CheckReachability verifies address reachability with a AutoNAT v2 server p. It'll provide dial data for dialing high
// priority addresses and not for low priority addresses.
func (ac *Client) CheckReachability(ctx context.Context, p peer.ID, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr) (*Result, error) {
func (ac *Client) CheckReachability(ctx context.Context, p peer.ID, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr) (Result, error) {
ctx, cancel := context.WithTimeout(ctx, streamTimeout)
defer cancel()

s, err := ac.host.NewStream(ctx, p, DialProtocol)
if err != nil {
return nil, fmt.Errorf("open %s stream: %w", DialProtocol, err)
return Result{}, fmt.Errorf("open %s stream: %w", DialProtocol, err)
}

if err := s.Scope().SetService(ServiceName); err != nil {
s.Reset()
return nil, fmt.Errorf("attach stream %s to service %s: %w", DialProtocol, ServiceName, err)
return Result{}, fmt.Errorf("attach stream %s to service %s: %w", DialProtocol, ServiceName, err)
}

if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
s.Reset()
return nil, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err)
return Result{}, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err)
}
defer s.Scope().ReleaseMemory(maxMsgSize)

Expand All @@ -74,13 +74,13 @@ func (ac *Client) CheckReachability(ctx context.Context, p peer.ID, highPriority
w := pbio.NewDelimitedWriter(s)
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
return nil, fmt.Errorf("dial request write: %w", err)
return Result{}, fmt.Errorf("dial request write: %w", err)
}

r := pbio.NewDelimitedReader(s, maxMsgSize)
if err := r.ReadMsg(&msg); err != nil {
s.Reset()
return nil, fmt.Errorf("dial msg read: %w", err)
return Result{}, fmt.Errorf("dial msg read: %w", err)
}

switch {
Expand All @@ -89,36 +89,40 @@ func (ac *Client) CheckReachability(ctx context.Context, p peer.ID, highPriority
case msg.GetDialDataRequest() != nil:
if int(msg.GetDialDataRequest().AddrIdx) >= len(highPriorityAddrs) {
s.Reset()
return nil, fmt.Errorf("dial data requested for low priority address")
return Result{}, fmt.Errorf("dial data requested for low priority address")
}
if msg.GetDialDataRequest().NumBytes > maxHandshakeSizeBytes {
s.Reset()
return nil, fmt.Errorf("dial data requested too high: %d", msg.GetDialDataRequest().NumBytes)
return Result{}, fmt.Errorf("dial data requested too high: %d", msg.GetDialDataRequest().NumBytes)
}
if err := ac.sendDialData(msg.GetDialDataRequest(), w, &msg); err != nil {
s.Reset()
return nil, fmt.Errorf("dial data send: %w", err)
return Result{}, fmt.Errorf("dial data send: %w", err)
}
if err := r.ReadMsg(&msg); err != nil {
s.Reset()
return nil, fmt.Errorf("dial response read: %w", err)
return Result{}, fmt.Errorf("dial response read: %w", err)
}
if msg.GetDialResponse() == nil {
s.Reset()
return nil, fmt.Errorf("invalid response type: %T", msg.Msg)
return Result{}, fmt.Errorf("invalid response type: %T", msg.Msg)
}
default:
s.Reset()
return nil, fmt.Errorf("invalid msg type: %T", msg.Msg)
return Result{}, fmt.Errorf("invalid msg type: %T", msg.Msg)
}

resp := msg.GetDialResponse()
if resp.GetStatus() != pbv2.DialResponse_ResponseStatus_OK {
return nil, fmt.Errorf("dial request failed: status %d %s", resp.GetStatus(),
// server couldn't dial any requested address
if resp.GetStatus() == pbv2.DialResponse_E_DIAL_REFUSED {
return Result{}, fmt.Errorf("dial request: %w", ErrDialRefused)
}
return Result{}, fmt.Errorf("dial request: status %d %s", resp.GetStatus(),
pbv2.DialStatus_name[int32(resp.GetStatus())])
}
if resp.GetDialStatus() == pbv2.DialStatus_E_INTERNAL_ERROR {
return nil, fmt.Errorf("dial request failed: received invalid dial status 0")
if resp.GetDialStatus() == pbv2.DialStatus_UNUSED {
return Result{}, fmt.Errorf("dial request failed: received invalid dial status 0")
}

var dialBackAddr ma.Multiaddr
Expand All @@ -135,12 +139,9 @@ func (ac *Client) CheckReachability(ctx context.Context, p peer.ID, highPriority
return ac.newResults(resp, highPriorityAddrs, lowPriorityAddrs, dialBackAddr)
}

func (ac *Client) newResults(resp *pbv2.DialResponse, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr, dialBackAddr ma.Multiaddr) (*Result, error) {
if resp.DialStatus == pbv2.DialStatus_E_DIAL_REFUSED {
return &Result{Idx: -1, Reachability: network.ReachabilityUnknown, Status: pbv2.DialStatus_E_DIAL_REFUSED}, nil
}
func (ac *Client) newResults(resp *pbv2.DialResponse, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr, dialBackAddr ma.Multiaddr) (Result, error) {
if int(resp.AddrIdx) >= len(highPriorityAddrs)+len(lowPriorityAddrs) {
return nil, fmt.Errorf("addrIdx out of range: %d 0-%d", resp.AddrIdx, len(highPriorityAddrs)+len(lowPriorityAddrs)-1)
return Result{}, fmt.Errorf("addrIdx out of range: %d 0-%d", resp.AddrIdx, len(highPriorityAddrs)+len(lowPriorityAddrs)-1)
}

idx := int(resp.AddrIdx)
Expand All @@ -163,7 +164,7 @@ func (ac *Client) newResults(resp *pbv2.DialResponse, highPriorityAddrs []ma.Mul
case pbv2.DialStatus_E_DIAL_ERROR:
rch = network.ReachabilityPrivate
}
return &Result{
return Result{
Idx: idx,
Addr: addr,
Reachability: rch,
Expand Down
58 changes: 29 additions & 29 deletions p2p/protocol/autonatv2/pbv2/autonat.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions p2p/protocol/autonatv2/pbv2/autonat.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ message DialDataRequest {


enum DialStatus {
E_INTERNAL_ERROR = 0; // Default value to force servers to explicitly set the status
E_DIAL_REFUSED = 100;
E_DIAL_ERROR = 101;
E_DIAL_BACK_ERROR = 102;
UNUSED = 0;
E_DIAL_ERROR = 100;
E_DIAL_BACK_ERROR = 101;
OK = 200;
}

Expand All @@ -36,6 +35,7 @@ message DialResponse {
enum ResponseStatus {
E_INTERNAL_ERROR = 0;
E_REQUEST_REJECTED = 100;
E_DIAL_REFUSED = 101;
ResponseStatus_OK = 200;
}

Expand Down
5 changes: 1 addition & 4 deletions p2p/protocol/autonatv2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ func (as *Server) handleDialRequest(s network.Stream) {
msg = pbv2.Message{
Msg: &pbv2.Message_DialResponse{
DialResponse: &pbv2.DialResponse{
Status: pbv2.DialResponse_ResponseStatus_OK,
DialStatus: pbv2.DialStatus_E_DIAL_REFUSED,
// send an invalid index to prevent accidental misuse
AddrIdx: uint32(len(msg.GetDialRequest().Addrs)),
Status: pbv2.DialResponse_E_DIAL_REFUSED,
},
},
}
Expand Down
Loading

0 comments on commit 6ec7194

Please sign in to comment.