Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Refactor Hasher" #134

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions crypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,28 @@ pub enum CryptoError {
CryptoLibError,
Size,
NotImplemented,
NotInitialized,
HashError,
}

pub trait Hasher: Sized {}
pub trait Hasher: Sized {
/// Adds a chunk to the running hash.
///
/// # Arguments
///
/// * `bytes` - Value to add to hash.
fn update(&mut self, bytes: &[u8]) -> Result<(), CryptoError>;

/// Finish a running hash operation and return the result.
///
/// Once this function has been called, the object can no longer be used and
/// a new one must be created to hash more data.
fn finish(self) -> Result<Digest, CryptoError>;
}

pub type Digest = CryptoBuf;

pub trait Crypto {
type Cdi;
type HashCtx;
type Hasher: Hasher;
type PrivKey;

/// Fills the buffer with random values.
Expand All @@ -67,9 +78,9 @@ pub trait Crypto {
/// * `algs` - Which length of algorithm to use.
/// * `bytes` - Value to be hashed.
fn hash(&mut self, algs: AlgLen, bytes: &[u8]) -> Result<Digest, CryptoError> {
let mut hash_ctx = self.hash_initialize(algs)?;
self.hash_update(&mut hash_ctx, bytes)?;
self.hash_finish(&mut hash_ctx)
let mut hasher = self.hash_initialize(algs)?;
hasher.update(bytes)?;
hasher.finish()
}

/// Compute the serial number of an ECDSA public key by computing the hash
Expand All @@ -92,11 +103,11 @@ pub trait Crypto {
return Err(CryptoError::CryptoLibError);
}

let mut hash_ctx = self.hash_initialize(algs)?;
self.hash_update(&mut hash_ctx, &[0x4u8])?;
self.hash_update(&mut hash_ctx, pub_key.x.bytes())?;
self.hash_update(&mut hash_ctx, pub_key.y.bytes())?;
let digest = self.hash_finish(&mut hash_ctx)?;
let mut hasher = self.hash_initialize(algs)?;
hasher.update(&[0x4u8])?;
hasher.update(pub_key.x.bytes())?;
hasher.update(pub_key.y.bytes())?;
let digest = hasher.finish()?;

let mut w = BufWriter {
buf: serial,
Expand All @@ -112,20 +123,7 @@ pub trait Crypto {
/// # Arguments
///
/// * `algs` - Which length of algorithm to use.
fn hash_initialize(&mut self, algs: AlgLen) -> Result<Self::HashCtx, CryptoError>;

/// Adds a chunk to the running hash.
///
/// # Arguments
///
/// * `bytes` - Value to add to hash.
fn hash_update(&mut self, ctx: &mut Self::HashCtx, bytes: &[u8]) -> Result<(), CryptoError>;

/// Finish a running hash operation and return the result.
///
/// Once this function has been called, the object can no longer be used and
/// a new one must be created to hash more data.
fn hash_finish(&mut self, ctx: &mut Self::HashCtx) -> Result<Digest, CryptoError>;
fn hash_initialize(&mut self, algs: AlgLen) -> Result<Self::Hasher, CryptoError>;

/// Derive a CDI based on the current base CDI and measurements
///
Expand Down
49 changes: 23 additions & 26 deletions crypto/src/openssl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Licensed under the Apache-2.0 license

use crate::{AlgLen, Crypto, CryptoBuf, CryptoError, Digest, EcdsaPub, HmacSig};
use crate::{AlgLen, Crypto, CryptoBuf, CryptoError, Digest, EcdsaPub, Hasher, HmacSig};
use hkdf::Hkdf;
use openssl::{
bn::{BigNum, BigNumContext},
Expand All @@ -14,6 +14,23 @@ use openssl::{
};
use sha2::{Sha256, Sha384};

pub struct OpensslHasher(openssl::hash::Hasher, AlgLen);

impl Hasher for OpensslHasher {
fn update(&mut self, bytes: &[u8]) -> Result<(), CryptoError> {
self.0
.update(bytes)
.map_err(|_| CryptoError::CryptoLibError)
}

fn finish(mut self) -> Result<Digest, CryptoError> {
Digest::new(
&self.0.finish().map_err(|_| CryptoError::CryptoLibError)?,
self.1,
)
}
}

pub struct OpensslCrypto;

impl OpensslCrypto {
Expand Down Expand Up @@ -57,14 +74,9 @@ type OpensslCdi = Vec<u8>;

type OpensslPrivKey = CryptoBuf;

pub struct OpensslHasher {
hasher: openssl::hash::Hasher,
algs: AlgLen,
}

impl Crypto for OpensslCrypto {
type Cdi = OpensslCdi;
type HashCtx = OpensslHasher;
type Hasher = OpensslHasher;
type PrivKey = OpensslPrivKey;

#[cfg(feature = "deterministic_rand")]
Expand All @@ -80,27 +92,12 @@ impl Crypto for OpensslCrypto {
openssl::rand::rand_bytes(dst).map_err(|_| CryptoError::CryptoLibError)
}

fn hash_initialize(&mut self, algs: AlgLen) -> Result<Self::HashCtx, CryptoError> {
fn hash_initialize(&mut self, algs: AlgLen) -> Result<Self::Hasher, CryptoError> {
let md = Self::get_digest(algs);
Ok(OpensslHasher {
hasher: openssl::hash::Hasher::new(md).map_err(|_| CryptoError::CryptoLibError)?,
Ok(OpensslHasher(
openssl::hash::Hasher::new(md).map_err(|_| CryptoError::CryptoLibError)?,
algs,
})
}

fn hash_update(&mut self, ctx: &mut Self::HashCtx, bytes: &[u8]) -> Result<(), CryptoError> {
ctx.hasher
.update(bytes)
.map_err(|_| CryptoError::CryptoLibError)
}

fn hash_finish(&mut self, ctx: &mut Self::HashCtx) -> Result<Digest, CryptoError> {
Digest::new(
&ctx.hasher
.finish()
.map_err(|_| CryptoError::CryptoLibError)?,
ctx.algs,
)
))
}

fn derive_cdi(
Expand Down
93 changes: 35 additions & 58 deletions dpe/src/dpe_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
DPE_PROFILE, INTERNAL_INPUT_INFO_SIZE, MAX_HANDLES,
};
use core::mem::size_of;
use crypto::{Crypto, Digest};
use crypto::{Crypto, Digest, Hasher};
use platform::{Platform, MAX_CHUNK_SIZE};
use zerocopy::AsBytes;

Expand Down Expand Up @@ -263,20 +263,17 @@ impl DpeInstance<'_> {
}

// Derive the new TCI as HASH(TCI_CUMULATIVE || INPUT_DATA).
let mut hash_ctx = env
let mut hasher = env
.crypto()
.hash_initialize(DPE_PROFILE.alg_len())
.map_err(|_| DpeErrorCode::HashError)?;
env.crypto()
.hash_update(&mut hash_ctx, &context.tci.tci_cumulative.0)
hasher
.update(&context.tci.tci_cumulative.0)
.map_err(|_| DpeErrorCode::HashError)?;
env.crypto()
.hash_update(&mut hash_ctx, &measurement.0)
.map_err(|_| DpeErrorCode::HashError)?;
let digest = env
.crypto()
.hash_finish(&mut hash_ctx)
hasher
.update(&measurement.0)
.map_err(|_| DpeErrorCode::HashError)?;
let digest = hasher.finish().map_err(|_| DpeErrorCode::HashError)?;

context.tci.tci_cumulative.0.copy_from_slice(digest.bytes());
context.tci.tci_current = *measurement;
Expand Down Expand Up @@ -316,7 +313,7 @@ impl DpeInstance<'_> {
env: &mut impl DpeEnv,
start_idx: usize,
) -> Result<Digest, DpeErrorCode> {
let mut hash_ctx = env
let mut hasher = env
.crypto()
.hash_initialize(DPE_PROFILE.alg_len())
.map_err(|_| DpeErrorCode::HashError)?;
Expand All @@ -330,8 +327,8 @@ impl DpeInstance<'_> {

let mut tci_bytes = [0u8; size_of::<TciNodeData>()];
let len = context.tci.serialize(&mut tci_bytes)?;
env.crypto()
.hash_update(&mut hash_ctx, &tci_bytes[..len])
hasher
.update(&tci_bytes[..len])
.map_err(|_| DpeErrorCode::HashError)?;

// Check if any context uses internal inputs
Expand All @@ -343,11 +340,8 @@ impl DpeInstance<'_> {
if uses_internal_input_info {
let mut internal_input_info = [0u8; INTERNAL_INPUT_INFO_SIZE];
self.serialize_internal_input_info(env, &mut internal_input_info)?;
env.crypto()
.hash_update(
&mut hash_ctx,
&internal_input_info[..INTERNAL_INPUT_INFO_SIZE],
)
hasher
.update(&internal_input_info[..INTERNAL_INPUT_INFO_SIZE])
.map_err(|_| DpeErrorCode::HashError)?;
}

Expand All @@ -359,16 +353,14 @@ impl DpeInstance<'_> {
env.platform()
.get_certificate_chain(offset, MAX_CHUNK_SIZE as u32, &mut cert_chunk)
{
env.crypto()
.hash_update(&mut hash_ctx, &cert_chunk[..len as usize])
hasher
.update(&cert_chunk[..len as usize])
.map_err(|_| DpeErrorCode::HashError)?;
offset += len;
}
}

env.crypto()
.hash_finish(&mut hash_ctx)
.map_err(|_| DpeErrorCode::HashError)
hasher.finish().map_err(|_| DpeErrorCode::HashError)
}
}

Expand Down Expand Up @@ -560,12 +552,10 @@ pub mod tests {
assert_eq!(data, context.tci.tci_current.0);

// Compute cumulative.
let mut hash_ctx = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
env.crypto()
.hash_update(&mut hash_ctx, &[0; DPE_PROFILE.get_hash_size()])
.unwrap();
env.crypto().hash_update(&mut hash_ctx, &data).unwrap();
let first_cumulative = env.crypto().hash_finish(&mut hash_ctx).unwrap();
let mut hasher = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
hasher.update(&[0; DPE_PROFILE.get_hash_size()]).unwrap();
hasher.update(&data).unwrap();
let first_cumulative = hasher.finish().unwrap();

// Make sure the cumulative was computed correctly.
assert_eq!(first_cumulative.bytes(), context.tci.tci_cumulative.0);
Expand All @@ -577,12 +567,10 @@ pub mod tests {
let context = &dpe.contexts[0];
assert_eq!(data, context.tci.tci_current.0);

let mut hash_ctx = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
env.crypto()
.hash_update(&mut hash_ctx, first_cumulative.bytes())
.unwrap();
env.crypto().hash_update(&mut hash_ctx, &data).unwrap();
let second_cumulative = env.crypto().hash_finish(&mut hash_ctx).unwrap();
let mut hasher = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
hasher.update(first_cumulative.bytes()).unwrap();
hasher.update(&data).unwrap();
let second_cumulative = hasher.finish().unwrap();

// Make sure the cumulative was computed correctly.
assert_eq!(second_cumulative.bytes(), context.tci.tci_cumulative.0);
Expand Down Expand Up @@ -680,19 +668,17 @@ pub mod tests {
last_cdi = curr_cdi;
}

let mut hash_ctx = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
let mut hasher = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
let leaf_idx = dpe
.get_active_context_pos(&ContextHandle::default(), TEST_LOCALITIES[0])
.unwrap();

for result in ChildToRootIter::new(leaf_idx, &dpe.contexts) {
let context = result.unwrap();
env.crypto()
.hash_update(&mut hash_ctx, context.tci.as_bytes())
.unwrap();
hasher.update(context.tci.as_bytes()).unwrap();
}

let digest = env.crypto().hash_finish(&mut hash_ctx).unwrap();
let digest = hasher.finish().unwrap();
let answer = env
.crypto()
.derive_cdi(DPE_PROFILE.alg_len(), &digest, b"DPE")
Expand Down Expand Up @@ -738,23 +724,18 @@ pub mod tests {
let context = &dpe.contexts[parent_context_idx];
assert!(context.uses_internal_input_info);

let mut hash_ctx = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
let mut hasher = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();

env.crypto()
.hash_update(&mut hash_ctx, context.tci.as_bytes())
.unwrap();
hasher.update(context.tci.as_bytes()).unwrap();
let mut internal_input_info = [0u8; INTERNAL_INPUT_INFO_SIZE];
dpe.serialize_internal_input_info(&mut env, &mut internal_input_info)
.unwrap();

env.crypto()
.hash_update(
&mut hash_ctx,
&internal_input_info[..INTERNAL_INPUT_INFO_SIZE],
)
hasher
.update(&internal_input_info[..INTERNAL_INPUT_INFO_SIZE])
.unwrap();

let digest = env.crypto().hash_finish(&mut hash_ctx).unwrap();
let digest = hasher.finish().unwrap();
let answer = env
.crypto()
.derive_cdi(DPE_PROFILE.alg_len(), &digest, b"DPE")
Expand Down Expand Up @@ -800,16 +781,12 @@ pub mod tests {
let context = &dpe.contexts[parent_context_idx];
assert!(context.uses_internal_input_dice);

let mut hash_ctx = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();
let mut hasher = env.crypto().hash_initialize(DPE_PROFILE.alg_len()).unwrap();

env.crypto()
.hash_update(&mut hash_ctx, context.tci.as_bytes())
.unwrap();
env.crypto()
.hash_update(&mut hash_ctx, &TEST_CERT_CHAIN[..MAX_CHUNK_SIZE])
.unwrap();
hasher.update(context.tci.as_bytes()).unwrap();
hasher.update(&TEST_CERT_CHAIN[..MAX_CHUNK_SIZE]).unwrap();

let digest = env.crypto().hash_finish(&mut hash_ctx).unwrap();
let digest = hasher.finish().unwrap();
let answer = env
.crypto()
.derive_cdi(DPE_PROFILE.alg_len(), &digest, b"DPE")
Expand Down