From a856f672351a0019423cc259e2a2f6835a34402a Mon Sep 17 00:00:00 2001 From: mamonet Date: Thu, 16 May 2024 10:21:36 +0300 Subject: [PATCH] Implement lemma_reverse in Vale.Math.Poly2.Galois module --- vale/code/lib/math/Vale.Math.Poly2.Galois.fst | 223 ++++++++++++++++++ .../code/lib/math/Vale.Math.Poly2.Galois.fsti | 5 + 2 files changed, 228 insertions(+) diff --git a/vale/code/lib/math/Vale.Math.Poly2.Galois.fst b/vale/code/lib/math/Vale.Math.Poly2.Galois.fst index 2fe1a8261a..8eb831b631 100644 --- a/vale/code/lib/math/Vale.Math.Poly2.Galois.fst +++ b/vale/code/lib/math/Vale.Math.Poly2.Galois.fst @@ -589,3 +589,226 @@ let lemma_mul f a b = lemma_fmul_fmul f a b; PL.lemma_mod_small (to_poly (G.fmul a b)) m; () + +let reverse_iter (f:G.field) : G.felem f -> i:nat{i < I.bits f.t} -> G.felem f -> G.felem f = + fun a i u -> + I.logor u (I.shift_left (I.logand (I.shift_right a (I.size i)) (G.one #f)) (I.size (I.bits f.t - 1 - i))) + +let rec reverse_rec (f:G.field) (a:G.felem f) (n:nat{n <= I.bits f.t}) : G.felem f = + if n = 0 then G.zero #f + else + let u = reverse_rec f a (n - 1) in + reverse_iter f a (n - 1) u + +let g_reverse (f:G.field) (a:G.felem f) : G.felem f = + reverse_rec f a (I.bits f.t) + +let f_reverse (f:G.field) (a:G.felem f) : G.felem f = + Lib.LoopCombinators.repeati (I.bits f.t) + (reverse_iter f a) + (G.zero #f) + +#reset-options "--z3rlimit 20" +let lemma_f_g_reverse (f:G.field) (a:G.felem f) : Lemma + (f_reverse f a == g_reverse f a) + = + let pred (n:nat{n <= I.bits f.t}) (pab:G.felem f) : Type0 = reverse_rec f a n == pab in + let _ = Lib.LoopCombinators.repeati_inductive' (I.bits f.t) pred (reverse_iter f a) (G.zero #f) in + () + +#reset-options "--initial_ifuel 1" +let lemma_f_reverse (f:G.field) (a:G.felem f) : Lemma + (G.reverse a == f_reverse f a) + = + let repeati = Lib.LoopCombinators.repeati in + let acc0 = G.zero #f in + let rec lem (n:nat{n <= I.bits f.t}) (f1:(i:nat{i < n} -> G.felem f -> G.felem f)) : Lemma + (requires (forall (i:nat{i < n}) (pab:G.felem f). f1 i pab == reverse_iter f a i pab)) + (ensures repeati n (reverse_iter f a) acc0 == repeati n f1 acc0) + [SMTPat (repeati n f1 acc0)] + = + if n = 0 then + ( + let pred (n:nat) (pab:G.felem f) : Type0 = n == 0 ==> pab == acc0 in + let _ = Lib.LoopCombinators.repeati_inductive' 0 pred (reverse_iter f a) acc0 in + let _ = Lib.LoopCombinators.repeati_inductive' 0 pred f1 acc0 in + () + ) + else + ( + lem (n - 1) f1; + Lib.LoopCombinators.unfold_repeati n (reverse_iter f a) acc0 (n - 1); + Lib.LoopCombinators.unfold_repeati n f1 acc0 (n - 1); + assert (repeati n (reverse_iter f a) acc0 == repeati n f1 acc0); + () + ) + in + () + +let rec s_reverse_rec (a:poly) (n:nat) (n':nat) : Tot (poly) (decreases n') = + if n' = 0 then + let a' = if a.[0] then one else zero in + shift a' n + else + let u = s_reverse_rec a n (n' - 1) in + let a' = if a.[n'] then one else zero in + let a = shift a' (n - n') in + u |. a + +let s_reverse (a:poly) (n:nat) : poly = + s_reverse_rec a n n + +[@"opaque_to_smt"] +let reverse_def (a:poly) (n:nat) : Pure poly + (requires True) + (ensures fun p -> + poly_length p <= n + 1 /\ + (forall (i:nat).{:pattern p.[i]} i <= n ==> p.[i] == a.[n - i]) + ) + = + of_fun (n + 1) (fun (i:nat) -> a.[n - i]) + +let lemma_reverse_def (a:poly) (n:nat) : Lemma + (reverse_def a n == reverse a n) + = + reveal_defs (); + lemma_equal (reverse_def a n) (reverse a n) + +let reverse_element_fun (a:poly) (n:nat) (k i:int) : bool = a.[i] && one.[k - (n - i)] + +let rec or_of_bools (j k:int) (f:int -> bool) : Tot bool (decreases (k - j)) = + if j >= k then f k + else (or_of_bools j (k - 1) f) || f k + +let rec lemma_reverse_elem_s (a:poly) (k:int) (n:nat) (n':nat) : Lemma + (requires n' <= n) + (ensures or_of_bools 0 n' (reverse_element_fun a n k) == (s_reverse_rec a n n').[k]) + (decreases n') + = + PL.lemma_index_all (); + PL.lemma_shift_define_all (); + PL.lemma_or_define_all (); + if n' > 0 then lemma_reverse_elem_s a k n (n' - 1) + +let rec lemma_reverse_elem_i (a:poly) (k:int) (n:nat) (n':nat) : Lemma + (requires True) + (ensures or_of_bools 0 n' (reverse_element_fun a n k) == (if k >= n - n' then a.[n - k] else false)) + (decreases n') + = + reveal_defs (); + if n' > 0 then lemma_reverse_elem_i a k n (n' - 1) + +let lemma_reverse_k (a:poly) (n:nat) (k:int) : Lemma + ((reverse_def a n).[k] == (s_reverse a n).[k]) + = + reveal_defs (); + lemma_reverse_elem_s a k n n; + lemma_reverse_elem_i a k n n + +let lemma_s_reverse_def (a:poly) (n:nat) : Lemma + (reverse_def a n == s_reverse a n) + = + PL.lemma_pointwise_equal (reverse_def a n) (s_reverse a n) (lemma_reverse_k a n) + +let lemma_s_reverse (a:poly) (n:nat) : Lemma + (ensures reverse a n == s_reverse a n) + = + lemma_reverse_def a n; + lemma_s_reverse_def a n + +let rec s_reverse_rec_n (a:poly) (n:nat) (n':nat) : Tot (poly) (decreases n') = + if n' = 0 then zero + else + let u = s_reverse_rec_n a n (n' - 1) in + let a' = if a.[n' - 1] then one else zero in + let a = shift a' (n - n') in + u |. a + +let reverse_n_element_fun (a:poly) (n:nat) (k i:int) : bool = a.[i] && one.[k - (n - 1 - i)] + +let rec or_of_bools_n (j k:int) (f:int -> bool) : Tot bool (decreases (k - j)) = + if j >= k then false + else (or_of_bools_n j (k - 1) f) || f (k - 1) + +let rec lemma_reverse_n_k_base (a:poly) (k:int) (n:nat) (n':nat) : Lemma + (requires n' <= n) + (ensures or_of_bools_n 0 n' (reverse_n_element_fun a n k) == (s_reverse_rec_n a n n').[k]) + (decreases n') + = + PL.lemma_index_all (); + PL.lemma_shift_define_all (); + PL.lemma_or_define_all (); + if n' > 0 then lemma_reverse_n_k_base a k n (n' - 1) + +let rec lemma_reverse_elem_n (a:poly) (k:int) (n:nat) (n':nat) : Lemma + (requires True) + (ensures or_of_bools_n 0 n' (reverse_n_element_fun a (n + 1) k) == + (if n' > 0 then or_of_bools 0 (n' - 1) (reverse_element_fun a n k) else false)) + (decreases n') + = + if n' > 0 then lemma_reverse_elem_n a k n (n' - 1) + +let lemma_reverse_n_k (a:poly) (n:nat) (k:int) : Lemma + ((s_reverse_rec a n n).[k] == (s_reverse_rec_n a (n + 1) (n + 1)).[k]) + = + lemma_reverse_elem_n a k n n; + lemma_reverse_elem_s a k n n; + lemma_reverse_n_k_base a k (n + 1) (n + 1) + +let lemma_reverse_n (a:poly) (n:nat) : Lemma + (s_reverse_rec a n n == s_reverse_rec_n a (n + 1) (n + 1)) + = + PL.lemma_pointwise_equal (s_reverse_rec a n n) (s_reverse_rec_n a (n + 1) (n + 1)) (lemma_reverse_n_k a n) + +let lemma_shift_right_and (f:G.field) (a:G.felem f) (i:nat{i < I.bits f.t}) : Lemma + (to_poly (I.logand (I.shift_right a (I.size i)) (G.one #f)) == (if (to_poly a).[i] then one else zero)) + = + PL.lemma_zero_define (); + PL.lemma_one_define (); + PL.lemma_and_define_all (); + PL.lemma_shift_define_all (); + if (to_poly a).[i] then + lemma_equal ((shift (to_poly a) (-i)) &. one) one + else + lemma_equal ((shift (to_poly a) (-i)) &. one) zero + +#reset-options "--z3rlimit 20" +let rec lemma_s_g_reverse_rec (f:G.field) (e:G.felem f) (n:nat{n <= I.bits f.t}) : Lemma + (to_poly (reverse_rec f e n) == s_reverse_rec_n (to_poly e) (I.bits f.t) n) + = + let a = to_poly e in + if n > 0 then + ( + lemma_s_g_reverse_rec f e (n - 1); + let sa = s_reverse_rec_n a (I.bits f.t) (n - 1) in + let ga = reverse_rec f e (n - 1) in + //assert (to_poly ga == sa); + + lemma_shift_right_and f e (n - 1); + //assert ((to_poly (I.logand (I.shift_right e (I.size (n - 1))) (G.one #f)) == + // (if a.[n - 1] then one else zero))); + + lemma_shift_left f (I.logand (I.shift_right e (I.size (n - 1))) (G.one #f)) (I.size 1); + if a.[n - 1] then + PL.lemma_mod_small (shift one (I.bits f.t - n)) (monomial (I.bits f.G.t)) + else + PL.lemma_mod_small (shift zero (I.bits f.t - n)) (monomial (I.bits f.G.t)); + //assert (to_poly (I.shift_left (I.logand (I.shift_right e (I.size (n - 1))) (G.one #f)) (I.size (I.bits f.t - n))) == + // shift (if a.[n - 1] then one else zero) (I.bits f.t - n)); + lemma_or f ga (I.shift_left (I.logand (I.shift_right e (I.size (n - 1))) (G.one #f)) (I.size (I.bits f.t - n))); + //assert (to_poly (I.logor ga (I.shift_left (I.logand (I.shift_right e (I.size (n - 1))) (G.one #f)) (I.size (I.bits f.t - n)))) == + // (sa |. (shift (if a.[n - 1] then one else zero) (I.bits f.t - n)))); + () + ) + +let lemma_s_g_reverse (f:G.field) (a:G.felem f) : Lemma + (to_poly (g_reverse f a) == s_reverse (to_poly a) (I.bits f.t - 1)) + = + lemma_s_g_reverse_rec f a (I.bits f.t); + lemma_reverse_n (to_poly a) (I.bits f.t - 1) + +let lemma_reverse f e = + lemma_f_g_reverse f e; + lemma_f_reverse f e; + lemma_s_g_reverse f e; + lemma_s_reverse (to_poly e) (I.bits f.t - 1) diff --git a/vale/code/lib/math/Vale.Math.Poly2.Galois.fsti b/vale/code/lib/math/Vale.Math.Poly2.Galois.fsti index fb317aa292..3a01e98bd9 100644 --- a/vale/code/lib/math/Vale.Math.Poly2.Galois.fsti +++ b/vale/code/lib/math/Vale.Math.Poly2.Galois.fsti @@ -83,3 +83,8 @@ val lemma_mul (f:G.field) (a b:G.felem f) : Lemma (requires True) (ensures to_poly (G.fmul a b) == (to_poly a *. to_poly b) %. (irred_poly f)) [SMTPat (to_poly (G.fmul a b))] + +val lemma_reverse (f:G.field) (e:G.felem f) : Lemma + (requires True) + (ensures to_poly (G.reverse e) == reverse (to_poly e) (I.bits f.t - 1)) + [SMTPat (to_poly (G.reverse e))]