Skip to content

Commit

Permalink
Added some more preconditions using hax::implies and hax::forall and …
Browse files Browse the repository at this point in the history
…slight refactoring. (#138)

* Added a whole bunch of pre- and post-conditions using hax_lib::forall and hax_lib::implies.
* Don't use map() at all.
* Fix hax.yml.

---------

Co-authored-by: Franziskus Kiefer <[email protected]>
  • Loading branch information
xvzcf and franziskuskiefer authored Dec 1, 2023
1 parent 599b4f4 commit e40b911
Show file tree
Hide file tree
Showing 12 changed files with 445 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/hax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
./hax-driver.py --kyber-reference
env FSTAR_HOME=${{ github.workspace }}/fstar \
HACL_HOME=${{ github.workspace }}/hacl-star \
HAX_LIBS_HOME=${{ github.workspace }}/hax/proof-libs/fstar \
HAX_HOME=${{ github.workspace }}/hax \
PATH="${PATH}:${{ github.workspace }}/fstar/bin" \
./hax-driver.py typecheck --admit
Expand Down
56 changes: 49 additions & 7 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ let v_MONTGOMERY_SHIFT: u8 = 16uy

let v_MONTGOMERY_R: i32 = 1l <<! v_MONTGOMERY_SHIFT

let get_montgomery_r_least_significant_bits (value: u32)
let get_n_least_significant_bits (n: u8) (value: u32)
: Prims.Pure u32
Prims.l_True
(requires n =. 4uy || n =. 5uy || n =. 10uy || n =. 11uy || n =. v_MONTGOMERY_SHIFT)
(ensures
fun result ->
let result:u32 = result in
result <. (Core.Num.impl__u32__pow 2ul (cast (v_MONTGOMERY_SHIFT <: u8) <: u32) <: u32)) =
value &. ((1ul <<! v_MONTGOMERY_SHIFT <: u32) -! 1ul <: u32)
result <. (Core.Num.impl__u32__pow 2ul (Core.Convert.f_into n <: u32) <: u32)) =
let _:Prims.unit = () <: Prims.unit in
value &. ((1ul <<! n <: u32) -! 1ul <: u32)

let barrett_reduce (value: i32)
: Prims.Pure i32
Expand Down Expand Up @@ -77,10 +78,10 @@ let montgomery_reduce (value: i32)
let _:i32 = v_MONTGOMERY_R in
let _:Prims.unit = () <: Prims.unit in
let t:u32 =
(get_montgomery_r_least_significant_bits (cast (value <: i32) <: u32) <: u32) *!
(get_n_least_significant_bits v_MONTGOMERY_SHIFT (cast (value <: i32) <: u32) <: u32) *!
v_INVERSE_OF_MODULUS_MOD_R
in
let k:i16 = cast (get_montgomery_r_least_significant_bits t <: u32) <: i16 in
let k:i16 = cast (get_n_least_significant_bits v_MONTGOMERY_SHIFT t <: u32) <: i16 in
let k_times_modulus:i32 =
(cast (k <: i16) <: i32) *! Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
in
Expand Down Expand Up @@ -113,7 +114,48 @@ type t_PolynomialRingElement = { f_coefficients:t_Array i32 (sz 256) }
let impl__PolynomialRingElement__ZERO: t_PolynomialRingElement =
{ f_coefficients = Rust_primitives.Hax.repeat 0l (sz 256) } <: t_PolynomialRingElement

