Skip to content

Commit

Permalink
Merge pull request #636 from cryspen/jonas/ml-kem-target-feature
Browse files Browse the repository at this point in the history
ML-KEM: AVX2 target feature edition
  • Loading branch information
franziskuskiefer authored Oct 21, 2024
2 parents e6b2142 + e474e81 commit 13a5dea
Show file tree
Hide file tree
Showing 54 changed files with 2,303 additions and 794 deletions.
51 changes: 51 additions & 0 deletions libcrux-intrinsics/src/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,83 @@ pub type Vec256 = __m256i;
pub type Vec128 = __m128i;
pub type Vec256Float = __m256;

#[inline(always)]
pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) {
debug_assert_eq!(output.len(), 32);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}
#[inline(always)]
pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}
#[inline(always)]
pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) {
debug_assert_eq!(output.len(), 8);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}

#[inline(always)]
pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) {
debug_assert!(output.len() >= 8);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}
#[inline(always)]
pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) {
debug_assert_eq!(output.len(), 4);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}

#[inline(always)]
pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: Vec128) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}

#[inline(always)]
pub fn mm_loadu_si128(input: &[u8]) -> Vec128 {
debug_assert_eq!(input.len(), 16);
unsafe { _mm_loadu_si128(input.as_ptr() as *const Vec128) }
}

#[inline(always)]
pub fn mm256_loadu_si256_u8(input: &[u8]) -> Vec256 {
debug_assert_eq!(input.len(), 32);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}
#[inline(always)]
pub fn mm256_loadu_si256_i16(input: &[i16]) -> Vec256 {
debug_assert_eq!(input.len(), 16);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}
#[inline(always)]
pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 {
debug_assert_eq!(input.len(), 8);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}

#[inline(always)]
pub fn mm256_setzero_si256() -> Vec256 {
unsafe { _mm256_setzero_si256() }
}
#[inline(always)]
pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 {
unsafe { _mm256_set_m128i(hi, lo) }
}

#[inline(always)]
pub fn mm_set_epi8(
byte15: u8,
byte14: u8,
Expand Down Expand Up @@ -111,6 +124,7 @@ pub fn mm_set_epi8(
}
}

#[inline(always)]
pub fn mm256_set_epi8(
byte31: i8,
byte30: i8,
Expand Down Expand Up @@ -154,9 +168,11 @@ pub fn mm256_set_epi8(
}
}

#[inline(always)]
pub fn mm256_set1_epi16(constant: i16) -> Vec256 {
unsafe { _mm256_set1_epi16(constant) }
}
#[inline(always)]
pub fn mm256_set_epi16(
input15: i16,
input14: i16,
Expand Down Expand Up @@ -242,21 +258,26 @@ pub fn mm256_abs_epi32(a: Vec256) -> Vec256 {
unsafe { _mm256_abs_epi32(a) }
}

#[inline(always)]
pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi16(lhs, rhs) }
}
#[inline(always)]
pub fn mm256_sub_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm_sub_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_sub_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mullo_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mullo_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm_mullo_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_mullo_epi16(lhs, rhs) }
}
Expand Down Expand Up @@ -289,18 +310,22 @@ pub fn mm256_movemask_ps(a: Vec256Float) -> i32 {
unsafe { _mm256_movemask_ps(a) }
}

#[inline(always)]
pub fn mm_mulhi_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_mulhi_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mullo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mullo_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mulhi_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mulhi_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mul_epu32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mul_epu32(lhs, rhs) }
}
Expand All @@ -320,102 +345,126 @@ pub fn mm256_or_si256(a: Vec256, b: Vec256) -> Vec256 {
unsafe { _mm256_or_si256(a, b) }
}

#[inline(always)]
pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 {
unsafe { _mm256_testz_si256(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_xor_si256(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_srai_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srai_epi16(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srai_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srai_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_srli_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srli_epi16(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srli_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm_srli_epi64<const SHIFT_BY: i32>(vector: Vec128) -> Vec128 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm_srli_epi64(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srli_epi64<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm256_srli_epi64(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_slli_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_slli_epi16(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_slli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_slli_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm_shuffle_epi8(vector: Vec128, control: Vec128) -> Vec128 {
unsafe { _mm_shuffle_epi8(vector, control) }
}
#[inline(always)]
pub fn mm256_shuffle_epi8(vector: Vec256, control: Vec256) -> Vec256 {
unsafe { _mm256_shuffle_epi8(vector, control) }
}
#[inline(always)]
pub fn mm256_shuffle_epi32<const CONTROL: i32>(vector: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_shuffle_epi32(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_permute4x64_epi64<const CONTROL: i32>(vector: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_permute4x64_epi64(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_unpackhi_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpackhi_epi64(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_unpacklo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpacklo_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_unpackhi_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpackhi_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_castsi256_si128(vector: Vec256) -> Vec128 {
unsafe { _mm256_castsi256_si128(vector) }
}
#[inline(always)]
pub fn mm256_castsi128_si256(vector: Vec128) -> Vec256 {
unsafe { _mm256_castsi128_si256(vector) }
}

#[inline(always)]
pub fn mm256_cvtepi16_epi32(vector: Vec128) -> Vec256 {
unsafe { _mm256_cvtepi16_epi32(vector) }
}

#[inline(always)]
pub fn mm_packs_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_packs_epi16(lhs, rhs) }
}
#[inline(always)]
pub fn mm256_packs_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_packs_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_extracti128_si256<const CONTROL: i32>(vector: Vec256) -> Vec128 {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_extracti128_si256(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_inserti128_si256<const CONTROL: i32>(vector: Vec256, vector_i128: Vec128) -> Vec256 {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) }
Expand Down Expand Up @@ -465,9 +514,11 @@ pub fn mm256_srlv_epi64(vector: Vec256, counts: Vec256) -> Vec256 {
unsafe { _mm256_srlv_epi64(vector, counts) }
}

#[inline(always)]
pub fn mm_sllv_epi32(vector: Vec128, counts: Vec128) -> Vec128 {
unsafe { _mm_sllv_epi32(vector, counts) }
}
#[inline(always)]
pub fn mm256_sllv_epi32(vector: Vec256, counts: Vec256) -> Vec256 {
unsafe { _mm256_sllv_epi32(vector, counts) }
}
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/code_gen.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
This code was generated with the following revisions:
Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_core_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_mlkem_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_mlkem_avx2_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_mlkem_portable.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_mlkem_portable_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_sha3_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_sha3_avx2_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_sha3_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_sha3_internal_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/libcrux_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#include "internal/libcrux_core.h"
Expand Down
Loading

0 comments on commit 13a5dea

Please sign in to comment.