From e684676eefcbcfa843a39686b931d9327237cf7f Mon Sep 17 00:00:00 2001 From: Gaukas Wang Date: Wed, 10 Apr 2024 10:55:29 -0600 Subject: [PATCH] fix: `(*UConn).Read()` and Secure Renegotiation (#292) * fix: UConn Read does not trigger correct Handshake Copy `(*Conn).Read` to `(*UConn).Read` and force it use `(*UConn).Handshake`. Same for `handleRenegotiation` and `handlePostHandshakeMessage`. Signed-off-by: Gaukas Wang * update: use VerifyData in RenegotiationInfoExt This make sure the renegotiation would work in certain scenarios instead of no scenarios. Signed-off-by: Gaukas Wang --------- Signed-off-by: Gaukas Wang --- u_conn.go | 131 ++++++++++++++++++++++++++++++++++++++++++++ u_tls_extensions.go | 23 +++++--- 2 files changed, 146 insertions(+), 8 deletions(-) diff --git a/u_conn.go b/u_conn.go index d7e569a6..6f531c5c 100644 --- a/u_conn.go +++ b/u_conn.go @@ -889,3 +889,134 @@ type utlsConnExtraFields struct { sessionController *sessionController } + +// Read reads data from the connection. +// +// As Read calls [Conn.Handshake], in order to prevent indefinite blocking a deadline +// must be set for both Read and [Conn.Write] before Read is called when the handshake +// has not yet completed. See [Conn.SetDeadline], [Conn.SetReadDeadline], and +// [Conn.SetWriteDeadline]. +func (c *UConn) Read(b []byte) (int, error) { + if err := c.Handshake(); err != nil { + return 0, err + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + c.in.Lock() + defer c.in.Unlock() + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// handleRenegotiation processes a HelloRequest handshake message. +func (c *UConn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake(nil) + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + c.isHandshakeComplete.Store(false) + + // [uTLS section begins] + if err = c.BuildHandshakeState(); err != nil { + return err + } + // [uTLS section ends] + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *UConn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake(nil) + if err != nil { + return err + } + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + } + // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest + // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an + // unexpected_message alert here doesn't provide it with enough information to distinguish + // this condition from other unexpected messages. This is probably fine. + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) +} diff --git a/u_tls_extensions.go b/u_tls_extensions.go index aca175d1..f68e7b6e 100644 --- a/u_tls_extensions.go +++ b/u_tls_extensions.go @@ -1563,11 +1563,11 @@ type RenegotiationInfoExtension struct { // If this is the initial handshake for a connection, then the // "renegotiated_connection" field is of zero length in both the // ClientHello and the ServerHello. - // RenegotiatedConnection []byte + RenegotiatedConnection []byte } func (e *RenegotiationInfoExtension) Len() int { - return 5 // + len(e.RenegotiatedConnection) + return 5 + len(e.RenegotiatedConnection) } func (e *RenegotiationInfoExtension) Read(b []byte) (int, error) { @@ -1575,15 +1575,15 @@ func (e *RenegotiationInfoExtension) Read(b []byte) (int, error) { return 0, io.ErrShortBuffer } - // dataLen := len(e.RenegotiatedConnection) - extBodyLen := 1 // + len(dataLen) + dataLen := len(e.RenegotiatedConnection) + extBodyLen := 1 + dataLen b[0] = byte(extensionRenegotiationInfo >> 8) b[1] = byte(extensionRenegotiationInfo & 0xff) b[2] = byte(extBodyLen >> 8) b[3] = byte(extBodyLen) - // b[4] = byte(dataLen) - // copy(b[5:], e.RenegotiatedConnection) + b[4] = byte(dataLen) + copy(b[5:], e.RenegotiatedConnection) return e.Len(), io.EOF } @@ -1593,7 +1593,7 @@ func (e *RenegotiationInfoExtension) UnmarshalJSON(_ []byte) error { return nil } -func (e *RenegotiationInfoExtension) Write(_ []byte) (int, error) { +func (e *RenegotiationInfoExtension) Write(b []byte) (int, error) { e.Renegotiation = RenegotiateOnceAsClient // none empty or other modes are unsupported // extData := cryptobyte.String(b) // var renegotiatedConnection cryptobyte.String @@ -1602,7 +1602,10 @@ func (e *RenegotiationInfoExtension) Write(_ []byte) (int, error) { // } // e.RenegotiatedConnection = make([]byte, len(renegotiatedConnection)) // copy(e.RenegotiatedConnection, renegotiatedConnection) - return 0, nil + + // we don't really want to parse it at all. + + return len(b), nil } func (e *RenegotiationInfoExtension) writeToUConn(uc *UConn) error { @@ -1612,6 +1615,10 @@ func (e *RenegotiationInfoExtension) writeToUConn(uc *UConn) error { fallthrough case RenegotiateFreelyAsClient: uc.HandshakeState.Hello.SecureRenegotiationSupported = true + // TODO: don't do backward propagation here + if uc.handshakes > 0 { + e.RenegotiatedConnection = uc.clientFinished[:] + } case RenegotiateNever: default: }