Skip to content

Commit

Permalink
Validate improper TLS RSA certificates early-on
Browse files Browse the repository at this point in the history
This commit makes validation of TLS certificates with either too big RSA keys, or the wrong exponent, fail as soon as the remote node presents its TLS certificate, in contrast to after the TLS handshake.

Signed-off-by: Yacov Manevich <[email protected]>
  • Loading branch information
yacovm committed Oct 24, 2024
1 parent dcb2f70 commit 7d08cc6
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 12 deletions.
25 changes: 25 additions & 0 deletions network/peer/tls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
package peer

import (
"crypto/rsa"
"crypto/tls"
"errors"
"io"

"github.com/ava-labs/avalanchego/staking"
)

// TLSConfig returns the TLS config that will allow secure connections to other
Expand All @@ -26,5 +30,26 @@ func TLSConfig(cert tls.Certificate, keyLogWriter io.Writer) *tls.Config {
InsecureSkipVerify: true, //#nosec G402
MinVersion: tls.VersionTLS13,
KeyLogWriter: keyLogWriter,
VerifyConnection: ValidateRSACertificate,
}
}

// ValidateRSACertificate validates TLS certificates
// with RSA public keys in the leaf of the certificate chain of the given connection state.
func ValidateRSACertificate(cs tls.ConnectionState) error {
if len(cs.PeerCertificates) == 0 {
return errors.New("no certificates sent by peer")
}

pk := cs.PeerCertificates[0].PublicKey
if pk == nil {
return errors.New("no public key sent by peer")
}

switch rsaKey := pk.(type) {
case *rsa.PublicKey:
return staking.ValidateRSAPublicKeyIsWellFormed(rsaKey)
default:
return nil
}
}
68 changes: 68 additions & 0 deletions network/peer/tls_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package peer_test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/network/peer"
)

func TestValidateRSACertificate(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

x509Cert := makeRSACertAndKey(t, key)

x509CertWithNoPK := makeRSACertAndKey(t, key)
x509CertWithNoPK.cert.PublicKey = nil

ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

basicCert := basicCert()
certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &ecKey.PublicKey, ecKey)
require.NoError(t, err)

ecCert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)

for _, testCase := range []struct {
description string
input tls.ConnectionState
expectedErr error
}{
{
description: "Valid TLS cert",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509Cert.cert}},
},
{
description: "No TLS certs given",
input: tls.ConnectionState{},
expectedErr: errors.New("no certificates sent by peer"),
},
{
description: "No public key in the cert given",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509CertWithNoPK.cert}},
expectedErr: errors.New("no public key sent by peer"),
},
{
description: "EC cert",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}},
},
} {
t.Run(testCase.description, func(t *testing.T) {
require.Equal(t, testCase.expectedErr, peer.ValidateRSACertificate(testCase.input))
})
}
}
216 changes: 216 additions & 0 deletions network/peer/upgrader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package peer_test

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/network/peer"
"github.com/ava-labs/avalanchego/staking"
)

func TestBlockClientsWithIncorrectRSAKeys(t *testing.T) {
privKey2048, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

privKey4096, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)

privKey8192, err := rsa.GenerateKey(rand.Reader, 8192)
require.NoError(t, err)

clientCert2048 := makeTLSCert(t, privKey2048)
clientCert4096 := makeTLSCert(t, privKey4096)
clientCert8192 := makeTLSCert(t, privKey8192)
clientCertBad := makeTLSCert(t, nonStandardRSAKey(t))

