Skip to content

Commit

Permalink
Validate improper TLS RSA certificates early-on
Browse files Browse the repository at this point in the history
This commit makes validation of TLS certificates with either too big RSA keys, or the wrong exponent, fail as soon as the remote node presents its TLS certificate, in contrast to after the TLS handshake.

Signed-off-by: Yacov Manevich <[email protected]>
  • Loading branch information
yacovm committed Oct 24, 2024
1 parent dcb2f70 commit 60e973f
Show file tree
Hide file tree
Showing 5 changed files with 398 additions and 12 deletions.
29 changes: 29 additions & 0 deletions network/peer/tls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
package peer

import (
"crypto/rsa"
"crypto/tls"
"errors"
"io"

"github.com/ava-labs/avalanchego/staking"
)

// TLSConfig returns the TLS config that will allow secure connections to other
Expand All @@ -26,5 +30,30 @@ func TLSConfig(cert tls.Certificate, keyLogWriter io.Writer) *tls.Config {
InsecureSkipVerify: true, //#nosec G402
MinVersion: tls.VersionTLS13,
KeyLogWriter: keyLogWriter,
VerifyConnection: ValidateRSACertificate,
}
}

// ValidateRSACertificate validates TLS certificates
// with RSA public keys in the leaf of the certificate chain of the given connection state.
func ValidateRSACertificate(cs tls.ConnectionState) error {
if len(cs.PeerCertificates) == 0 {
return errors.New("no certificates sent by peer")
}

if cs.PeerCertificates[0] == nil {
return errors.New("certificate sent by peer is empty")
}

pk := cs.PeerCertificates[0].PublicKey
if pk == nil {
return errors.New("no public key sent by peer")
}

switch rsaKey := pk.(type) {
case *rsa.PublicKey:
return staking.ValidateRSAPublicKeyIsWellFormed(rsaKey)
default:
return nil
}
}
73 changes: 73 additions & 0 deletions network/peer/tls_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package peer_test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/network/peer"
)

func TestValidateRSACertificate(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

x509Cert := makeRSACertAndKey(t, key)

x509CertWithNoPK := makeRSACertAndKey(t, key)
x509CertWithNoPK.cert.PublicKey = nil

ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

basicCert := basicCert()
certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &ecKey.PublicKey, ecKey)
require.NoError(t, err)

ecCert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)

for _, testCase := range []struct {
description string
input tls.ConnectionState
expectedErr error
}{
{
description: "Valid TLS cert",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509Cert.cert}},
},
{
description: "No TLS certs given",
input: tls.ConnectionState{},
expectedErr: errors.New("no certificates sent by peer"),
},
{
description: "No TLS certs given",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{nil}},
expectedErr: errors.New("certificate sent by peer is empty"),
},
{
description: "No public key in the cert given",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509CertWithNoPK.cert}},
expectedErr: errors.New("no public key sent by peer"),
},
{
description: "EC cert",
input: tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}},
},
} {
t.Run(testCase.description, func(t *testing.T) {
require.Equal(t, testCase.expectedErr, peer.ValidateRSACertificate(testCase.input))
})
}
}
212 changes: 212 additions & 0 deletions network/peer/upgrader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package peer_test

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"net"
"sync"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/network/peer"
"github.com/ava-labs/avalanchego/staking"
)

func TestBlockClientsWithIncorrectRSAKeys(t *testing.T) {
privKey2048, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

privKey4096, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)

privKey8192, err := rsa.GenerateKey(rand.Reader, 8192)
require.NoError(t, err)

clientCert2048 := makeTLSCert(t, privKey2048)
clientCert4096 := makeTLSCert(t, privKey4096)
clientCert8192 := makeTLSCert(t, privKey8192)
clientCertBad := makeTLSCert(t, nonStandardRSAKey(t))

