Skip to content

Commit

Permalink
Refactoring Kyber reference implementation. (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf authored Sep 29, 2023
1 parent fdc4a1a commit 2d41bdc
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 44 deletions.
19 changes: 11 additions & 8 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber768.Compress.fst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ let compress
=
Core.Array.map_under_impl_23 re
.Libcrux.Kem.Kyber768.Arithmetic.KyberPolynomialRingElement.f_coefficients
(fun coefficient -> compress_q coefficient bits_per_compressed_coefficient <: i32)
(fun coefficient ->
compress_q (Libcrux.Kem.Kyber768.Conversions.to_unsigned_representative coefficient
<:
u16)
bits_per_compressed_coefficient
<:
i32)
}
in
re
Expand All @@ -34,25 +40,22 @@ let decompress
in
re

let compress_q (fe: i32) (to_bit_size: usize) : i32 =
let compress_q (fe: u16) (to_bit_size: usize) : i32 =
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.(to_bit_size <=. Libcrux.Kem.Kyber768.Parameters.v_BITS_PER_COEFFICIENT <: bool)
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: to_bit_size <= parameters::BITS_PER_COEFFICIENT"
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: to_bit_size <= BITS_PER_COEFFICIENT"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let two_pow_bit_size:u32 = 1ul >>. to_bit_size in
let fe_unsigned:i32 =
fe +. ((fe <<. 15l <: i32) &. Libcrux.Kem.Kyber768.Parameters.v_FIELD_MODULUS <: i32)
in
let compressed:u32 = cast fe_unsigned *. (two_pow_bit_size >>. 1l <: u32) in
let compressed:u32 = cast fe *. (two_pow_bit_size >>. 1l <: u32) in
let compressed:Prims.unit = compressed +. cast Libcrux.Kem.Kyber768.Parameters.v_FIELD_MODULUS in
let compressed:Prims.unit =
compressed /. cast (Libcrux.Kem.Kyber768.Parameters.v_FIELD_MODULUS >>. 1l <: i32)
Expand All @@ -66,7 +69,7 @@ let decompress_q (fe: i32) (to_bit_size: usize) : i32 =
let _:Prims.unit =
if ~.(to_bit_size <=. Libcrux.Kem.Kyber768.Parameters.v_BITS_PER_COEFFICIENT <: bool)
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: to_bit_size <= parameters::BITS_PER_COEFFICIENT"
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: to_bit_size <= BITS_PER_COEFFICIENT"

<:
Rust_primitives.Hax.t_Never)
Expand Down
5 changes: 4 additions & 1 deletion proofs/fstar/extraction/Libcrux.Kem.Kyber768.Conversions.fst
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,7 @@ let impl (#len: usize) : t_UpdatingArray (t_UpdatableArray v_LEN) =
}
in
self
}
}

let to_unsigned_representative (fe: i32) : u16 =
cast (fe +. ((fe <<. 15l <: i32) &. Libcrux.Kem.Kyber768.Parameters.v_FIELD_MODULUS <: i32))
22 changes: 8 additions & 14 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber768.Serialize.fst
Original file line number Diff line number Diff line change
Expand Up @@ -353,34 +353,28 @@ let serialize_little_endian_12_ (re: Libcrux.Kem.Kyber768.Arithmetic.t_KyberPoly
_)
serialized
(fun serialized (i, chunks) ->
let coefficient1:i32 = chunks.[ 0sz ] in
let coefficient1:Prims.unit =
coefficient1 +.
((coefficient1 <<. 15l <: i32) &. Libcrux.Kem.Kyber768.Parameters.v_FIELD_MODULUS <: i32
)
let coefficient1:u16 =
Libcrux.Kem.Kyber768.Conversions.to_unsigned_representative (chunks.[ 0sz ] <: i32)
in
let coefficient2:i32 = chunks.[ 1sz ] in
let coefficient2:Prims.unit =
coefficient2 +.
((coefficient2 <<. 15l <: i32) &. Libcrux.Kem.Kyber768.Parameters.v_FIELD_MODULUS <: i32
)
let coefficient2:u16 =
Libcrux.Kem.Kyber768.Conversions.to_unsigned_representative (chunks.[ 1sz ] <: i32)
in
let serialized:array u8 384sz =
Rust_primitives.Hax.update_at serialized
(3sz *. i <: usize)
(cast (coefficient1 &. 255l <: i32))
(cast (coefficient1 &. 255us <: u16))
in
let serialized:array u8 384sz =
Rust_primitives.Hax.update_at serialized
((3sz *. i <: usize) +. 1sz <: usize)
(cast ((coefficient1 <<. 8l <: i32) |. ((coefficient2 &. 15l <: i32) >>. 4l <: i32)
(cast ((coefficient1 <<. 8l <: u16) |. ((coefficient2 &. 15us <: u16) >>. 4l <: u16)
<:
i32))
u16))
in
let serialized:array u8 384sz =
Rust_primitives.Hax.update_at serialized
((3sz *. i <: usize) +. 2sz <: usize)
(cast ((coefficient2 <<. 4l <: i32) &. 255l <: i32))
(cast ((coefficient2 <<. 4l <: u16) &. 255us <: u16))
in
serialized)
in
Expand Down
30 changes: 15 additions & 15 deletions src/kem/kyber768/compress.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use crate::kem::kyber768::{
arithmetic::{KyberFieldElement, KyberPolynomialRingElement},
parameters::{self, FIELD_MODULUS},
conversions::to_unsigned_representative,
parameters::{BITS_PER_COEFFICIENT, FIELD_MODULUS},
};

