Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize kyber key generation #37

Merged
merged 6 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions benches/kyber768.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};

use libcrux::kem::Algorithm;
use libcrux::drbg::Drbg;
use libcrux::digest;
use libcrux::drbg::Drbg;
use libcrux::kem::Algorithm;

pub fn comparisons_key_generation(c: &mut Criterion) {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();
let mut group = c.benchmark_group("Kyber768 Key Generation");

group.bench_function("libcrux reference implementation", |b| {
b.iter(
|| {
let (_secret_key, _public_key) = libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
}
)
b.iter(|| {
let (_secret_key, _public_key) =
libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
})
});

group.bench_function("pqclean reference implementation", |b| {
Expand All @@ -30,7 +29,8 @@ pub fn comparisons_encapsulation(c: &mut Criterion) {
b.iter_batched(
|| {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();
let (_secret_key, public_key) = libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
let (_secret_key, public_key) =
libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();

(drbg, public_key)
},
Expand All @@ -50,7 +50,8 @@ pub fn comparisons_encapsulation(c: &mut Criterion) {
public_key
},
|public_key| {
let (_shared_secret, _ciphertext) = pqcrypto_kyber::kyber768::encapsulate(&public_key);
let (_shared_secret, _ciphertext) =
pqcrypto_kyber::kyber768::encapsulate(&public_key);
},
BatchSize::SmallInput,
)
Expand All @@ -64,12 +65,15 @@ pub fn comparisons_decapsulation(c: &mut Criterion) {
b.iter_batched(
|| {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();
let (secret_key, public_key) = libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
let (_shared_secret, ciphertext) = libcrux::kem::encapsulate(Algorithm::Kyber768, &public_key, &mut drbg).unwrap();
let (secret_key, public_key) =
libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
let (_shared_secret, ciphertext) =
libcrux::kem::encapsulate(Algorithm::Kyber768, &public_key, &mut drbg).unwrap();
(secret_key, ciphertext)
},
|(secret_key, ciphertext)| {
let _shared_secret = libcrux::kem::decapsulate(Algorithm::Kyber768, &ciphertext, &secret_key);
let _shared_secret =
libcrux::kem::decapsulate(Algorithm::Kyber768, &ciphertext, &secret_key);
},
BatchSize::SmallInput,
)
Expand All @@ -79,12 +83,14 @@ pub fn comparisons_decapsulation(c: &mut Criterion) {
b.iter_batched(
|| {
let (public_key, secret_key) = pqcrypto_kyber::kyber768::keypair();
let (_shared_secret, ciphertext) = pqcrypto_kyber::kyber768::encapsulate(&public_key);
let (_shared_secret, ciphertext) =
pqcrypto_kyber::kyber768::encapsulate(&public_key);

(ciphertext, secret_key)
},
|(ciphertext, secret_key)| {
let _shared_secret = pqcrypto_kyber::kyber768::decapsulate(&ciphertext, &secret_key);
let _shared_secret =
pqcrypto_kyber::kyber768::decapsulate(&ciphertext, &secret_key);
},
BatchSize::SmallInput,
)
Expand Down
11 changes: 11 additions & 0 deletions examples/kyber768_generate_keypair.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use libcrux::digest;
use libcrux::drbg::Drbg;
use libcrux::kem;

fn main() {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();

for _i in 0..100000 {
let (_secret_key, _public_key) = kem::key_gen(kem::Algorithm::Kyber768, &mut drbg).unwrap();
}
}
10 changes: 5 additions & 5 deletions src/drbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,13 @@ impl Drbg {
/// Implementation of the [`RngCore`] trait for the [`Drbg`].
impl RngCore for Drbg {
fn next_u32(&mut self) -> u32 {
let mut bytes : [u8; 4] = [0; 4];
let mut bytes: [u8; 4] = [0; 4];
self.generate(&mut bytes).unwrap();

(bytes[0] as u32) |
(bytes[1] as u32) << 8 |
(bytes[2] as u32) << 16 |
(bytes[3] as u32) << 24
(bytes[0] as u32)
| (bytes[1] as u32) << 8
| (bytes[2] as u32) << 16
| (bytes[3] as u32) << 24
}

fn next_u64(&mut self) -> u64 {
Expand Down
4 changes: 2 additions & 2 deletions src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ mod kyber768;
// (and change the visibility of the exported functions to pub(crate)) the
// moment we have an implementation of one. This is tracked by:
// https://github.com/cryspen/libcrux/issues/36
pub use kyber768::generate_keypair as kyber768_generate_keypair_derand;
pub use kyber768::encapsulate as kyber768_encapsulate_derand;
pub use kyber768::decapsulate as kyber768_decapsulate_derand;
pub use kyber768::encapsulate as kyber768_encapsulate_derand;
pub use kyber768::generate_keypair as kyber768_generate_keypair_derand;

/// KEM Algorithms
///
Expand Down
40 changes: 21 additions & 19 deletions src/kem/kyber768/ind_cpa.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::kem::kyber768::utils::{
ArrayConversion, ArrayPadding, PanickingIntegerCasts, UpdatableArray, UpdatingArray, VecUpdate,
ArrayConversion, ArrayPadding, PanickingIntegerCasts, UpdatableArray, UpdatingArray,
};

use crate::kem::kyber768::{
Expand All @@ -10,14 +10,14 @@ use crate::kem::kyber768::{
},
parameters::{
hash_functions::{G, H, PRF, XOF},
KyberPolynomialRingElement, BITS_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT,
CPA_PKE_CIPHERTEXT_SIZE, CPA_PKE_KEY_GENERATION_SEED_SIZE, CPA_PKE_MESSAGE_SIZE,
CPA_PKE_PUBLIC_KEY_SIZE, CPA_PKE_SECRET_KEY_SIZE, CPA_SERIALIZED_KEY_LEN, RANK,
REJECTION_SAMPLING_SEED_SIZE, T_AS_NTT_ENCODED_SIZE, VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_SIZE, VECTOR_V_COMPRESSION_FACTOR,
KyberPolynomialRingElement, BITS_PER_RING_ELEMENT, BYTES_PER_RING_ELEMENT,
COEFFICIENTS_IN_RING_ELEMENT, CPA_PKE_CIPHERTEXT_SIZE, CPA_PKE_KEY_GENERATION_SEED_SIZE,
CPA_PKE_MESSAGE_SIZE, CPA_PKE_PUBLIC_KEY_SIZE, CPA_PKE_SECRET_KEY_SIZE,
CPA_SERIALIZED_KEY_LEN, RANK, REJECTION_SAMPLING_SEED_SIZE, T_AS_NTT_ENCODED_SIZE,
VECTOR_U_COMPRESSION_FACTOR, VECTOR_U_SIZE, VECTOR_V_COMPRESSION_FACTOR,
},
sampling::{sample_from_binomial_distribution, sample_from_uniform_distribution},
serialize::{deserialize_little_endian, serialize_little_endian},
sampling::{sample_from_binomial_distribution_with_2_coins, sample_from_uniform_distribution},
serialize::{deserialize_little_endian, serialize_little_endian, serialize_little_endian_12},
BadRejectionSamplingRandomnessError,
};

Expand Down Expand Up @@ -52,10 +52,12 @@ impl KeyPair {
}
}

fn encode_12(input: [KyberPolynomialRingElement; RANK]) -> Vec<u8> {
let mut out = Vec::new();
for re in input.into_iter() {
out.extend_from_slice(&serialize_little_endian(re, 12));
fn encode_12(input: [KyberPolynomialRingElement; RANK]) -> [u8; RANK * BYTES_PER_RING_ELEMENT] {
let mut out = [0u8; RANK * BYTES_PER_RING_ELEMENT];

for (i, re) in input.into_iter().enumerate() {
out[i * BYTES_PER_RING_ELEMENT..(i + 1) * BYTES_PER_RING_ELEMENT]
.copy_from_slice(&serialize_little_endian_12(re));
}

out
Expand Down Expand Up @@ -93,7 +95,7 @@ pub(crate) fn generate_keypair(
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let secret = sample_from_binomial_distribution(2, &prf_output[..]);
let secret = sample_from_binomial_distribution_with_2_coins(prf_output);
secret_as_ntt[i] = ntt_representation(secret);
}

Expand All @@ -109,7 +111,7 @@ pub(crate) fn generate_keypair(
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let error = sample_from_binomial_distribution(2, &prf_output[..]);
let error = sample_from_binomial_distribution_with_2_coins(prf_output);
error_as_ntt[i] = ntt_representation(error);
}

Expand All @@ -120,13 +122,13 @@ pub(crate) fn generate_keypair(
}

// pk := (Encode_12(tˆ mod^{+}q) || ρ)
let public_key_serialized = encode_12(t_as_ntt).concat(seed_for_A);
let public_key_serialized = [&encode_12(t_as_ntt), seed_for_A].concat();

// sk := Encode_12(sˆ mod^{+}q)
let secret_key_serialized = encode_12(secret_as_ntt);

Ok(KeyPair::new(
secret_key_serialized.into_array(),
secret_key_serialized,
public_key_serialized.into_array(),
))
}
Expand Down Expand Up @@ -167,7 +169,7 @@ fn cbd(mut prf_input: [u8; 33]) -> ([KyberPolynomialRingElement; RANK], u8) {
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let r = sample_from_binomial_distribution(2, &prf_output);
let r = sample_from_binomial_distribution_with_2_coins(prf_output);
r_as_ntt[i] = ntt_representation(r);
}
(r_as_ntt, domain_separator)
Expand Down Expand Up @@ -231,14 +233,14 @@ pub(crate) fn encrypt(

// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);
error_1[i] = sample_from_binomial_distribution(2, &prf_output);
error_1[i] = sample_from_binomial_distribution_with_2_coins(prf_output);
}

// e_2 := CBD{η2}(PRF(r, N))
prf_input[32] = domain_separator;
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);
let error_2 = sample_from_binomial_distribution(2, &prf_output);
let error_2 = sample_from_binomial_distribution_with_2_coins(prf_output);

// u := NTT^{-1}(AˆT ◦ rˆ) + e_1
let mut u = multiply_matrix_by_column(&A_transpose, &r_as_ntt).map(|r| invert_ntt(r));
Expand Down
51 changes: 0 additions & 51 deletions src/kem/kyber768/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
self, KyberFieldElement, KyberPolynomialRingElement, COEFFICIENTS_IN_RING_ELEMENT,
};

/// [ pow(17, br(i), p) for 0 <= i < 128 ]
/// br(i) is the bit reversal of i regarded as a 7-bit number.
const ZETAS: [u16; 128] = [
1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
Expand All @@ -23,8 +21,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
2154,
];

/// [ pow(17, 2 * br(i) + 1, p) for 0 <= i < 128 ]
/// br(i) is the bit reversal of i regarded as a 7-bit number.
const MOD_ROOTS: [u16; 128] = [
17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229,
1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279,
Expand All @@ -39,26 +35,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {

const NTT_LAYERS: [usize; 7] = [2, 4, 8, 16, 32, 64, 128];

/// Use the Cooley–Tukey butterfly to compute an in-place NTT representation
franziskuskiefer marked this conversation as resolved.
Show resolved Hide resolved
/// of a `KyberPolynomialRingElement`.
///
/// This can be seen (see [CFRG draft]) as 128 applications of the linear map CT where
///
/// CT_i(a, b) => (a + zeta^i * b, a - zeta^i * b) mod q
///
/// for the appropriate i.
///
/// Because the Kyber base field has 256th roots of unity but not 512th roots
/// of unity, the resulting NTT representation is an element in:
///
/// ```plaintext
/// Product(i = 0 to 255) F_{3329}[x] / (x^2 - zeta^{2i+1}),
/// ```
///
/// This is isomorphic to `F_{3329}[x] / (x^{256} + 1)` by the
/// Chinese Remainder Theorem.
///
/// [CFRG draft]: <https://datatracker.ietf.org/doc/draft-cfrg-schwabe-kyber/>
pub fn ntt_representation(mut re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let mut zeta_i = 0;
for layer in NTT_LAYERS.iter().rev() {
Expand All @@ -76,17 +52,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
re
}

/// Use the Gentleman-Sande butterfly to invert, in-place, the NTT representation
/// of a `KyberPolynomialRingElement`. The inverse NTT can be computed (see [CFRG draft]) by
/// replacing CS_i by GS_j and
///
/// ```plaintext
/// GS_j(a, b) => ( (a + b) / 2, zeta^{2*j + 1} * (a - b) / 2 ) mod q
/// ```
///
/// for the appropriate j.
///
/// [CFRG draft]: https://datatracker.ietf.org/doc/draft-cfrg-schwabe-kyber/
pub fn invert_ntt(re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let inverse_of_2: KyberFieldElement =
KyberFieldElement::new((parameters::FIELD_MODULUS + 1) / 2);
Expand Down Expand Up @@ -114,22 +79,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
out
}

/// Two elements `a, b ∈ F_{3329}[x] / (x^2 - zeta^{2i+1})` in the Kyber NTT
/// domain:
///
/// ```plaintext
/// a = a_0 + a_1 * x
/// b = b_0 + b_1 * x
/// ```
///
/// can be multiplied as follows:
///
/// ```plaintext
/// (a_2 * x + a_1)(b_2 * x + b_1) =
/// (a_0 * b_0 + a_1 * b_1 * zeta^{2i + 1}) + (a_0 * b_1 + a_1 * b_0) * x
/// ```
///
/// for the appropriate i.
pub fn ntt_multiply(
left: &KyberPolynomialRingElement,
other: &KyberPolynomialRingElement,
Expand Down
5 changes: 4 additions & 1 deletion src/kem/kyber768/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ pub(crate) const BITS_PER_COEFFICIENT: usize = 12;
/// Coefficients per ring element
pub(crate) const COEFFICIENTS_IN_RING_ELEMENT: usize = 256;

/// Bits required per ring element
/// Bits required per (uncompressed) ring element
pub(crate) const BITS_PER_RING_ELEMENT: usize = COEFFICIENTS_IN_RING_ELEMENT * 12;

/// Bytes required per (uncompressed) ring element
pub(crate) const BYTES_PER_RING_ELEMENT: usize = BITS_PER_RING_ELEMENT / 8;

/// Seed size for rejection sampling.
///
/// See <https://eprint.iacr.org/2023/708> for some background regarding
Expand Down
39 changes: 21 additions & 18 deletions src/kem/kyber768/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::kem::kyber768::{
parameters::{self, KyberFieldElement, KyberPolynomialRingElement},
utils::bit_vector::LittleEndianBitStream,
BadRejectionSamplingRandomnessError,
};

Expand Down Expand Up @@ -37,28 +36,32 @@ pub fn sample_from_uniform_distribution(
Err(BadRejectionSamplingRandomnessError)
}

pub fn sample_from_binomial_distribution(
sampling_coins: usize,
randomness: &[u8],
pub fn sample_from_binomial_distribution_with_2_coins(
randomness: [u8; 128],
) -> KyberPolynomialRingElement {
assert_eq!(randomness.len(), sampling_coins * 64);

let mut sampled: KyberPolynomialRingElement = KyberPolynomialRingElement::ZERO;

for i in 0..sampled.len() {
let mut coin_tosses: u8 = 0;
for j in 0..sampling_coins {
coin_tosses += randomness.nth_bit(2 * i * sampling_coins + j);
}
let coin_tosses_a: KyberFieldElement = coin_tosses.into();
for (chunk_number, byte_chunk) in randomness.chunks_exact(4).enumerate() {
let random_bits_as_u32: u32 = (byte_chunk[0] as u32)
| (byte_chunk[1] as u32) << 8
| (byte_chunk[2] as u32) << 16
| (byte_chunk[3] as u32) << 24;

coin_tosses = 0;
for j in 0..sampling_coins {
coin_tosses += randomness.nth_bit(2 * i * sampling_coins + sampling_coins + j);
}
let coin_tosses_b: KyberFieldElement = coin_tosses.into();
let even_bits = random_bits_as_u32 & 0x55555555;
let odd_bits = (random_bits_as_u32 >> 1) & 0x55555555;

let coin_toss_outcomes = even_bits + odd_bits;

sampled[i] = coin_tosses_a - coin_tosses_b;
for outcome_set in (0..u32::BITS).step_by(4) {
let outcome_1: u16 = ((coin_toss_outcomes >> outcome_set) & 0x3) as u16;
let outcome_1: KyberFieldElement = outcome_1.into();

let outcome_2: u16 = ((coin_toss_outcomes >> (outcome_set + 2)) & 0x3) as u16;
let outcome_2: KyberFieldElement = outcome_2.into();

let offset = usize::try_from(outcome_set >> 2).unwrap();
sampled[8 * chunk_number + offset] = outcome_1 - outcome_2;
}
}

sampled
Expand Down
Loading