Skip to content

Commit

Permalink
Implement client side for TCP allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyLc committed Mar 29, 2023
1 parent 7abfa3b commit 2da1324
Show file tree
Hide file tree
Showing 14 changed files with 811 additions and 121 deletions.
155 changes: 117 additions & 38 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/pion/transport/v2/vnet"
"github.com/pion/turn/v2/internal/client"
"github.com/pion/turn/v2/internal/proto"
t "github.com/pion/turn/v2/internal/transport"
)

const (
Expand Down Expand Up @@ -42,6 +43,7 @@ type ClientConfig struct {
Realm string
Software string
RTO time.Duration
Dialer transport.Dialer
Conn net.PacketConn // Listening socket (net.PacketConn)
LoggerFactory logging.LoggerFactory
Net transport.Net
Expand All @@ -61,7 +63,8 @@ type Client struct {
software stun.Software // read-only
trMap *client.TransactionMap // thread-safe
rto time.Duration // read-only
relayedConn *client.UDPConn // protected by mutex ***
relayedConn client.RelayConn // protected by mutex ***
dialer transport.Dialer // read-only
allocTryLock client.TryLock // thread-safe
listenTryLock client.TryLock // thread-safe
net transport.Net // read-only
Expand All @@ -83,6 +86,10 @@ func NewClient(config *ClientConfig) (*Client, error) {
return nil, errNilConn
}

if config.Dialer == nil {
config.Dialer = &net.Dialer{}
}

var err error
if config.Net == nil {
config.Net, err = stdnet.NewNet() // defaults to native operation
Expand Down Expand Up @@ -121,6 +128,7 @@ func NewClient(config *ClientConfig) (*Client, error) {

c := &Client{
conn: config.Conn,
dialer: config.Dialer,
stunServ: stunServ,
turnServ: turnServ,
stunServStr: stunServStr,
Expand Down Expand Up @@ -238,42 +246,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) {
return c.SendBindingRequestTo(c.stunServ)
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.relayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}
func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) {
var relayed proto.RelayedAddress
var lifetime proto.Lifetime
var nonce stun.Nonce

msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: protocol},
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err := c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

res := trRes.Msg

// Anonymous allocate failed, trying to authenticate.
var nonce stun.Nonce
if err = nonce.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
if err = c.realm.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
c.realm = append([]byte(nil), c.realm...)
c.integrity = stun.NewLongTermIntegrity(
Expand All @@ -283,65 +283,119 @@ func (c *Client) Allocate() (net.PacketConn, error) {
msg, err = stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: protocol},
&c.username,
&c.realm,
&nonce,
&c.integrity,
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err = c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
res = trRes.Msg

if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113
}

// Getting relayed addresses from response.
var relayed proto.RelayedAddress
if err := relayed.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}

// Getting lifetime from response
if err := lifetime.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}
return relayed, lifetime, nonce, nil
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.getRelayedConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.(*client.UDPConn).LocalAddr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoUDP)
if err != nil {
return nil, err
}

relayedAddr := &net.UDPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

// Getting lifetime from response
var lifetime proto.Lifetime
if err := lifetime.GetFrom(res); err != nil {
relayedConn = client.NewUDPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
})
c.setRelayedConn(relayedConn)

return relayedConn.(*client.UDPConn), nil
}

// Allocate TCP
func (c *Client) AllocateTCP() (t.Relay, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.getRelayedConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.(*client.TCPConn).Addr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoTCP)
if err != nil {
return nil, err
}

relayedConn = client.NewUDPConn(&client.UDPConnConfig{
relayedAddr := &net.TCPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

relayedConn = client.NewTCPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
Dialer: c.dialer,
})

c.setRelayedUDPConn(relayedConn)
c.setRelayedConn(relayedConn)

return relayedConn, nil
return relayedConn.(*client.TCPConn), nil
}

// CreatePermission Issues a CreatePermission request for the supplied addresses
// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9
func (c *Client) CreatePermission(addrs ...net.Addr) error {
return c.relayedUDPConn().CreatePermissions(addrs...)
return c.getRelayedConn().CreatePermissions(addrs...)
}

// PerformTransaction performs STUN transaction
Expand Down Expand Up @@ -386,7 +440,7 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult
// OnDeallocated is called when de-allocation of relay address has been complete.
// (Called by UDPConn)
func (c *Client) OnDeallocated(relayedAddr net.Addr) {
c.setRelayedUDPConn(nil)
c.setRelayedConn(nil)
}

// HandleInbound handles data received.
Expand Down Expand Up @@ -445,7 +499,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

if msg.Type.Class == stun.ClassIndication {
if msg.Type.Method == stun.MethodData {
switch msg.Type.Method {
case stun.MethodData:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
Expand All @@ -462,13 +517,37 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {

c.log.Debugf("data indication received from %s", from.String())

relayedConn := c.relayedUDPConn()
relayedConn := c.getRelayedConn().(*client.UDPConn)
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleInbound(data, from)
case stun.MethodConnectionAttempt:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
}

from = &net.TCPAddr{
IP: peerAddr.IP,
Port: peerAddr.Port,
}

var cid proto.ConnectionID
if err := cid.GetFrom(msg); err != nil {
return err
}

c.log.Debugf("connection attempt from %s", from.String())

relayedConn := c.getRelayedConn().(*client.TCPConn)
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleConnectionAttempt(data, from, cid)
}
return nil
}
Expand Down Expand Up @@ -514,7 +593,7 @@ func (c *Client) handleChannelData(data []byte) error {
return err
}

relayedConn := c.relayedUDPConn()
relayedConn := c.getRelayedConn().(*client.UDPConn)
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
Expand Down Expand Up @@ -566,14 +645,14 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int) {
tr.StartRtxTimer(c.onRtxTimeout)
}

func (c *Client) setRelayedUDPConn(conn *client.UDPConn) {
func (c *Client) setRelayedConn(conn client.RelayConn) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.relayedConn = conn
}

func (c *Client) relayedUDPConn() *client.UDPConn {
func (c *Client) getRelayedConn() client.RelayConn {
c.mutex.RLock()
defer c.mutex.RUnlock()

Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,53 @@ func TestClientNonceExpiration(t *testing.T) {
assert.NoError(t, conn.Close())
assert.NoError(t, server.Close())
}

// Create a tcp-based allocation and verify allocation can be created
func TestTCPClient(t *testing.T) {
// Setup server
tcpListener, err := net.Listen("tcp4", "0.0.0.0:3478")
assert.NoError(t, err)

server, err := NewServer(ServerConfig{
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
return GenerateAuthKey(username, realm, "pass"), true
},
ListenerConfigs: []ListenerConfig{
{
Listener: tcpListener,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
},
},
Realm: "pion.ly",
})
assert.NoError(t, err)

// Setup clients
conn, err := net.Dial("tcp", "127.0.0.1:3478")
assert.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: NewSTUNConn(conn),
STUNServerAddr: "127.0.0.1:3478",
TURNServerAddr: "127.0.0.1:3478",
Username: "foo",
Password: "pass",
})
assert.NoError(t, err)
assert.NoError(t, client.Listen())

allocation, err := client.AllocateTCP()
assert.NoError(t, err)

// TODO: Implement server side handling of Connect and ConnectionBind
// _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
// assert.NoError(t, err)

// Shutdown
assert.NoError(t, allocation.Close())
assert.NoError(t, conn.Close())
assert.NoError(t, server.Close())
}
Loading

0 comments on commit 2da1324

Please sign in to comment.