Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-KEM: AVX2 target feature edition #636

Merged
merged 18 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions libcrux-intrinsics/src/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,83 @@ pub type Vec256 = __m256i;
pub type Vec128 = __m128i;
pub type Vec256Float = __m256;

#[inline(always)]
pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) {
debug_assert_eq!(output.len(), 32);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}
#[inline(always)]
pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}
#[inline(always)]
pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) {
debug_assert_eq!(output.len(), 8);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}

#[inline(always)]
pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) {
debug_assert!(output.len() >= 8);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}
#[inline(always)]
pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) {
debug_assert_eq!(output.len(), 4);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}

#[inline(always)]
pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: Vec128) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}

#[inline(always)]
pub fn mm_loadu_si128(input: &[u8]) -> Vec128 {
debug_assert_eq!(input.len(), 16);
unsafe { _mm_loadu_si128(input.as_ptr() as *const Vec128) }
}

#[inline(always)]
pub fn mm256_loadu_si256_u8(input: &[u8]) -> Vec256 {
debug_assert_eq!(input.len(), 32);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}
#[inline(always)]
pub fn mm256_loadu_si256_i16(input: &[i16]) -> Vec256 {
debug_assert_eq!(input.len(), 16);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}
#[inline(always)]
pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 {
debug_assert_eq!(input.len(), 8);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}

#[inline(always)]
pub fn mm256_setzero_si256() -> Vec256 {
unsafe { _mm256_setzero_si256() }
}
#[inline(always)]
pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 {
unsafe { _mm256_set_m128i(hi, lo) }
}

#[inline(always)]
pub fn mm_set_epi8(
byte15: u8,
byte14: u8,
Expand Down Expand Up @@ -111,6 +124,7 @@ pub fn mm_set_epi8(
}
}

#[inline(always)]
pub fn mm256_set_epi8(
byte31: i8,
byte30: i8,
Expand Down Expand Up @@ -154,9 +168,11 @@ pub fn mm256_set_epi8(
}
}

#[inline(always)]
pub fn mm256_set1_epi16(constant: i16) -> Vec256 {
unsafe { _mm256_set1_epi16(constant) }
}
#[inline(always)]
pub fn mm256_set_epi16(
input15: i16,
input14: i16,
Expand Down Expand Up @@ -242,21 +258,26 @@ pub fn mm256_abs_epi32(a: Vec256) -> Vec256 {
unsafe { _mm256_abs_epi32(a) }
}

#[inline(always)]
pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi16(lhs, rhs) }
}
#[inline(always)]
pub fn mm256_sub_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm_sub_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_sub_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mullo_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mullo_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm_mullo_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_mullo_epi16(lhs, rhs) }
}
Expand Down Expand Up @@ -289,18 +310,22 @@ pub fn mm256_movemask_ps(a: Vec256Float) -> i32 {
unsafe { _mm256_movemask_ps(a) }
}

#[inline(always)]
pub fn mm_mulhi_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_mulhi_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mullo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mullo_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mulhi_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mulhi_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mul_epu32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mul_epu32(lhs, rhs) }
}
Expand All @@ -320,102 +345,126 @@ pub fn mm256_or_si256(a: Vec256, b: Vec256) -> Vec256 {
unsafe { _mm256_or_si256(a, b) }
}

#[inline(always)]
pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 {
unsafe { _mm256_testz_si256(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_xor_si256(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_srai_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srai_epi16(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srai_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srai_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_srli_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srli_epi16(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srli_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm_srli_epi64<const SHIFT_BY: i32>(vector: Vec128) -> Vec128 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm_srli_epi64(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srli_epi64<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm256_srli_epi64(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_slli_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_slli_epi16(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_slli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_slli_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm_shuffle_epi8(vector: Vec128, control: Vec128) -> Vec128 {
unsafe { _mm_shuffle_epi8(vector, control) }
}
#[inline(always)]
pub fn mm256_shuffle_epi8(vector: Vec256, control: Vec256) -> Vec256 {
unsafe { _mm256_shuffle_epi8(vector, control) }
}
#[inline(always)]
pub fn mm256_shuffle_epi32<const CONTROL: i32>(vector: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_shuffle_epi32(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_permute4x64_epi64<const CONTROL: i32>(vector: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_permute4x64_epi64(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_unpackhi_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpackhi_epi64(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_unpacklo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpacklo_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_unpackhi_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpackhi_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_castsi256_si128(vector: Vec256) -> Vec128 {
unsafe { _mm256_castsi256_si128(vector) }
}
#[inline(always)]
pub fn mm256_castsi128_si256(vector: Vec128) -> Vec256 {
unsafe { _mm256_castsi128_si256(vector) }
}

#[inline(always)]
pub fn mm256_cvtepi16_epi32(vector: Vec128) -> Vec256 {
unsafe { _mm256_cvtepi16_epi32(vector) }
}

#[inline(always)]
pub fn mm_packs_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_packs_epi16(lhs, rhs) }
}
#[inline(always)]
pub fn mm256_packs_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_packs_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_extracti128_si256<const CONTROL: i32>(vector: Vec256) -> Vec128 {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_extracti128_si256(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_inserti128_si256<const CONTROL: i32>(vector: Vec256, vector_i128: Vec128) -> Vec256 {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) }
Expand Down Expand Up @@ -465,9 +514,11 @@ pub fn mm256_srlv_epi64(vector: Vec256, counts: Vec256) -> Vec256 {
unsafe { _mm256_srlv_epi64(vector, counts) }
}

#[inline(always)]
pub fn mm_sllv_epi32(vector: Vec128, counts: Vec128) -> Vec128 {
unsafe { _mm_sllv_epi32(vector, counts) }
}
#[inline(always)]
pub fn mm256_sllv_epi32(vector: Vec256, counts: Vec256) -> Vec256 {
unsafe { _mm256_sllv_epi32(vector, counts) }
}
Expand Down
10 changes: 6 additions & 4 deletions libcrux-ml-kem/benches/ml-kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ pub fn pk_validation(c: &mut Criterion) {
init!(mlkem1024, "PK Validation", c);
}

pub fn encapsulation(c: &mut Criterion) {
#[target_feature(enable = "avx2")]
jschneider-bensch marked this conversation as resolved.
Show resolved Hide resolved
pub unsafe fn encapsulation(c: &mut Criterion) {
macro_rules! fun {
($name:expr, $p:path, $group:expr) => {
$group.bench_function(format!("libcrux {} (external random)", $name), |b| {
Expand Down Expand Up @@ -160,7 +161,8 @@ pub fn encapsulation(c: &mut Criterion) {
init!(mlkem1024, "Encapsulation", c);
}

pub fn decapsulation(c: &mut Criterion) {
#[target_feature(enable = "avx2")]
pub unsafe fn decapsulation(c: &mut Criterion) {
macro_rules! fun {
($name:expr, $p:path, $group:expr) => {
$group.bench_function(format!("libcrux {}", $name), |b| {
Expand Down Expand Up @@ -219,8 +221,8 @@ pub fn decapsulation(c: &mut Criterion) {
pub fn comparisons(c: &mut Criterion) {
pk_validation(c);
key_generation(c);
encapsulation(c);
decapsulation(c);
unsafe { encapsulation(c) };
unsafe { decapsulation(c) };
}

criterion_group!(benches, comparisons);
Expand Down
Loading
Loading