pub fn compress(
mut re: KyberPolynomialRingElement,
bits_per_compressed_coefficient: usize,
) -> KyberPolynomialRingElement {
re.coefficients = re
.coefficients
.map(|coefficient| compress_q(coefficient, bits_per_compressed_coefficient));
re.coefficients = re.coefficients.map(|coefficient| {
compress_q(
to_unsigned_representative(coefficient),
bits_per_compressed_coefficient,
)
});
re
}

Expand All @@ -23,26 +27,22 @@ pub fn decompress(
re
}

fn compress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement {
debug_assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);
fn compress_q(fe: u16, to_bit_size: usize) -> KyberFieldElement {
debug_assert!(to_bit_size <= BITS_PER_COEFFICIENT);

let two_pow_bit_size = 1u32 << to_bit_size;

// Convert from canonical signed representative to canonical unsigned
// representative.
let fe_unsigned = fe + ((fe >> 15) & FIELD_MODULUS);

let mut compressed = (fe_unsigned as u32) * (two_pow_bit_size << 1);
compressed += parameters::FIELD_MODULUS as u32;
compressed /= (parameters::FIELD_MODULUS << 1) as u32;
let mut compressed = (fe as u32) * (two_pow_bit_size << 1);
compressed += FIELD_MODULUS as u32;
compressed /= (FIELD_MODULUS << 1) as u32;

(compressed & (two_pow_bit_size - 1)) as KyberFieldElement
}

fn decompress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement {
debug_assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);
debug_assert!(to_bit_size <= BITS_PER_COEFFICIENT);

let mut decompressed = (fe as u32) * (parameters::FIELD_MODULUS as u32);
let mut decompressed = (fe as u32) * (FIELD_MODULUS as u32);
decompressed = (decompressed << 1) + (1 << to_bit_size);
decompressed >>= to_bit_size + 1;

Expand Down
7 changes: 7 additions & 0 deletions src/kem/kyber768/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::kem::kyber768::{arithmetic::KyberFieldElement, parameters::FIELD_MODULUS};

#[inline(always)]
pub(super) fn into_padded_array<const LEN: usize>(slice: &[u8]) -> [u8; LEN] {
debug_assert!(slice.len() <= LEN);
Expand Down Expand Up @@ -35,3 +37,8 @@ impl<const LEN: usize> UpdatingArray for UpdatableArray<LEN> {
self
}
}

#[inline(always)]
pub(crate) fn to_unsigned_representative(fe: KyberFieldElement) -> u16 {
(fe + ((fe >> 15) & FIELD_MODULUS)) as u16
}
10 changes: 4 additions & 6 deletions src/kem/kyber768/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::kem::kyber768::{
arithmetic::{KyberFieldElement, KyberPolynomialRingElement},
parameters::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS},
conversions::to_unsigned_representative,
parameters::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT},
};

/// This file contains instantiations of the functions
Expand Down Expand Up @@ -141,11 +142,8 @@ pub fn serialize_little_endian_12(re: KyberPolynomialRingElement) -> [u8; BYTES_
let mut serialized = [0u8; BYTES_PER_RING_ELEMENT];

for (i, chunks) in re.coefficients.chunks_exact(2).enumerate() {
let mut coefficient1 = chunks[0];
coefficient1 += (coefficient1 >> 15) & FIELD_MODULUS;

let mut coefficient2 = chunks[1];
coefficient2 += (coefficient2 >> 15) & FIELD_MODULUS;
let coefficient1 = to_unsigned_representative(chunks[0]);
let coefficient2 = to_unsigned_representative(chunks[1]);

serialized[3 * i] = (coefficient1 & 0xFF) as u8;
serialized[3 * i + 1] = ((coefficient1 >> 8) | ((coefficient2 & 0xF) << 4)) as u8;
Expand Down

0 comments on commit 2d41bdc

Please sign in to comment.