for _, testCase := range []struct {
description string
clientTLSCert tls.Certificate
shouldSucceed bool
expectedErr error
}{
{
description: "Proper key size and private key - 2048",
clientTLSCert: clientCert2048,
shouldSucceed: true,
},
{
description: "Proper key size and private key - 4096",
clientTLSCert: clientCert4096,
shouldSucceed: true,
},
{
description: "Too big key",
clientTLSCert: clientCert8192,
expectedErr: staking.ErrUnsupportedRSAModulusBitLen,
},
{
description: "Improper public exponent",
clientTLSCert: clientCertBad,
expectedErr: staking.ErrUnsupportedRSAPublicExponent,
},
} {
t.Run(testCase.description, func(t *testing.T) {
serverKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

serverCert := makeTLSCert(t, serverKey)

config := peer.TLSConfig(serverCert, nil)

c := prometheus.NewCounter(prometheus.CounterOpts{})

// Initialize upgrader with a mock that fails when it's incremented.
failOnIncrementCounter := &mockPrometheusCounter{
Counter: c,
t: t,
onIncrement: func() {
require.FailNow(t, "should not have invoked")
},
}
upgrader := peer.NewTLSServerUpgrader(config, failOnIncrementCounter)

clientConfig := tls.Config{
ClientAuth: tls.RequireAnyClientCert,
InsecureSkipVerify: true, //#nosec G402
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{testCase.clientTLSCert},
}

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()
conn, err := listener.Accept()
require.NoError(t, err)

_, _, _, err = upgrader.Upgrade(conn)

if testCase.shouldSucceed {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, testCase.expectedErr)
}
}()

conn, err := tls.Dial("tcp", listener.Addr().String(), &clientConfig)
require.NoError(t, err)

require.NoError(t, conn.Handshake())

wg.Wait()
})
}
}

func nonStandardRSAKey(t *testing.T) *rsa.PrivateKey {
for {
sk, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

// This speed up RSA operations, and was initialized during the key-gen.
// If we wish to override the key parameters we need to nullify this,
// otherwise the signer will use these values and the verifier will use
// the values we override, and verification will fail.
sk.Precomputed = rsa.PrecomputedValues{}

// We want a non-standard E, so let's use E = 257 and derive D again.
e := 257
sk.PublicKey.E = e
sk.E = e

p := sk.Primes[0]
q := sk.Primes[1]

pminus1 := new(big.Int).Sub(p, big.NewInt(1))
qminus1 := new(big.Int).Sub(q, big.NewInt(1))

phiN := big.NewInt(0).Mul(pminus1, qminus1)

sk.D = big.NewInt(0).ModInverse(big.NewInt(int64(e)), phiN)

if sk.D == nil {
// If we ended up picking a bad starting modulus, try again.
continue
}

return sk
}
}

func makeTLSCert(t *testing.T, privKey *rsa.PrivateKey) tls.Certificate {
x509Cert := makeRSACertAndKey(t, privKey)

rawX509PEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: x509Cert.cert.Raw})
privateKeyInDER, err := x509.MarshalPKCS8PrivateKey(x509Cert.key)
require.NoError(t, err)

privateKeyInPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyInDER})

tlsCertServer, err := tls.X509KeyPair(rawX509PEM, privateKeyInPEM)
require.NoError(t, err)

return tlsCertServer
}

type certAndKey struct {
cert x509.Certificate
key *rsa.PrivateKey
}

func makeRSACertAndKey(t *testing.T, privKey *rsa.PrivateKey) certAndKey {
// Create a self-signed cert
basicCert := basicCert()
certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &privKey.PublicKey, privKey)
require.NoError(t, err)

cert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)

return certAndKey{
cert: *cert,
key: privKey,
}
}

func basicCert() *x509.Certificate {
return &x509.Certificate{
SerialNumber: big.NewInt(0).SetInt64(100),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour).UTC(),
BasicConstraintsValid: true,
}
}

type mockPrometheusCounter struct {
t *testing.T
prometheus.Counter
onIncrement func()
}

func (m *mockPrometheusCounter) Inc() {
m.onIncrement()
}
31 changes: 19 additions & 12 deletions staking/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,8 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto
return nil, ErrInvalidRSAPublicExponent
}

if pub.N.Sign() <= 0 {
return nil, ErrRSAModulusNotPositive
}

if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen)
}
if pub.N.Bit(0) == 0 {
return nil, ErrRSAModulusIsEven
}
if pub.E != allowedRSAPublicExponentValue {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E)
if err := ValidateRSAPublicKeyIsWellFormed(pub); err != nil {
return nil, err
}
return pub, nil
case oid.Equal(oidPublicKeyECDSA):
Expand All @@ -165,3 +155,20 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto
return nil, ErrUnknownPublicKeyAlgorithm
}
}

// ValidateRSAPublicKeyIsWellFormed validates the given RSA public key
func ValidateRSAPublicKeyIsWellFormed(pub *rsa.PublicKey) error {
if pub.N.Sign() <= 0 {
return ErrRSAModulusNotPositive
}
if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen {
return fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen)
}
if pub.N.Bit(0) == 0 {
return ErrRSAModulusIsEven
}
if pub.E != allowedRSAPublicExponentValue {
return fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E)
}
return nil
}
Loading

0 comments on commit 60e973f

Please sign in to comment.