Skip to content

Commit

Permalink
Fix: data race (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
Loyalsoldier authored May 2, 2021
1 parent 6c907b7 commit 2c58592
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 44 deletions.
15 changes: 13 additions & 2 deletions common/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package common
import (
"io"
"net"
"sync"

"github.com/p4gefau1t/trojan-go/log"
)

type RewindReader struct {
mu sync.Mutex
rawReader io.Reader
buf []byte
bufReadIdx int
Expand All @@ -17,13 +19,16 @@ type RewindReader struct {
}

func (r *RewindReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()

if r.rewound {
if len(r.buf) > r.bufReadIdx {
n := copy(p, r.buf[r.bufReadIdx:])
r.bufReadIdx += n
return n, nil
}
r.rewound = false //all buffering content has been read
r.rewound = false // all buffering content has been read
}
n, err := r.rawReader.Read(p)
if r.buffering {
Expand Down Expand Up @@ -59,19 +64,24 @@ func (r *RewindReader) Discard(n int) (int, error) {
}

func (r *RewindReader) Rewind() {
r.mu.Lock()
if r.bufferSize == 0 {
panic("no buffer")
}
r.rewound = true
r.bufReadIdx = 0
r.mu.Unlock()
}

func (r *RewindReader) StopBuffering() {
r.mu.Lock()
r.buffering = false
r.mu.Unlock()
}

func (r *RewindReader) SetBufferSize(size int) {
if size == 0 { //disable buffering
r.mu.Lock()
if size == 0 { // disable buffering
if !r.buffering {
panic("reader is disabled")
}
Expand All @@ -88,6 +98,7 @@ func (r *RewindReader) SetBufferSize(size int) {
r.bufferSize = size
r.buf = make([]byte, 0, size)
}
r.mu.Unlock()
}

type RewindConn struct {
Expand Down
86 changes: 44 additions & 42 deletions test/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,33 @@ import (

// CheckConn checks if two netConn were connected and work properly
func CheckConn(a net.Conn, b net.Conn) bool {
payload1 := [1024]byte{}
payload2 := [1024]byte{}
rand.Reader.Read(payload1[:])
rand.Reader.Read(payload2[:])
payload1 := make([]byte, 1024)
payload2 := make([]byte, 1024)

result1 := make([]byte, 1024)
result2 := make([]byte, 1024)

rand.Reader.Read(payload1)
rand.Reader.Read(payload2)

result1 := [1024]byte{}
result2 := [1024]byte{}
wg := sync.WaitGroup{}
wg.Add(2)

go func() {
a.Write(payload1[:])
a.Read(result2[:])
a.Write(payload1)
a.Read(result2)
wg.Done()
}()

go func() {
b.Read(result1[:])
b.Write(payload2[:])
b.Read(result1)
b.Write(payload2)
wg.Done()
}()

wg.Wait()
if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) {
return false
}
return true

return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2)
}

// CheckPacketOverConn checks if two PacketConn streaming over a connection work properly
Expand All @@ -45,55 +48,54 @@ func CheckPacketOverConn(a, b net.PacketConn) bool {
IP: net.ParseIP("127.0.0.1"),
Port: port,
}
payload1 := [1024]byte{}
payload2 := [1024]byte{}
rand.Reader.Read(payload1[:])
rand.Reader.Read(payload2[:])

result1 := [1024]byte{}
result2 := [1024]byte{}
payload1 := make([]byte, 1024)
payload2 := make([]byte, 1024)

result1 := make([]byte, 1024)
result2 := make([]byte, 1024)

common.Must2(a.WriteTo(payload1[:], addr))
_, addr1, err := b.ReadFrom(result1[:])
rand.Reader.Read(payload1)
rand.Reader.Read(payload2)

common.Must2(a.WriteTo(payload1, addr))
_, addr1, err := b.ReadFrom(result1)
common.Must(err)
if addr1.String() != addr.String() {
return false
}

common.Must2(a.WriteTo(payload2[:], addr))
_, addr2, err := b.ReadFrom(result2[:])
common.Must2(a.WriteTo(payload2, addr))
_, addr2, err := b.ReadFrom(result2)
common.Must(err)
if addr2.String() != addr.String() {
return false
}
if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) {
return false
}
return true

return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2)
}

func CheckPacket(a, b net.PacketConn) bool {
payload1 := [1024]byte{}
payload2 := [1024]byte{}
rand.Reader.Read(payload1[:])
rand.Reader.Read(payload2[:])
payload1 := make([]byte, 1024)
payload2 := make([]byte, 1024)

result1 := [1024]byte{}
result2 := [1024]byte{}
result1 := make([]byte, 1024)
result2 := make([]byte, 1024)

_, err := a.WriteTo(payload1[:], b.LocalAddr())
rand.Reader.Read(payload1)
rand.Reader.Read(payload2)

_, err := a.WriteTo(payload1, b.LocalAddr())
common.Must(err)
_, _, err = b.ReadFrom(result1[:])
_, _, err = b.ReadFrom(result1)
common.Must(err)

_, err = b.WriteTo(payload2[:], a.LocalAddr())
_, err = b.WriteTo(payload2, a.LocalAddr())
common.Must(err)
_, _, err = a.ReadFrom(result2[:])
_, _, err = a.ReadFrom(result2)
common.Must(err)
if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) {
return false
}
return true

return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2)
}

func GetTestAddr() string {
Expand Down
6 changes: 6 additions & 0 deletions tunnel/adapter/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package adapter
import (
"context"
"net"
"sync"

"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/config"
Expand All @@ -18,6 +19,7 @@ type Server struct {
udpListener net.PacketConn
socksConn chan tunnel.Conn
httpConn chan tunnel.Conn
socksLock sync.RWMutex
nextSocks bool
ctx context.Context
cancel context.CancelFunc
Expand Down Expand Up @@ -45,7 +47,9 @@ func (s *Server) acceptConnLoop() {
log.Error(common.NewError("failed to detect proxy protocol type").Base(err))
continue
}
s.socksLock.RLock()
if buf[0] == 5 && s.nextSocks {
s.socksLock.RUnlock()
log.Debug("socks5 connection")
s.socksConn <- &freedom.Conn{
Conn: rewindConn,
Expand All @@ -68,7 +72,9 @@ func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) {
return nil, common.NewError("adapter closed")
}
} else if _, ok := overlay.(*socks.Tunnel); ok {
s.socksLock.Lock()
s.nextSocks = true
s.socksLock.Unlock()
select {
case conn := <-s.socksConn:
return conn, nil
Expand Down
6 changes: 6 additions & 0 deletions tunnel/transport/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"os/exec"
"strconv"
"sync"
"time"

"github.com/p4gefau1t/trojan-go/common"
Expand All @@ -22,6 +23,7 @@ type Server struct {
cmd *exec.Cmd
connChan chan tunnel.Conn
wsChan chan tunnel.Conn
httpLock sync.RWMutex
nextHTTP bool
ctx context.Context
cancel context.CancelFunc
Expand Down Expand Up @@ -50,7 +52,9 @@ func (s *Server) acceptLoop() {

go func(tcpConn net.Conn) {
log.Info("tcp connection from", tcpConn.RemoteAddr())
s.httpLock.RLock()
if s.nextHTTP { // plaintext mode enabled
s.httpLock.RUnlock()
// we use real http header parser to mimic a real http server
rewindConn := common.NewRewindConn(tcpConn)
rewindConn.SetBufferSize(512)
Expand Down Expand Up @@ -84,7 +88,9 @@ func (s *Server) acceptLoop() {
func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) {
// TODO fix import cycle
if overlay != nil && (overlay.Name() == "WEBSOCKET" || overlay.Name() == "HTTP") {
s.httpLock.Lock()
s.nextHTTP = true
s.httpLock.Unlock()
select {
case conn := <-s.wsChan:
return conn, nil
Expand Down

0 comments on commit 2c58592

Please sign in to comment.