let add_to_ring_element (v_K: usize) (lhs rhs: t_PolynomialRingElement) : t_PolynomialRingElement =
let add_to_ring_element (v_K: usize) (lhs rhs: t_PolynomialRingElement)
: Prims.Pure t_PolynomialRingElement
(requires
Hax_lib.v_forall (fun i ->
let i:usize = i in
Hax_lib.implies (i <. Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
<:
bool)
(((Core.Num.impl__i32__abs (lhs.f_coefficients.[ i ] <: i32) <: i32) <=.
(((cast (v_K <: usize) <: i32) -! 1l <: i32) *!
Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
i32)
<:
bool) &&
((Core.Num.impl__i32__abs (rhs.f_coefficients.[ i ] <: i32) <: i32) <=.
Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
bool))
<:
bool))
(ensures
fun result ->
let result:t_PolynomialRingElement = result in
Hax_lib.v_forall (fun i ->
let i:usize = i in
Hax_lib.implies (i <.
(Core.Slice.impl__len (Rust_primitives.unsize result.f_coefficients
<:
t_Slice i32)
<:
usize)
<:
bool)
((Core.Num.impl__i32__abs (result.f_coefficients.[ i ] <: i32) <: i32) <=.
((cast (v_K <: usize) <: i32) *! Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
i32)
<:
bool)
<:
bool)) =
let _:Prims.unit = () <: Prims.unit in
let _:Prims.unit = () <: Prims.unit in
let lhs:t_PolynomialRingElement =
Expand Down
26 changes: 13 additions & 13 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,19 @@ open FStar.Mul
let compress_message_coefficient (fe: u16)
: Prims.Pure u8
(requires fe <. (cast (Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <: i32) <: u16))
(fun _ -> Prims.l_True) =
(ensures
fun result ->
let result:u8 = result in
Hax_lib.implies ((833us <=. fe <: bool) && (fe <=. 2596us <: bool))
(result =. 1uy <: bool) &&
Hax_lib.implies (~.((833us <=. fe <: bool) && (fe <=. 2596us <: bool)) <: bool)
(result =. 0uy <: bool)) =
let (shifted: i16):i16 = 1664s -! (cast (fe <: u16) <: i16) in
let shifted_to_positive:i16 = (shifted >>! 15l <: i16) ^. shifted in
let mask:i16 = shifted >>! 15l in
let shifted_to_positive:i16 = mask ^. shifted in
let shifted_positive_in_range:i16 = shifted_to_positive -! 832s in
cast ((shifted_positive_in_range >>! 15l <: i16) &. 1s <: i16) <: u8

let get_n_least_significant_bits (n: u8) (value: u32)
: Prims.Pure u32
(requires n =. 4uy || n =. 5uy || n =. 10uy || n =. 11uy)
(ensures
fun result ->
let result:u32 = result in
result <. (Core.Num.impl__u32__pow 2ul (Core.Convert.f_into n <: u32) <: u32)) =
let _:Prims.unit = () <: Prims.unit in
value &. ((1ul <<! n <: u32) -! 1ul <: u32)

let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16)
: Prims.Pure i32
(requires
Expand All @@ -42,7 +39,10 @@ let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16)
let compressed:u32 =
compressed /! (cast (Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <<! 1l <: i32) <: u32)
in
cast (get_n_least_significant_bits coefficient_bits compressed <: u32) <: i32
cast (Libcrux.Kem.Kyber.Arithmetic.get_n_least_significant_bits coefficient_bits compressed <: u32
)
<:
i32

let decompress_ciphertext_coefficient (coefficient_bits: u8) (fe: i32)
: Prims.Pure i32
Expand Down
11 changes: 6 additions & 5 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Constant_time_ops.fst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ let is_non_zero (value: u8)
(ensures
fun result ->
let result:u8 = result in
(~.(value =. 0uy <: bool) || result =. 0uy) &&
(~.(value <>. 0uy <: bool) || result =. 1uy)) =
Hax_lib.implies (value =. 0uy <: bool) (result =. 0uy <: bool) &&
Hax_lib.implies (value <>. 0uy <: bool) (result =. 1uy <: bool)) =
let value:u16 = cast (value <: u8) <: u16 in
let result:u16 =
((value |. (Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) <: u16) >>! 8l <: u16) &.
Expand All @@ -24,7 +24,8 @@ let compare_ciphertexts_in_constant_time (v_CIPHERTEXT_SIZE: usize) (lhs rhs: t_
(ensures
fun result ->
let result:u8 = result in
(~.(lhs =. rhs <: bool) || result =. 0uy) && (~.(lhs <>. rhs <: bool) || result =. 1uy)) =
Hax_lib.implies (lhs =. rhs <: bool) (result =. 0uy <: bool) &&
Hax_lib.implies (lhs <>. rhs <: bool) (result =. 1uy <: bool)) =
let _:Prims.unit = () <: Prims.unit in
let _:Prims.unit = () <: Prims.unit in
let (r: u8):u8 = 0uy in
Expand All @@ -51,8 +52,8 @@ let select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8)
(ensures
fun result ->
let result:t_Array u8 (sz 32) = result in
(~.(selector =. 0uy <: bool) || result =. lhs) &&
(~.(selector <>. 0uy <: bool) || result =. rhs)) =
Hax_lib.implies (selector =. 0uy <: bool) (result =. lhs <: bool) &&
Hax_lib.implies (selector <>. 0uy <: bool) (result =. rhs <: bool)) =
let _:Prims.unit = () <: Prims.unit in
let _:Prims.unit = () <: Prims.unit in
let mask:u8 = Core.Num.impl__u8__wrapping_sub (is_non_zero selector <: u8) 1uy in
Expand Down
Loading

0 comments on commit e40b911

Please sign in to comment.