diff --git a/README.md b/README.md index c3cf3a2..a5284bf 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ `./cmd/*.md` ## @application 应用层 - [x] http [docs](./cmd/http.md) -- [ ] websocket [docs](./cmd/websocket.md) +- [x] websocket [docs](./cmd/websocket.md) ## @transport 传输层 - [x] tcp [docs](./cmd/tcp.md) diff --git a/cmd/application/websocket/websocketserver.go b/cmd/application/websocket/websocketserver.go index 1500951..c14e752 100644 --- a/cmd/application/websocket/websocketserver.go +++ b/cmd/application/websocket/websocketserver.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "log" "github.com/brewlin/net-protocol/pkg/logging" @@ -13,20 +14,23 @@ func init() { } func main() { serv := http.NewHTTP("tap1", "192.168.1.0/24", "192.168.1.1", "9502") - serv.HandleFunc("/websocket", echo) + serv.HandleFunc("/ws", echo) serv.HandleFunc("/", func(request *http.Request, response *http.Response) { response.End("hello") }) + fmt.Println("@main: server is start ip:192.168.1.1 port:9502 ") serv.ListenAndServ() } //websocket处理器 func echo(r *http.Request, w *http.Response) { - //协议升级 + fmt.Println("got http request ; start to upgrade websocket protocol....") + //协议升级 c *websocket.Conn c, err := websocket.Upgrade(r, w) if err != nil { - log.Print("Upgrade error:", err) + //升级协议失败,直接return 交由http处理响应 + fmt.Println("Upgrade error:", err) return } defer c.Close() @@ -37,7 +41,8 @@ func echo(r *http.Request, w *http.Response) { log.Println("read:", err) break } - log.Printf("recv:%s", message) - c.SendData(message) + fmt.Println("recv client msg:", string(message)) + // c.SendData(message ) + c.SendData([]byte("hello")) } } diff --git a/protocol/application/http/connection.go b/protocol/application/http/connection.go index 4257c8e..5a03e8b 100644 --- a/protocol/application/http/connection.go +++ b/protocol/application/http/connection.go @@ -1,8 +1,9 @@ package http import ( - "fmt" + "errors" "log" + "sync" "github.com/brewlin/net-protocol/pkg/buffer" "github.com/brewlin/net-protocol/pkg/waiter" @@ -29,6 +30,10 @@ type Connection struct { // 请求文件的真实路径 real_path string + //接受队列缓存区 + buf buffer.View + bufmu sync.RWMutex + q *waiter.Queue waitEntry waiter.Entry notifyC chan struct{} @@ -64,7 +69,6 @@ func newCon(e tcpip.Endpoint, q *waiter.Queue) *Connection { func (con *Connection) handler() { <-con.notifyC log.Println("@应用层 http: waiting new event trigger ...") - fmt.Println("@应用层 http: waiting new event trigger ...") for { v, _, err := con.socket.Read(con.addr) if err != nil { @@ -76,8 +80,8 @@ func (con *Connection) handler() { } con.recv_buf += string(v) } - fmt.Println("http协议原始数据:") - fmt.Println(con.recv_buf) + log.Println("http协议原始数据:") + log.Println(con.recv_buf) con.request.parse(con) //dispatch the route request defaultMux.dispatch(con) @@ -92,22 +96,63 @@ func (c *Connection) set_status_code(code int) { } //Write write -func (c *Connection) Write(buf []byte) *tcpip.Error { +func (c *Connection) Write(buf []byte) error { v := buffer.View(buf) - _, _, err := c.socket.Write(tcpip.SlicePayload(v), + c.socket.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{To: c.addr}) - return err + return nil } //Read data -func (c *Connection) Read(p []byte) (int, error) { - buf, _, err := c.socket.Read(c.addr) - if err != nil { - return 0, err +func (c *Connection) Read() ([]byte, error) { + + var buf []byte + var err error + for { + v, _, e := c.socket.Read(c.addr) + if e != nil { + err = e + break + } + buf = append(buf, v...) } - n := copy(p, buf) - return n, nil + if buf == nil { + return nil, err + } + return buf, nil + +} +//Readn 读取固定字节的数据 +func (c *Connection) Readn(p []byte) (int, error) { + c.bufmu.Lock() + defer c.bufmu.Unlock() + //获取足够长度的字节 + if len(p) > len(c.buf) { + + for { + if len(p) <= len(c.buf) { + break + } + buf, _, err := c.socket.Read(c.addr) + if err != nil { + if err == tcpip.ErrWouldBlock { + //阻塞等待数据 + <-c.notifyC + continue + } + return 0, err + } + c.buf = append(c.buf, buf...) + } + } + if len(p) > len(c.buf) { + return 0, errors.New("package len is smaller than p need") + } + + n := copy(p, c.buf) + c.buf = c.buf[len(p):] + return n, nil } //关闭连接 diff --git a/protocol/application/http/request.go b/protocol/application/http/request.go index ca7a20b..42b5719 100644 --- a/protocol/application/http/request.go +++ b/protocol/application/http/request.go @@ -1,7 +1,7 @@ package http import ( - "fmt" + "log" "strings" ) @@ -36,7 +36,7 @@ func (req *Request) parse(con *Connection) { buf := con.recv_buf req.method_raw, buf = match_until(buf, " ") - fmt.Println("@application http: header parse method_raw:", req.method_raw) + log.Println("@application http: header parse method_raw:", req.method_raw) if req.method_raw == "" { con.status_code = 400 @@ -45,7 +45,7 @@ func (req *Request) parse(con *Connection) { // 获得HTTP方法 req.method = get_method(req.method_raw) - fmt.Println("@application http: header parse method:", req.method) + log.Println("@application http: header parse method:", req.method) if req.method == HTTP_METHOD_NOT_SUPPORTED { con.set_status_code(501) @@ -56,7 +56,7 @@ func (req *Request) parse(con *Connection) { // 获得URI req.uri, buf = match_until(buf, " ") - fmt.Println("@application http: header parse uri:", req.uri) + log.Println("@application http: header parse uri:", req.uri) if req.uri == "" { con.status_code = 400 @@ -80,7 +80,7 @@ func (req *Request) parse(con *Connection) { // 获得HTTP版本 req.version_raw, buf = match_until(buf, "\r\n") - fmt.Println("@application http: header parse version_raw:", req.version_raw) + log.Println("@application http: header parse version_raw:", req.version_raw) if req.version_raw == "" { con.status_code = 400 @@ -95,8 +95,8 @@ func (req *Request) parse(con *Connection) { } else { con.set_status_code(400) } - fmt.Println("@application http: header parse version:", req.version) - fmt.Println("@application http: header parse status_code:", con.status_code) + log.Println("@application http: header parse version:", req.version) + log.Println("@application http: header parse status_code:", con.status_code) if con.status_code > 0 { return } diff --git a/protocol/application/http/response.go b/protocol/application/http/response.go index 6ab9dea..d7ddbc5 100644 --- a/protocol/application/http/response.go +++ b/protocol/application/http/response.go @@ -1,7 +1,7 @@ package http import ( - "fmt" + "log" "strconv" "github.com/brewlin/net-protocol/pkg/buffer" @@ -133,8 +133,8 @@ func (r *Response) build_and_send_response() { } buf += "\r\n" buf += r.entity_body - fmt.Println("@application http:response send 构建http响应包体") - fmt.Println(buf) + log.Println("@application http:response send 构建http响应包体") + log.Println(buf) // 将字符串缓存发送到客户端 r.send_all(buf) } diff --git a/protocol/application/http/server.go b/protocol/application/http/server.go index 7e18585..2a9e5e9 100644 --- a/protocol/application/http/server.go +++ b/protocol/application/http/server.go @@ -2,7 +2,6 @@ package http import ( "flag" - "fmt" "log" "net" "strconv" @@ -144,11 +143,10 @@ func (s *Server) ListenAndServ() { if err != nil { if err == tcpip.ErrWouldBlock { log.Println("@application http:", " now waiting to new client connection ...") - fmt.Println("@application http:", " now waiting to new client connection ...") <-notifyCh continue } - fmt.Println("@application http: Accept() failed: ", err) + log.Println("@application http: Accept() failed: ", err) panic(err) } @@ -157,9 +155,9 @@ func (s *Server) ListenAndServ() { } func (s *Server) dispatch(e tcpip.Endpoint, wq *waiter.Queue) { - fmt.Println("@application http: dispatch got new request") + log.Println("@application http: dispatch got new request") con := newCon(e, wq) con.handler() - fmt.Println("@application http: dispatch close this request") + log.Println("@application http: dispatch close this request") con.Close() } diff --git a/protocol/application/http/server_patttern.go b/protocol/application/http/server_patttern.go index 1046edb..ad20777 100644 --- a/protocol/application/http/server_patttern.go +++ b/protocol/application/http/server_patttern.go @@ -14,14 +14,6 @@ type muxEntry struct { pattern string } -//NewMuxEntry entry -func NewMuxEntry(pattern string, handler func(*Request, *Response)) muxEntry { - var entry muxEntry - entry.h = handler - entry.pattern = pattern - return entry -} - var defaultMux ServeMux //handle diff --git a/protocol/application/websocket/conn.go b/protocol/application/websocket/conn.go index 6a6e7e8..673a8c8 100644 --- a/protocol/application/websocket/conn.go +++ b/protocol/application/websocket/conn.go @@ -3,7 +3,6 @@ package websocket import ( "encoding/binary" "errors" - "fmt" "log" "github.com/brewlin/net-protocol/protocol/application/http" @@ -67,7 +66,7 @@ func (c *Conn) SendData(data []byte) { * => 1 0 0 0 0 0 0 1 */ c.writeBuf[0] = byte(TextMessage) | finalBit - fmt.Printf("1 bit:%b\n", c.writeBuf[0]) + log.Printf("1 bit:%b\n", c.writeBuf[0]) //数据帧第二个字节,服务器发送的数据不需要进行掩码处理 switch { @@ -88,7 +87,7 @@ func (c *Conn) SendData(data []byte) { //c.writeBuf[1] = byte(0x00) | byte(length) c.writeBuf[1] = byte(length) } - fmt.Printf("2 bit:%b\n", c.writeBuf[1]) + log.Printf("2 bit:%b\n", c.writeBuf[1]) copy(c.writeBuf[payloadStart:], data[:]) c.conn.Write(c.writeBuf[:payloadStart+length]) @@ -98,12 +97,12 @@ func (c *Conn) SendData(data []byte) { func (c *Conn) ReadData() (data []byte, err error) { var b [8]byte //读取数据帧的前两个字节 - if _, err := c.conn.Read(b[:2]); err != nil { + if _, err := c.conn.Readn(b[:2]); err != nil { return nil, err } //开始解析第一个字节 是否还有后续数据帧 final := b[0]&finalBit != 0 - fmt.Printf("read data 1 bit :%b\n", b[0]) + log.Printf("read data 1 bit :%b\n", b[0]) //不支持数据分片 if !final { log.Println("Recived fragemented frame,not support") @@ -138,13 +137,13 @@ func (c *Conn) ReadData() (data []byte, err error) { //根据payload length 判断数据的真实长度 switch payloadLen { case 126: //扩展2字节 - if _, err := c.conn.Read(b[:2]); err != nil { + if _, err := c.conn.Readn(b[:2]); err != nil { return nil, err } //获取扩展二字节的真实数据长度 dataLen = int64(binary.BigEndian.Uint16(b[:2])) case 127: - if _, err := c.conn.Read(b[:8]); err != nil { + if _, err := c.conn.Readn(b[:8]); err != nil { return nil, err } dataLen = int64(binary.BigEndian.Uint64(b[:8])) @@ -154,13 +153,13 @@ func (c *Conn) ReadData() (data []byte, err error) { //读取mask key if mask { //如果需要掩码处理的话 需要取出key //maskKey 是 4 字节 32位 - if _, err := c.conn.Read(c.maskKey[:]); err != nil { + if _, err := c.conn.Readn(c.maskKey[:]); err != nil { return nil, err } } //读取数据内容 p := make([]byte, dataLen) - if _, err := c.conn.Read(p); err != nil { + if _, err := c.conn.Readn(p); err != nil { return nil, err } if mask { diff --git a/protocol/application/websocket/upgrade.go b/protocol/application/websocket/upgrade.go index 31fba6f..809ee6d 100644 --- a/protocol/application/websocket/upgrade.go +++ b/protocol/application/websocket/upgrade.go @@ -2,6 +2,7 @@ package websocket import ( "errors" + "fmt" "log" "github.com/brewlin/net-protocol/protocol/application/http" @@ -14,13 +15,13 @@ func Upgrade(r *http.Request, w *http.Response) (c *Conn, err error) { return nil, errors.New("websocket:method not GET") } //检查 Sec-WebSocket-Version 版本 - if values := r.GetHeader("Sec-Websocket-Version"); values == "" || values != "13" { + if values := r.GetHeader("Sec-WebSocket-Version"); values == "" || values != "13" { w.Error(http.StatusBadRequest) return nil, errors.New("websocket:version != 13") } //检查Connection 和 Upgrade - if values := r.GetHeader("Connection"); values != "upgrade" { + if values := r.GetHeader("Connection"); !tokenListContainsValue(values, "upgrade") { w.Error(http.StatusBadRequest) return nil, errors.New("websocket:could not find connection header with token 'upgrade'") } @@ -30,7 +31,7 @@ func Upgrade(r *http.Request, w *http.Response) (c *Conn, err error) { } //计算Sec-Websocket-Accept的值 - challengeKey := r.GetHeader("Sec-Websocket-Key") + challengeKey := r.GetHeader("Sec-WebSocket-Key") if challengeKey == "" { w.Error(http.StatusBadRequest) return nil, errors.New("websocket:key missing or blank") @@ -45,6 +46,8 @@ func Upgrade(r *http.Request, w *http.Response) (c *Conn, err error) { p = append(p, "\r\n\r\n"...) //返回repson 但不关闭连接 if err = con.Write(p); err != nil { + fmt.Println(err == nil) + fmt.Println("write p err", err) return nil, err } //升级为websocket diff --git a/protocol/application/websocket/utils.go b/protocol/application/websocket/utils.go index 7f73e9f..cabaf6f 100644 --- a/protocol/application/websocket/utils.go +++ b/protocol/application/websocket/utils.go @@ -1,39 +1,36 @@ package websocket import ( - "crypto/sha1" - "encoding/base64" - "strings" - "net/http" + "crypto/sha1" + "encoding/base64" + "strings" ) - var KeyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + //握手阶段使用 加密key返回 进行握手 -func computeAcceptKey(challengeKey string)string{ - h := sha1.New() - h.Write([]byte(challengeKey)) - h.Write(KeyGUID) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(KeyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) } //解码 -func maskBytes(key [4]byte,b []byte){ - pos := 0 - for i := range b { - b[i] ^= key[pos & 3] - pos ++ - } +func maskBytes(key [4]byte, b []byte) { + pos := 0 + for i := range b { + b[i] ^= key[pos&3] + pos++ + } } // 检查http 头部字段中是否包含指定的值 -func tokenListContainsValue(header http.Header, name string, value string)bool{ - for _,v := range header[name] { - for _, s := range strings.Split(v,","){ - if strings.EqualFold(value,strings.TrimSpace(s)) { - return true - } - } - } - return false +func tokenListContainsValue(h string, value string) bool { + for _, s := range strings.Split(h, ",") { + if strings.EqualFold(value, strings.TrimSpace(s)) { + return true + } + } + return false } diff --git a/protocol/transport/tcp/endpoint.go b/protocol/transport/tcp/endpoint.go index c1991b2..e360598 100644 --- a/protocol/transport/tcp/endpoint.go +++ b/protocol/transport/tcp/endpoint.go @@ -470,6 +470,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, } // 从tcp的接收队列中读取数据,并从接收队列中删除已读数据 +// tcp 队列是切片 func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { if e.rcvClosed || e.state != stateConnected {