Skip to content

Commit

Permalink
fix: Properly initialize RNG context (#33)
Browse files Browse the repository at this point in the history
Fixes an issue where using some ciphers would cause Mbedtls to fail with `MBEDTLS_ERR_SSL_NO_RNG`
  • Loading branch information
AnthonyGrondin authored Jul 16, 2024
1 parent 5fb0ca7 commit c58cad2
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions esp-mbedtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ impl<'a> Certificates<'a> {
min_version: TlsVersion,
) -> Result<
(
*mut mbedtls_ctr_drbg_context,
*mut mbedtls_ssl_context,
*mut mbedtls_ssl_config,
*mut mbedtls_x509_crt,
Expand All @@ -258,9 +259,16 @@ impl<'a> Certificates<'a> {
unsafe {
error_checked!(psa_crypto_init())?;

let drbg_context = calloc(1, size_of::<mbedtls_ctr_drbg_context>() as u32)
as *mut mbedtls_ctr_drbg_context;
if drbg_context.is_null() {
return Err(TlsError::OutOfMemory);
}

let ssl_context =
calloc(1, size_of::<mbedtls_ssl_context>() as u32) as *mut mbedtls_ssl_context;
if ssl_context.is_null() {
free(drbg_context as *const _);
return Err(TlsError::OutOfMemory);
}

Expand All @@ -273,6 +281,7 @@ impl<'a> Certificates<'a> {

let crt = calloc(1, size_of::<mbedtls_x509_crt>() as u32) as *mut mbedtls_x509_crt;
if crt.is_null() {
free(drbg_context as *const _);
free(ssl_context as *const _);
free(ssl_config as *const _);
return Err(TlsError::OutOfMemory);
Expand All @@ -281,6 +290,7 @@ impl<'a> Certificates<'a> {
let certificate =
calloc(1, size_of::<mbedtls_x509_crt>() as u32) as *mut mbedtls_x509_crt;
if certificate.is_null() {
free(drbg_context as *const _);
free(ssl_context as *const _);
free(ssl_config as *const _);
free(crt as *const _);
Expand All @@ -290,6 +300,7 @@ impl<'a> Certificates<'a> {
let private_key =
calloc(1, size_of::<mbedtls_pk_context>() as u32) as *mut mbedtls_pk_context;
if private_key.is_null() {
free(drbg_context as *const _);
free(ssl_context as *const _);
free(ssl_config as *const _);
free(crt as *const _);
Expand All @@ -306,7 +317,9 @@ impl<'a> Certificates<'a> {
// Initialize private key
mbedtls_pk_init(private_key);
(*ssl_config).private_f_dbg = Some(dbg_print);
(*ssl_config).private_f_rng = Some(rng);
// Init RNG
mbedtls_ctr_drbg_init(drbg_context);
mbedtls_ssl_conf_rng(ssl_config, Some(rng), drbg_context as *mut c_void);

error_checked!(mbedtls_ssl_config_defaults(
ssl_config,
Expand Down Expand Up @@ -379,13 +392,21 @@ impl<'a> Certificates<'a> {

mbedtls_ssl_conf_ca_chain(ssl_config, crt, core::ptr::null_mut());
error_checked!(mbedtls_ssl_setup(ssl_context, ssl_config))?;
Ok((ssl_context, ssl_config, crt, certificate, private_key))
Ok((
drbg_context,
ssl_context,
ssl_config,
crt,
certificate,
private_key,
))
}
}
}

pub struct Session<T> {
stream: T,
drbg_context: *mut mbedtls_ctr_drbg_context,
ssl_context: *mut mbedtls_ssl_context,
ssl_config: *mut mbedtls_ssl_config,
crt: *mut mbedtls_x509_crt,
Expand Down Expand Up @@ -420,10 +441,11 @@ impl<T> Session<T> {
min_version: TlsVersion,
certificates: Certificates,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
return Ok(Self {
stream,
drbg_context,
ssl_context,
ssl_config,
crt,
Expand Down Expand Up @@ -561,11 +583,13 @@ impl<T> Drop for Session<T> {
RSA_REF = core::mem::transmute(None::<RSA>);
}
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ctr_drbg_free(self.drbg_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
mbedtls_x509_crt_free(self.crt);
mbedtls_x509_crt_free(self.client_crt);
mbedtls_pk_free(self.private_key);
free(self.drbg_context as *const _);
free(self.ssl_config as *const _);
free(self.ssl_context as *const _);
free(self.crt as *const _);
Expand Down Expand Up @@ -627,6 +651,7 @@ pub mod asynch {

pub struct Session<T, const BUFFER_SIZE: usize = 4096> {
stream: T,
drbg_context: *mut mbedtls_ctr_drbg_context,
ssl_context: *mut mbedtls_ssl_context,
ssl_config: *mut mbedtls_ssl_config,
crt: *mut mbedtls_x509_crt,
Expand Down Expand Up @@ -663,10 +688,11 @@ pub mod asynch {
min_version: TlsVersion,
certificates: Certificates,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
return Ok(Self {
stream,
drbg_context,
ssl_context,
ssl_config,
crt,
Expand Down Expand Up @@ -704,11 +730,13 @@ pub mod asynch {
RSA_REF = core::mem::transmute(None::<RSA>);
}
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ctr_drbg_free(self.drbg_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
mbedtls_x509_crt_free(self.crt);
mbedtls_x509_crt_free(self.client_crt);
mbedtls_pk_free(self.private_key);
free(self.drbg_context as *const _);
free(self.ssl_config as *const _);
free(self.ssl_context as *const _);
free(self.crt as *const _);
Expand Down

0 comments on commit c58cad2

Please sign in to comment.