diff --git a/common/io_test.go b/common/io_test.go new file mode 100644 index 000000000..12325457a --- /dev/null +++ b/common/io_test.go @@ -0,0 +1,37 @@ +package common + +import ( + "bytes" + "crypto/rand" + "testing" + "v2ray.com/core/common" +) + +func TestBufferedReader(t *testing.T) { + payload := [1024]byte{} + rand.Reader.Read(payload[:]) + rawReader := bytes.NewBuffer(payload[:]) + r := RewindReader{ + rawReader: rawReader, + } + r.SetBufferSize(2048) + buf1 := make([]byte, 512) + buf2 := make([]byte, 512) + common.Must2(r.Read(buf1)) + r.Rewind() + common.Must2(r.Read(buf2)) + if !bytes.Equal(buf1, buf2) { + t.Fail() + } + buf3 := make([]byte, 512) + common.Must2(r.Read(buf3)) + if !bytes.Equal(buf3, payload[512:]) { + t.Fail() + } + r.Rewind() + buf4 := make([]byte, 1024) + common.Must2(r.Read(buf4)) + if !bytes.Equal(payload[:], buf4) { + t.Fail() + } +} diff --git a/tunnel/transport/server.go b/tunnel/transport/server.go index 224856f6d..f501c01da 100644 --- a/tunnel/transport/server.go +++ b/tunnel/transport/server.go @@ -52,16 +52,15 @@ func (s *Server) acceptLoop() { r := bufio.NewReader(rewindConn) httpReq, err := http.ReadRequest(r) rewindConn.Rewind() + rewindConn.StopBuffering() if err != nil { // this is not a http request, pass it to trojan protocol layer for further inspection - rewindConn.StopBuffering() s.connChan <- &Conn{ Conn: rewindConn, } } else { // this is a http request, pass it to websocket protocol layer log.Debug("plaintext http request: ", httpReq) - rewindConn.StopBuffering() s.wsChan <- &Conn{ Conn: rewindConn, } diff --git a/tunnel/trojan/server.go b/tunnel/trojan/server.go index 3f71ec92a..295a1b965 100644 --- a/tunnel/trojan/server.go +++ b/tunnel/trojan/server.go @@ -124,7 +124,6 @@ func (s *Server) acceptLoop() { go func(conn tunnel.Conn) { rewindConn := common.NewRewindConn(conn) rewindConn.SetBufferSize(128) - defer rewindConn.StopBuffering() inboundConn := &InboundConn{ Conn: rewindConn, diff --git a/tunnel/trojan/trojan_test.go b/tunnel/trojan/trojan_test.go index 3621b0267..2408756e6 100644 --- a/tunnel/trojan/trojan_test.go +++ b/tunnel/trojan/trojan_test.go @@ -12,6 +12,7 @@ import ( "github.com/p4gefau1t/trojan-go/test/util" "github.com/p4gefau1t/trojan-go/tunnel" "github.com/p4gefau1t/trojan-go/tunnel/transport" + "io" "net" "testing" ) @@ -119,7 +120,7 @@ func TestTrojan(t *testing.T) { sendBuf := util.GeneratePayload(1024) recvBuf := [1024]byte{} common.Must2(conn.Write(sendBuf)) - common.Must2(conn.Read(recvBuf[:])) + common.Must2(io.ReadFull(conn, recvBuf[:])) if !bytes.Equal(sendBuf, recvBuf[:]) { fmt.Println(sendBuf) fmt.Println(recvBuf[:])