Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reordered functions in Kyber reference so definitions come before calls. #113

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions src/kem/kyber/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ use super::{
conversions::to_unsigned_representative,
};

fn compress_q<const COEFFICIENT_BITS: usize>(fe: u16) -> KyberFieldElement {
debug_assert!(COEFFICIENT_BITS <= BITS_PER_COEFFICIENT);

let mut compressed = (fe as u32) << (COEFFICIENT_BITS + 1);
compressed += FIELD_MODULUS as u32;
compressed /= (FIELD_MODULUS << 1) as u32;

(compressed & ((1u32 << COEFFICIENT_BITS) - 1)) as KyberFieldElement
}
pub fn compress<const COEFFICIENT_BITS: usize>(
mut re: KyberPolynomialRingElement,
) -> KyberPolynomialRingElement {
Expand All @@ -13,27 +22,6 @@ pub fn compress<const COEFFICIENT_BITS: usize>(
re
}

pub fn decompress<const COEFFICIENT_BITS: usize>(
mut re: KyberPolynomialRingElement,
) -> KyberPolynomialRingElement {
re.coefficients = re
.coefficients
.map(|coefficient| decompress_q::<COEFFICIENT_BITS>(coefficient));
re
}

fn compress_q<const COEFFICIENT_BITS: usize>(fe: u16) -> KyberFieldElement {
debug_assert!(COEFFICIENT_BITS <= BITS_PER_COEFFICIENT);

let two_pow_bit_size = 1u32 << COEFFICIENT_BITS;

let mut compressed = (fe as u32) * (two_pow_bit_size << 1);
compressed += FIELD_MODULUS as u32;
compressed /= (FIELD_MODULUS << 1) as u32;

(compressed & (two_pow_bit_size - 1)) as KyberFieldElement
}

fn decompress_q<const COEFFICIENT_BITS: usize>(fe: KyberFieldElement) -> KyberFieldElement {
debug_assert!(COEFFICIENT_BITS <= BITS_PER_COEFFICIENT);

Expand All @@ -43,3 +31,11 @@ fn decompress_q<const COEFFICIENT_BITS: usize>(fe: KyberFieldElement) -> KyberFi

decompressed as KyberFieldElement
}
pub fn decompress<const COEFFICIENT_BITS: usize>(
mut re: KyberPolynomialRingElement,
) -> KyberPolynomialRingElement {
re.coefficients = re
.coefficients
.map(|coefficient| decompress_q::<COEFFICIENT_BITS>(coefficient));
re
}
25 changes: 12 additions & 13 deletions src/kem/kyber/ind_cpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ use super::{

// The PKE Private Key
impl_generic_struct!(PrivateKey);
pub fn serialize_secret_key<const SERIALIZED_KEY_LEN: usize>(
private_key: &[u8],
public_key: &[u8],
implicit_rejection_value: &[u8],
) -> [u8; SERIALIZED_KEY_LEN] {
UpdatableArray::new([0u8; SERIALIZED_KEY_LEN])
.push(private_key)
.push(public_key)
.push(&H(public_key))
.push(implicit_rejection_value)
.array()
}

#[inline(always)]
#[allow(non_snake_case)]
Expand Down Expand Up @@ -168,19 +180,6 @@ pub(crate) fn generate_keypair<
)
}

pub fn serialize_secret_key<const SERIALIZED_KEY_LEN: usize>(
private_key: &[u8],
public_key: &[u8],
implicit_rejection_value: &[u8],
) -> [u8; SERIALIZED_KEY_LEN] {
UpdatableArray::new([0u8; SERIALIZED_KEY_LEN])
.push(private_key)
.push(public_key)
.push(&H(public_key))
.push(implicit_rejection_value)
.array()
}

fn compress_then_encode_u<
const K: usize,
const OUT_LEN: usize,
Expand Down
26 changes: 13 additions & 13 deletions src/kem/kyber/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,6 @@ pub fn sample_from_uniform_distribution<const SEED_SIZE: usize>(
(out, Some(BadRejectionSamplingRandomnessError))
}

#[inline(always)]
pub(super) fn sample_from_binomial_distribution<const ETA: usize>(
randomness: &[u8],
) -> KyberPolynomialRingElement {
debug_assert_eq!(randomness.len(), ETA * 64);

match ETA as u32 {
2 => sample_from_binomial_distribution_2(randomness),
3 => sample_from_binomial_distribution_3(randomness),
_ => unreachable!("factor {ETA}"),
}
}

/// Given a series of uniformly random bytes in `|randomness|`, sample
/// a ring element from a binomial distribution centered at 0 that uses two sets
/// of `|sampling_coins|` coin flips. If, for example,
Expand Down Expand Up @@ -133,3 +120,16 @@ fn sample_from_binomial_distribution_3(randomness: &[u8]) -> KyberPolynomialRing

sampled
}

