forked from SenseUnit/dumbproxy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
handler.go
135 lines (121 loc) · 3.67 KB
/
handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package main
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
)
type ProxyHandler struct {
timeout time.Duration
auth Auth
logger *CondLogger
httptransport http.RoundTripper
outbound map[string]string
outboundMux sync.RWMutex
}
func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler {
httptransport := &http.Transport{}
return &ProxyHandler{
timeout: timeout,
auth: auth,
logger: logger,
httptransport: httptransport,
outbound: make(map[string]string),
}
}
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
ctx, _ := context.WithTimeout(req.Context(), s.timeout)
dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI)
if err != nil {
s.logger.Error("Can't satisfy CONNECT request: %v", err)
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
return
}
localAddr := conn.LocalAddr().String()
s.outboundMux.Lock()
s.outbound[localAddr] = req.RemoteAddr
s.outboundMux.Unlock()
defer func() {
conn.Close()
s.outboundMux.Lock()
delete(s.outbound, localAddr)
s.outboundMux.Unlock()
}()
if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
// Upgrade client connection
localconn, _, err := hijack(wr)
if err != nil {
s.logger.Error("Can't hijack client connection: %v", err)
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
return
}
defer localconn.Close()
// Inform client connection is built
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
proxy(req.Context(), localconn, conn)
} else if req.ProtoMajor == 2 {
wr.Header()["Date"] = nil
wr.WriteHeader(http.StatusOK)
flush(wr)
proxyh2(req.Context(), req.Body, wr, conn)
} else {
s.logger.Error("Unsupported protocol version: %s", req.Proto)
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
return
}
}
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
req.RequestURI = ""
if req.ProtoMajor == 2 {
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
req.URL.Host = req.Host
}
resp, err := s.httptransport.RoundTrip(req)
if err != nil {
s.logger.Error("HTTP fetch error: %v", err)
http.Error(wr, "Server Error", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
delHopHeaders(resp.Header)
copyHeader(wr.Header(), resp.Header)
wr.WriteHeader(resp.StatusCode)
flush(wr)
copyBody(wr, resp.Body)
}
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
s.outboundMux.RLock()
originator, found := s.outbound[req.RemoteAddr]
s.outboundMux.RUnlock()
return originator, found
}
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
if originator, isLoopback := s.isLoopback(req); isLoopback {
s.logger.Critical("Loopback tunnel detected: %s is an outbound "+
"address for another request from %s", req.RemoteAddr, originator)
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return
}
isConnect := strings.ToUpper(req.Method) == "CONNECT"
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
req.Host == "" && req.ProtoMajor == 2 {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return
}
username, ok := s.auth.Validate(wr, req)
s.logger.Info("Request: %v %q %v %v %v", req.RemoteAddr, username, req.Proto, req.Method, req.URL)
if !ok {
return
}
delHopHeaders(req.Header)
if isConnect {
s.HandleTunnel(wr, req)
} else {
s.HandleRequest(wr, req)
}
}