Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More kyber code refactoring #135

Merged
merged 24 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
df46ee4
Break out matrix multiplication code in ntt.rs.
xvzcf Nov 20, 2023
da0f6cc
KyberPolynomialRingElement -> PolynomialRingElement.
xvzcf Nov 20, 2023
0e4d182
KyberFieldElement -> FieldElement.
xvzcf Nov 20, 2023
81369f8
Added some types related to montgomery reduction
xvzcf Nov 20, 2023
0f79ee7
FieldElement -> StandardFieldElement.
xvzcf Nov 20, 2023
c9bd0db
More annotating with type aliases.
xvzcf Nov 21, 2023
710f698
cbd -> sample_vector_cbd_then_ntt.
xvzcf Nov 21, 2023
1da97ad
Make message compression and decompression constant time.
xvzcf Nov 22, 2023
24be1ba
Refine preconditions in compress.
xvzcf Nov 22, 2023
2c3b985
Convert (invert_)ntt_at_layer macros to functions.
xvzcf Nov 22, 2023
9731506
Review comments.
xvzcf Nov 23, 2023
96787ce
Revert "Convert (invert_)ntt_at_layer macros to functions."
xvzcf Nov 23, 2023
7579051
Re-extract Kyber fstar code with latest karthik/core-for-kyber hax co…
xvzcf Nov 23, 2023
19044b8
Fix fstar lax-typechecking.
xvzcf Nov 24, 2023
0637fca
Add comments to ntt.rs.
xvzcf Nov 24, 2023
aaba346
Add lax typechecking to CI.
xvzcf Nov 24, 2023
c23ba2e
3329 -> FIELD_MODULUS in ntt.rs and update paths in hax.yml
xvzcf Nov 24, 2023
ce756a6
Add cfg guard in ntt.rs
xvzcf Nov 24, 2023
180e67c
Fixed silly mistake in hax.yml
xvzcf Nov 24, 2023
063f95f
Get Fstar binaries in hax.yml.
xvzcf Nov 24, 2023
f50ef7e
Remove spurious slash in hax.yml
xvzcf Nov 24, 2023
48c942d
Debug hax ci
xvzcf Nov 24, 2023
7bbb472
Debug hax ci
xvzcf Nov 24, 2023
be406dd
CI should work now?
xvzcf Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ open Core
open FStar.Mul

unfold
let t_FieldElementTimesMontgomeryR = i32
let t_FieldElement = i32

unfold
let t_MontgomeryFieldElement = i32
let t_FieldElementTimesMontgomeryR = i32

unfold
let t_StandardFieldElement = i32
let t_MontgomeryFieldElement = i32

let v_BARRETT_MULTIPLIER: i64 = 20159L

Expand All @@ -20,6 +20,8 @@ let v_BARRETT_R: i64 = 1L <<! v_BARRETT_SHIFT

let v_INVERSE_OF_MODULUS_MOD_R: u32 = 62209ul

let v_MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS: i32 = 1353l

let v_MONTGOMERY_SHIFT: u8 = 16uy

let v_MONTGOMERY_R: i32 = 1l <<! v_MONTGOMERY_SHIFT
Expand Down Expand Up @@ -88,7 +90,8 @@ let montgomery_reduce (value: i32)

let montgomery_multiply_sfe_by_fer (fe fer: i32) : i32 = montgomery_reduce (fe *! fer <: i32)

let to_standard_domain (mfe: i32) : i32 = montgomery_reduce (mfe *! 1353l <: i32)
let to_standard_domain (mfe: i32) : i32 =
montgomery_reduce (mfe *! v_MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS <: i32)