#[inline(always)]
pub(super) fn sample_from_binomial_distribution<const ETA: usize>(
randomness: &[u8],
) -> KyberPolynomialRingElement {
debug_assert_eq!(randomness.len(), ETA * 64);

match ETA as u32 {
2 => sample_from_binomial_distribution_2(randomness),
3 => sample_from_binomial_distribution_3(randomness),
_ => unreachable!("factor {ETA}"),
}
}
100 changes: 50 additions & 50 deletions src/kem/kyber/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,56 +40,6 @@ use super::{
/// function itself; the rest don't since they are called only after `compress_q`
/// is called, and `compress_q` also performs this conversion.

#[inline(always)]
pub(super) fn serialize_little_endian<const COMPRESSION_FACTOR: usize, const OUT_LEN: usize>(
re: KyberPolynomialRingElement,
) -> [u8; OUT_LEN] {
debug_assert!(
(COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN,
"{} != {}",
(COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8,
OUT_LEN
);

match COMPRESSION_FACTOR as u32 {
1 => serialize_little_endian_1(re),
// VECTOR_V_COMPRESSION_FACTOR_768 & VECTOR_V_COMPRESSION_FACTOR_512
4 => serialize_little_endian_4(re),
// VECTOR_V_COMPRESSION_FACTOR_1024
5 => serialize_little_endian_5(re),
// VECTOR_U_COMPRESSION_FACTOR_768 & VECTOR_U_COMPRESSION_FACTOR_512
10 => serialize_little_endian_10(re),
// VECTOR_U_COMPRESSION_FACTOR_1024
11 => serialize_little_endian_11(re),
12 => serialize_little_endian_12(re),
_ => unreachable!("factor {COMPRESSION_FACTOR}"),
}
}

#[inline(always)]
pub(super) fn deserialize_little_endian<const COMPRESSION_FACTOR: usize>(
serialized: &[u8],
) -> KyberPolynomialRingElement {
debug_assert_eq!(
serialized.len(),
(COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8
);

match COMPRESSION_FACTOR as u32 {
1 => deserialize_little_endian_1(serialized),
// VECTOR_V_COMPRESSION_FACTOR_768 & VECTOR_V_COMPRESSION_FACTOR_512
4 => deserialize_little_endian_4(serialized),
// VECTOR_V_COMPRESSION_FACTOR_1024
5 => deserialize_little_endian_5(serialized),
// VECTOR_U_COMPRESSION_FACTOR_768 & VECTOR_U_COMPRESSION_FACTOR_512
10 => deserialize_little_endian_10(serialized),
// VECTOR_U_COMPRESSION_FACTOR_1024
11 => deserialize_little_endian_11(serialized),
12 => deserialize_little_endian_12(serialized),
_ => unreachable!("factor {COMPRESSION_FACTOR}"),
}
}

#[inline(always)]
fn serialize_little_endian_1<const OUT_LEN: usize>(
re: KyberPolynomialRingElement,
Expand Down Expand Up @@ -350,3 +300,53 @@ fn deserialize_little_endian_12(serialized: &[u8]) -> KyberPolynomialRingElement

re
}

#[inline(always)]
pub(super) fn serialize_little_endian<const COMPRESSION_FACTOR: usize, const OUT_LEN: usize>(
re: KyberPolynomialRingElement,
) -> [u8; OUT_LEN] {
debug_assert!(
(COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN,
"{} != {}",
(COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8,
OUT_LEN
);

match COMPRESSION_FACTOR as u32 {
1 => serialize_little_endian_1(re),
// VECTOR_V_COMPRESSION_FACTOR_768 & VECTOR_V_COMPRESSION_FACTOR_512
4 => serialize_little_endian_4(re),
// VECTOR_V_COMPRESSION_FACTOR_1024
5 => serialize_little_endian_5(re),
// VECTOR_U_COMPRESSION_FACTOR_768 & VECTOR_U_COMPRESSION_FACTOR_512
10 => serialize_little_endian_10(re),
// VECTOR_U_COMPRESSION_FACTOR_1024
11 => serialize_little_endian_11(re),
12 => serialize_little_endian_12(re),
_ => unreachable!("factor {COMPRESSION_FACTOR}"),
}
}

#[inline(always)]
pub(super) fn deserialize_little_endian<const COMPRESSION_FACTOR: usize>(
serialized: &[u8],
) -> KyberPolynomialRingElement {
debug_assert_eq!(
serialized.len(),
(COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8
);

match COMPRESSION_FACTOR as u32 {
1 => deserialize_little_endian_1(serialized),
// VECTOR_V_COMPRESSION_FACTOR_768 & VECTOR_V_COMPRESSION_FACTOR_512
4 => deserialize_little_endian_4(serialized),
// VECTOR_V_COMPRESSION_FACTOR_1024
5 => deserialize_little_endian_5(serialized),
// VECTOR_U_COMPRESSION_FACTOR_768 & VECTOR_U_COMPRESSION_FACTOR_512
10 => deserialize_little_endian_10(serialized),
// VECTOR_U_COMPRESSION_FACTOR_1024
11 => deserialize_little_endian_11(serialized),
12 => deserialize_little_endian_12(serialized),
_ => unreachable!("factor {COMPRESSION_FACTOR}"),
}
}