for _, testCase := range []struct {
description string
clientTLSCert tls.Certificate
shouldSucceed bool
expectedErr error
}{
{
description: "Proper key size and private key - 2048",
clientTLSCert: clientCert2048,
shouldSucceed: true,
},
{
description: "Proper key size and private key - 4096",
clientTLSCert: clientCert4096,
shouldSucceed: true,
},
{
description: "Too big key",
clientTLSCert: clientCert8192,
expectedErr: staking.ErrUnsupportedRSAModulusBitLen,
},
{
description: "Improper public exponent",
clientTLSCert: clientCertBad,
expectedErr: staking.ErrUnsupportedRSAPublicExponent,
},
} {
t.Run(testCase.description, func(t *testing.T) {
t.Parallel()
serverKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

serverCert := makeTLSCert(t, serverKey)

config := peer.TLSConfig(serverCert, nil)

c := prometheus.NewCounter(prometheus.CounterOpts{})

// Initialize upgrader with a mock that fails when it's incremented.
failOnIncrementCounter := &mockPrometheusCounter{
Counter: c,
t: t,
onIncrement: func() {
require.FailNow(t, "should not have invoked")
},
}
upgrader := peer.NewTLSServerUpgrader(config, failOnIncrementCounter)

clientConfig := tls.Config{
ClientAuth: tls.RequireAnyClientCert,
InsecureSkipVerify: true, //#nosec G402
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{testCase.clientTLSCert},
}

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()
conn, err := listener.Accept()
require.NoError(t, err)

_, _, _, err = upgrader.Upgrade(conn)

if testCase.shouldSucceed {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, testCase.expectedErr)
}
}()

conn, err := tls.Dial("tcp", listener.Addr().String(), &clientConfig)
require.NoError(t, err)

require.NoError(t, conn.Handshake())

wg.Wait()
})
}
}

func nonStandardRSAKey(t *testing.T) *rsa.PrivateKey {
for {
sk, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

sk.Precomputed = rsa.PrecomputedValues{}

// We want a non-standard E, so let's use E = 257 and derive D again.
e := 257
sk.PublicKey.E = e
sk.E = e

p := sk.Primes[0]
q := sk.Primes[1]

pminus1 := new(big.Int).Sub(p, big.NewInt(1))
qminus1 := new(big.Int).Sub(q, big.NewInt(1))

phiN := big.NewInt(0).Mul(pminus1, qminus1)

sk.D = big.NewInt(0).ModInverse(big.NewInt(int64(e)), phiN)

if sk.D == nil {
// If we ended up picking a bad starting modulus, try again.
continue
}

fmt.Println("E", e)
fmt.Println("D", sk.D)
fmt.Println("N", sk.N)
fmt.Println("Primes[0]", sk.Primes[0])
fmt.Println("Primes[1]", sk.Primes[1])

return sk
}
}

func makeTLSCert(t *testing.T, privKey *rsa.PrivateKey) tls.Certificate {
x509Cert := makeRSACertAndKey(t, privKey)

rawX509PEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: x509Cert.cert.Raw})
privateKeyInDER, err := x509.MarshalPKCS8PrivateKey(x509Cert.key)
require.NoError(t, err)

privateKeyInPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyInDER})

tlsCertServer, err := tls.X509KeyPair(rawX509PEM, privateKeyInPEM)
require.NoError(t, err)

return tlsCertServer
}

type certAndKey struct {
cert x509.Certificate
key *rsa.PrivateKey
}

func makeRSACertAndKey(t *testing.T, privKey *rsa.PrivateKey) certAndKey {
// Create a self-signed cert
basicCert := basicCert()
certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &privKey.PublicKey, privKey)
require.NoError(t, err)

cert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)

return certAndKey{
cert: *cert,
key: privKey,
}
}

func basicCert() *x509.Certificate {
return &x509.Certificate{
SerialNumber: big.NewInt(0).SetInt64(100),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour).UTC(),
BasicConstraintsValid: true,
}
}

type mockPrometheusCounter struct {
t *testing.T
prometheus.Counter
onIncrement func()
}

func (m *mockPrometheusCounter) Inc() {
m.onIncrement()
}
30 changes: 18 additions & 12 deletions staking/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,8 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto
return nil, ErrInvalidRSAPublicExponent
}

if pub.N.Sign() <= 0 {
return nil, ErrRSAModulusNotPositive
}

if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen)
}
if pub.N.Bit(0) == 0 {
return nil, ErrRSAModulusIsEven
}
if pub.E != allowedRSAPublicExponentValue {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E)
if err := ValidateRSAPublicKeyIsWellFormed(pub); err != nil {
return nil, err
}
return pub, nil
case oid.Equal(oidPublicKeyECDSA):
Expand All @@ -165,3 +155,19 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto
return nil, ErrUnknownPublicKeyAlgorithm
}
}

func ValidateRSAPublicKeyIsWellFormed(pub *rsa.PublicKey) error {
if pub.N.Sign() <= 0 {
return ErrRSAModulusNotPositive
}
if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen {
return fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen)
}
if pub.N.Bit(0) == 0 {
return ErrRSAModulusIsEven
}
if pub.E != allowedRSAPublicExponentValue {
return fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E)
}
return nil
}

0 comments on commit 7d08cc6

Please sign in to comment.