-
Notifications
You must be signed in to change notification settings - Fork 0
/
pgssl.go
132 lines (112 loc) · 3.64 KB
/
pgssl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package main
import (
"crypto/tls"
"fmt"
"net"
"github.com/jackc/pgproto3/v2"
)
type PgSSL struct {
pgAddr string
clientCert *tls.Certificate
connectionPassword string
}
func (p *PgSSL) HandleConn(clientConn net.Conn) error {
// close client connection in the end
defer clientConn.Close()
// we pose as a backend for client connection
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn)
// receive startup message from the client
clientStartupMessage, err := backend.ReceiveStartupMessage()
if err != nil {
return fmt.Errorf("error receiving startup message: %w", err)
}
switch clientStartupMessage.(type) {
case *pgproto3.StartupMessage:
// ok
case *pgproto3.SSLRequest:
_, err := clientConn.Write([]byte{'N'})
if err != nil {
return fmt.Errorf("error while sending SSL-decline to client: %s", err)
}
return fmt.Errorf("client must not use SSL")
default:
return fmt.Errorf("unknown startup message: %#v", clientStartupMessage)
}
if p.connectionPassword != "" {
backend.Send(&pgproto3.AuthenticationCleartextPassword{})
// receive password message from the client
clientPasswordMessage, err := backend.Receive()
if err != nil {
return fmt.Errorf("error receiving password message: %w", err)
}
clientPassword := clientPasswordMessage.(*pgproto3.PasswordMessage).Password
if clientPassword != p.connectionPassword {
backend.Send(&pgproto3.ErrorResponse{Severity: "FATAL", Message: "wrong password"})
return fmt.Errorf("wrong password")
}
// no need to send AuthenticationOk, it will be sent by the postgres backend
}
// open connection to postgres backend
pgConn, err := net.Dial("tcp", p.pgAddr)
if err != nil {
return err
}
defer func() {
pgConn.Close()
}()
// we pose as a frontend for Postgres connection
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn), pgConn)
// send SSL request to postgres
err = frontend.Send(&pgproto3.SSLRequest{})
if err != nil {
return err
}
// The server then responds with a single byte containing S or N, indicating that it is willing or unwilling to perform SSL, respectively.
// If additional bytes are available to read at this point, it likely means that a man-in-the-middle is attempting to perform a buffer-stuffing attack (CVE-2021-23222).
buf := make([]byte, 2)
n, err := pgConn.Read(buf)
if err != nil {
return err
}
if n != 1 {
return fmt.Errorf("server returned more than 1 byte to SSLrequest, this is not expected")
}
if buf[0] == 'N' {
return fmt.Errorf("server declined SSL communication")
}
if buf[0] != 'S' {
return fmt.Errorf("unexpected response to SSLrequest: %v", buf[0])
}
// upgrade connection to TLS
pgTLSconn := tls.Client(pgConn, &tls.Config{
InsecureSkipVerify: true,
})
if p.clientCert != nil {
// Add client keypair to our upgraded connection
pgTLSconn = tls.Client(pgConn, &tls.Config{
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { return p.clientCert, nil },
InsecureSkipVerify: true,
})
}
// upgrade frontend
frontend = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgTLSconn), pgTLSconn)
defer frontend.Send(&pgproto3.Terminate{})
err = pgTLSconn.Handshake()
if err != nil {
return fmt.Errorf("handshake error: %v", err)
}
// send original startup
err = frontend.Send(clientStartupMessage)
if err != nil {
return err
}
// pipe connections
pgConnErr, clientConnErr := Pipe(clientConn, pgTLSconn)
if pgConnErr != nil {
return fmt.Errorf("postgres connection error: %s", pgConnErr)
}
if clientConnErr != nil {
return fmt.Errorf("client connection error: %s", clientConnErr)
}
return nil
}