let to_unsigned_representative (fe: i32)
: Prims.Pure u16
Expand Down Expand Up @@ -134,7 +137,7 @@ let add_to_ring_element (v_K: usize) (lhs rhs: t_PolynomialRingElement) : t_Poly
lhs with
f_coefficients
=
Rust_primitives.Hax.update_at lhs.f_coefficients
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize lhs.f_coefficients
i
((lhs.f_coefficients.[ i ] <: i32) +! (rhs.f_coefficients.[ i ] <: i32) <: i32)
<:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ let select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8)
(fun out i ->
let out:t_Array u8 (sz 32) = out in
let i:usize = i in
Rust_primitives.Hax.update_at out
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
i
((out.[ i ] <: u8) |.
(((lhs.[ i ] <: u8) &. mask <: u8) |. ((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8)
Expand Down
4 changes: 2 additions & 2 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Conversions.fst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ let into_padded_array (v_LEN: usize) (slice: t_Slice u8) : t_Array u8 v_LEN =
in
let out:t_Array u8 v_LEN = Rust_primitives.Hax.repeat 0uy v_LEN in
let out:t_Array u8 v_LEN =
Rust_primitives.Hax.update_at out
Rust_primitives.Hax.Monomorphized_update_at.update_at_range out
({ Core.Ops.Range.f_start = sz 0; Core.Ops.Range.f_end = Core.Slice.impl__len slice <: usize }
<:
Core.Ops.Range.t_Range usize)
Expand Down Expand Up @@ -60,7 +60,7 @@ let impl_1 (v_LEN: usize) : t_UpdatingArray (t_UpdatableArray v_LEN) =
self with
f_value
=
Rust_primitives.Hax.update_at self.f_value
Rust_primitives.Hax.Monomorphized_update_at.update_at_range self.f_value
({
Core.Ops.Range.f_start = self.f_pointer;
Core.Ops.Range.f_end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ let v_XOFx4 (v_LEN v_K: usize) (input: t_Array (t_Array u8 (sz 34)) v_K)
(fun out i ->
let out:t_Array (t_Array u8 v_LEN) v_K = out in
let i:usize = i in
Rust_primitives.Hax.update_at out
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
i
(Libcrux.Digest.shake128 v_LEN
(Rust_primitives.unsize (input.[ i ] <: t_Array u8 (sz 34)) <: t_Slice u8)
Expand Down
48 changes: 29 additions & 19 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Ind_cpa.fst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ let sample_vector_cbd_then_ntt
in
let i:usize = i in
let prf_input:t_Array u8 (sz 33) =
Rust_primitives.Hax.update_at prf_input (sz 32) domain_separator
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize prf_input
(sz 32)
domain_separator
in
let domain_separator:u8 = domain_separator +! 1uy in
let (prf_output: t_Array u8 v_ETA_RANDOMNESS_SIZE):t_Array u8 v_ETA_RANDOMNESS_SIZE =
Expand All @@ -73,7 +75,7 @@ let sample_vector_cbd_then_ntt
(Rust_primitives.unsize prf_output <: t_Slice u8)
in
let re_as_ntt:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at re_as_ntt
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re_as_ntt
i
(Libcrux.Kem.Kyber.Ntt.ntt_binomially_sampled_ring_element r
<:
Expand Down Expand Up @@ -138,18 +140,22 @@ let sample_matrix_A (v_K: usize) (seed: t_Array u8 (sz 34)) (transpose: bool)
let seeds:t_Array (t_Array u8 (sz 34)) v_K = seeds in
let j:usize = j in
let seeds:t_Array (t_Array u8 (sz 34)) v_K =
Rust_primitives.Hax.update_at seeds
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize seeds
j
(Rust_primitives.Hax.update_at (seeds.[ j ] <: t_Array u8 (sz 34))
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (seeds.[ j ]
<:
t_Array u8 (sz 34))
(sz 32)
(cast (i <: usize) <: u8)
<:
t_Array u8 (sz 34))
in
let seeds:t_Array (t_Array u8 (sz 34)) v_K =
Rust_primitives.Hax.update_at seeds
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize seeds
j
(Rust_primitives.Hax.update_at (seeds.[ j ] <: t_Array u8 (sz 34))
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (seeds.[ j ]
<:
t_Array u8 (sz 34))
(sz 33)
(cast (j <: usize) <: u8)
<:
Expand Down Expand Up @@ -197,9 +203,10 @@ let sample_matrix_A (v_K: usize) (seed: t_Array u8 (sz 34)) (transpose: bool)
then
let v_A_transpose:t_Array
(t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K) v_K =
Rust_primitives.Hax.update_at v_A_transpose
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize v_A_transpose
j
(Rust_primitives.Hax.update_at (v_A_transpose.[ j ]
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (v_A_transpose.[ j
]
<:
t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K)
i
Expand All @@ -214,9 +221,10 @@ let sample_matrix_A (v_K: usize) (seed: t_Array u8 (sz 34)) (transpose: bool)
else
let v_A_transpose:t_Array
(t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K) v_K =
Rust_primitives.Hax.update_at v_A_transpose
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize v_A_transpose
i
(Rust_primitives.Hax.update_at (v_A_transpose.[ i ]
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (v_A_transpose.[ i
]
<:
t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K)
j
Expand Down Expand Up @@ -254,7 +262,7 @@ let compress_then_encode_u
(fun out temp_1_ ->
let out:t_Array u8 v_OUT_LEN = out in
let i, re:(usize & Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) = temp_1_ in
Rust_primitives.Hax.update_at out
Rust_primitives.Hax.Monomorphized_update_at.update_at_range out
({
Core.Ops.Range.f_start = i *! (v_OUT_LEN /! v_K <: usize) <: usize;
Core.Ops.Range.f_end = (i +! sz 1 <: usize) *! (v_OUT_LEN /! v_K <: usize) <: usize
Expand Down Expand Up @@ -306,7 +314,7 @@ let serialize_key
(fun out temp_1_ ->
let out:t_Array u8 v_OUT_LEN = out in
let i, re:(usize & Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) = temp_1_ in
Rust_primitives.Hax.update_at out
Rust_primitives.Hax.Monomorphized_update_at.update_at_range out
({
Core.Ops.Range.f_start
=
Expand Down Expand Up @@ -391,7 +399,7 @@ let decrypt
u_bytes
in
let u_as_ntt:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at u_as_ntt
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize u_as_ntt
i
(Libcrux.Kem.Kyber.Ntt.ntt_vector_u v_U_COMPRESSION_FACTOR u
<:
Expand Down Expand Up @@ -425,7 +433,7 @@ let decrypt
secret_as_ntt
in
let i, secret_bytes:(usize & t_Slice u8) = temp_1_ in
Rust_primitives.Hax.update_at secret_as_ntt
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize secret_as_ntt
i
(Libcrux.Kem.Kyber.Serialize.deserialize_to_uncompressed_ring_element secret_bytes
<:
Expand Down Expand Up @@ -471,7 +479,7 @@ let encrypt
tt_as_ntt
in
let i, tt_as_ntt_bytes:(usize & t_Slice u8) = temp_1_ in
Rust_primitives.Hax.update_at tt_as_ntt
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize tt_as_ntt
i
(Libcrux.Kem.Kyber.Serialize.deserialize_to_uncompressed_ring_element tt_as_ntt_bytes
<:
Expand Down Expand Up @@ -524,15 +532,17 @@ let encrypt
in
let i:usize = i in
let prf_input:t_Array u8 (sz 33) =
Rust_primitives.Hax.update_at prf_input (sz 32) domain_separator
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize prf_input
(sz 32)
domain_separator
in
let domain_separator:u8 = domain_separator +! 1uy in
let (prf_output: t_Array u8 v_ETA2_RANDOMNESS_SIZE):t_Array u8 v_ETA2_RANDOMNESS_SIZE =
Libcrux.Kem.Kyber.Hash_functions.v_PRF v_ETA2_RANDOMNESS_SIZE
(Rust_primitives.unsize prf_input <: t_Slice u8)
in
let error_1_:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at error_1_
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize error_1_
i
(Libcrux.Kem.Kyber.Sampling.sample_from_binomial_distribution v_ETA2
(Rust_primitives.unsize prf_output <: t_Slice u8)
Expand All @@ -545,7 +555,7 @@ let encrypt
t_Array u8 (sz 33)))
in
let prf_input:t_Array u8 (sz 33) =
Rust_primitives.Hax.update_at prf_input (sz 32) domain_separator
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize prf_input (sz 32) domain_separator
in
let (prf_output: t_Array u8 v_ETA2_RANDOMNESS_SIZE):t_Array u8 v_ETA2_RANDOMNESS_SIZE =
Libcrux.Kem.Kyber.Hash_functions.v_PRF v_ETA2_RANDOMNESS_SIZE
Expand Down Expand Up @@ -581,7 +591,7 @@ let encrypt
(Rust_primitives.unsize c1 <: t_Slice u8)
in
let ciphertext:t_Array u8 v_CIPHERTEXT_SIZE =
Rust_primitives.Hax.update_at ciphertext
Rust_primitives.Hax.Monomorphized_update_at.update_at_range_from ciphertext
({ Core.Ops.Range.f_start = v_C1_LEN } <: Core.Ops.Range.t_RangeFrom usize)
(Core.Slice.impl__copy_from_slice (ciphertext.[ { Core.Ops.Range.f_start = v_C1_LEN }
<:
Expand Down
56 changes: 15 additions & 41 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Matrix.fst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ let compute_As_plus_e
(s_as_ntt.[ j ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
in
let result:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at result
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
i
(Libcrux.Kem.Kyber.Arithmetic.add_to_ring_element v_K
(result.[ i ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
Expand All @@ -73,14 +73,7 @@ let compute_As_plus_e
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end
=
Core.Slice.impl__len (Rust_primitives.unsize (result.[ i ]
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
<:
t_Slice i32)
<:
usize
Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
}
<:
Core.Ops.Range.t_Range usize)
Expand All @@ -100,13 +93,13 @@ let compute_As_plus_e
<:
i32)
in
Rust_primitives.Hax.update_at result
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
i
({
(result.[ i ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.update_at (result.[ i ]
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (result.[ i ]
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
Expand Down Expand Up @@ -168,14 +161,7 @@ let compute_message
let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end
=
Core.Slice.impl__len (Rust_primitives.unsize result
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
<:
t_Slice i32)
<:
usize
Core.Ops.Range.f_end = Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
}
<:
Core.Ops.Range.t_Range usize)
Expand All @@ -199,7 +185,8 @@ let compute_message
result with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.update_at result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
i
(Libcrux.Kem.Kyber.Arithmetic.barrett_reduce ((v
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ]
Expand Down Expand Up @@ -256,14 +243,7 @@ let compute_ring_element_v
let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end
=
Core.Slice.impl__len (Rust_primitives.unsize result
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
<:
t_Slice i32)
<:
usize
Core.Ops.Range.f_end = Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
}
<:
Core.Ops.Range.t_Range usize)
Expand All @@ -287,7 +267,8 @@ let compute_ring_element_v
result with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.update_at result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
i
(Libcrux.Kem.Kyber.Arithmetic.barrett_reduce ((coefficient_normal_form +!
(error_2_.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] <: i32)
Expand Down Expand Up @@ -361,7 +342,7 @@ let compute_vector_u
(r_as_ntt.[ j ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
in
let result:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at result
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
i
(Libcrux.Kem.Kyber.Arithmetic.add_to_ring_element v_K
(result.[ i ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
Expand All @@ -372,7 +353,7 @@ let compute_vector_u
result)
in
let result:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at result
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
i
(Libcrux.Kem.Kyber.Ntt.invert_ntt_montgomery v_K
(result.[ i ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
Expand All @@ -383,14 +364,7 @@ let compute_vector_u
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end
=
Core.Slice.impl__len (Rust_primitives.unsize (result.[ i ]
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
<:
t_Slice i32)
<:
usize
Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
}
<:
Core.Ops.Range.t_Range usize)
Expand All @@ -414,13 +388,13 @@ let compute_vector_u
i32)
in
let result:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.update_at result
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize result
i
({
(result.[ i ] <: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.update_at (result.[ i ]
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (result.[ i ]
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
Expand Down
Loading