Skip to content

Commit

Permalink
failed attempt at solving towering / base field perf discrepancy, sim…
Browse files Browse the repository at this point in the history
…ilar to #446
  • Loading branch information
mratsim committed Aug 22, 2024
1 parent b850dc7 commit 112ab49
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 68 deletions.
11 changes: 11 additions & 0 deletions constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ macro addmod_gen[N: static int](r_PIR: var Limbs[N], a_PIR, b_PIR, M_MEM: Limbs[
ctx.finalSubMayOverflowImpl(r, u, M, v, a_in_scratch = true, scratchReg = b.reuseRegister())

result.add ctx.generate()
return nnkBlockStmt.newTree(
newEmptyNode(),
result)

func addmod_asm*(r: var Limbs, a, b, M: Limbs, spareBits: static int) =
## Constant-time modular addition
Expand Down Expand Up @@ -228,6 +231,10 @@ macro submod_gen[N: static int](r_PIR: var Limbs[N], a_PIR, b_PIR, M_MEM: Limbs[

result.add ctx.generate()

return nnkBlockStmt.newTree(
newEmptyNode(),
result)

func submod_asm*(r: var Limbs, a, b, M: Limbs) =
## Constant-time modular substraction
## Warning, does not handle aliasing of a and b
Expand Down Expand Up @@ -277,6 +284,10 @@ macro negmod_gen[N: static int](r_PIR: var Limbs[N], a_MEM, M_MEM: Limbs[N]): un

result.add ctx.generate()

return nnkBlockStmt.newTree(
newEmptyNode(),
result)

func negmod_asm*(r: var Limbs, a, M: Limbs) =
## Constant-time modular negation
negmod_gen(r, a, M)
85 changes: 85 additions & 0 deletions constantine/math/extension_fields/assembly/fp2_asm_x86.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
# Internal
constantine/platforms/abstractions,
constantine/named/algebras,
constantine/math/arithmetic

import constantine/math/arithmetic/assembly/limbs_asm_modular_x86 {.all.}

# ############################################################
# #
# Assembly implementation of 𝔽p2 #
# #
# ############################################################

static: doAssert UseASM_X86_64

# No exceptions allowed
{.push raises: [].}

# 𝔽p2 addition law
# ------------------------------------------------------------

template aliasPtr(coord, name: untyped): untyped =
# The *_gen macros get confused by bracket [] and dot `.` expressions
# when deriving names so create aliases
# Furthermore the C compiler requires asm inputs to be lvalues
# and arrays should be passed as pointers (aren't they aren't if we use a dot expression)
let name {.inject.} = coord.mres.limbs.unsafeAddr()

func fp2_add_asm*(
r: var array[2, Fp],
a, b: array[2, Fp]) =
## Addition on Fp2
# This specialized proc inline calls and limits data movement (for example register pop/push)
const spareBits = Fp.getSpareBits()

aliasPtr r[0], r0
aliasPtr r[1], r1
aliasPtr a[0], a0
aliasPtr a[1], a1
aliasPtr b[0], b0
aliasPtr b[1], b1
let p = Fp.getModulus().limbs.unsafeAddr()

addmod_gen(r0[], a0[], b0[], p[], spareBits)
addmod_gen(r1[], a1[], b1[], p[], spareBits)

func fp2_sub_asm*(
r: var array[2, Fp],
a, b: array[2, Fp]) =
## Substraction on Fp2
# This specialized proc inline calls and limits data movement (for example register pop/push)
aliasPtr r[0], r0
aliasPtr r[1], r1
aliasPtr a[0], a0
aliasPtr a[1], a1
aliasPtr b[0], b0
aliasPtr b[1], b1
let p = Fp.getModulus().limbs.unsafeAddr()

submod_gen(r0[], a0[], b0[], p[])
submod_gen(r1[], a1[], b1[], p[])

func fp2_neg_asm*(
r: var array[2, Fp],
a: array[2, Fp]) =
## Negation on Fp2
# This specialized proc inline calls and limits data movement (for example register pop/push)

aliasPtr r[0], r0
aliasPtr r[1], r1
aliasPtr a[0], a0
aliasPtr a[1], a1
let p = Fp.getModulus().limbs.unsafeAddr()

negmod_gen(r0[], a0[], p[])
negmod_gen(r1[], a1[], p[])
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ static: doAssert UseASM_X86_64
# No exceptions allowed
{.push raises: [].}

template c0*(a: array): auto =
template c0(a: array): auto =
a[0]
template c1*(a: array): auto =
template c1(a: array): auto =
a[1]

func has1extraBit(F: type Fp): bool =
Expand Down
158 changes: 95 additions & 63 deletions constantine/math/extension_fields/towers.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export Fp

when UseASM_X86_64:
import
./assembly/fp2_asm_x86_adx_bmi2
./assembly/[fp2_asm_x86, fp2_asm_x86_adx_bmi2]

# Note: to avoid burdening the Nim compiler, we rely on generic extension
# to complain if the base field procedures don't exist
Expand Down Expand Up @@ -57,8 +57,8 @@ type
CubicExt[Fp2[Name]]

Fp12*[Name: static Algebra] =
CubicExt[Fp4[Name]]
# QuadraticExt[Fp6[Name]]
# CubicExt[Fp4[Name]]
QuadraticExt[Fp6[Name]]

template c0*(a: ExtensionField): auto =
a.coords[0]
Expand All @@ -80,6 +80,56 @@ template Name*(E: type ExtensionField): Algebra =
template getModulus*(E: type ExtensionField): auto =
E.F.getModulus()

# ############################################################
# #
# Cost functions #
# #
# ############################################################

func prefer_3sqr_over_2mul(F: type ExtensionField): bool {.compileTime.} =
## Returns true
## if time(3sqr) < time(2mul) in the extension fields

let a = default(F)
# No shortcut in the VM
when a.c0 is Fp12:
# Benchmarked on BLS12-381
when a.c0.c0 is Fp6:
return true
elif a.c0.c0 is Fp4:
return false
else: return false
else: return false

func has_large_NR_norm(Name: static Algebra): bool =
## Returns true if the non-residue of the extension fields
## has a large norm

const j = Name.getNonResidueFp()
const u = Name.getNonResidueFp2()[0]
const v = Name.getNonResidueFp2()[1]

const norm2 = u*u + (j*v)*(j*v)

# Compute integer square root
var norm = 0
while (norm+1) * (norm+1) <= norm2:
norm += 1

return norm > 5

func has_large_field_elem*(Name: static Algebra): bool =
## Returns true if field element are large
## and necessitate custom routine for assembly in particular
let a = default(Fp[Name])
return a.mres.limbs.len > 6

# ############################################################
# #
# Implementation #
# #
# ############################################################

# Initialization
# -------------------------------------------------------------------

Expand Down Expand Up @@ -148,36 +198,56 @@ func ccopy*(a: var ExtensionField, b: ExtensionField, ctl: SecretBool) =

# Abelian group
# -------------------------------------------------------------------
func hasFp2x86asm(T: type ExtensionField): bool =
T is Fp2 and UseASM_X86_64 and not T.Name.has_large_field_elem()

func neg*(r: var ExtensionField, a: ExtensionField) =
## Field out-of-place negation
staticFor i, 0, a.coords.len:
r.coords[i].neg(a.coords[i])
when a.typeof().hasFp2x86asm():
r.coords.fp2_neg_asm(a.coords)
else:
staticFor i, 0, a.coords.len:
r.coords[i].neg(a.coords[i])

func neg*(a: var ExtensionField) =
## Field in-place negation
staticFor i, 0, a.coords.len:
a.coords[i].neg()
when a.typeof().hasFp2x86asm():
a.coords.fp2_neg_asm(a.coords)
else:
staticFor i, 0, a.coords.len:
a.coords[i].neg()

func `+=`*(a: var ExtensionField, b: ExtensionField) =
## Addition in the extension field
staticFor i, 0, a.coords.len:
a.coords[i] += b.coords[i]
when a.typeof().hasFp2x86asm():
a.coords.fp2_add_asm(a.coords, b.coords)
else:
staticFor i, 0, a.coords.len:
a.coords[i] += b.coords[i]

func `-=`*(a: var ExtensionField, b: ExtensionField) =
## Substraction in the extension field
staticFor i, 0, a.coords.len:
a.coords[i] -= b.coords[i]
when a.typeof().hasFp2x86asm():
a.coords.fp2_sub_asm(a.coords, b.coords)
else:
staticFor i, 0, a.coords.len:
a.coords[i] -= b.coords[i]

func double*(r: var ExtensionField, a: ExtensionField) =
## Field out-of-place doubling
staticFor i, 0, a.coords.len:
r.coords[i].double(a.coords[i])
when a.typeof().hasFp2x86asm():
r.coords.fp2_add_asm(a.coords, a.coords)
else:
staticFor i, 0, a.coords.len:
r.coords[i].double(a.coords[i])

func double*(a: var ExtensionField) =
## Field in-place doubling
staticFor i, 0, a.coords.len:
a.coords[i].double()
when a.typeof().hasFp2x86asm():
a.coords.fp2_add_asm(a.coords, a.coords)
else:
staticFor i, 0, a.coords.len:
a.coords[i].double()

func div2*(a: var ExtensionField) =
## Field in-place division by 2
Expand All @@ -186,13 +256,19 @@ func div2*(a: var ExtensionField) =

func sum*(r: var ExtensionField, a, b: ExtensionField) =
## Sum ``a`` and ``b`` into ``r``
staticFor i, 0, a.coords.len:
r.coords[i].sum(a.coords[i], b.coords[i])
when a.typeof().hasFp2x86asm():
r.coords.fp2_add_asm(a.coords, b.coords)
else:
staticFor i, 0, a.coords.len:
r.coords[i].sum(a.coords[i], b.coords[i])

func diff*(r: var ExtensionField, a, b: ExtensionField) =
## Diff ``a`` and ``b`` into ``r``
staticFor i, 0, a.coords.len:
r.coords[i].diff(a.coords[i], b.coords[i])
when a.typeof().hasFp2x86asm():
r.coords.fp2_sub_asm(a.coords, b.coords)
else:
staticFor i, 0, a.coords.len:
r.coords[i].diff(a.coords[i], b.coords[i])

func conj*(a: var QuadraticExt) =
## Computes the conjugate in-place
Expand Down Expand Up @@ -692,50 +768,6 @@ func prod2x*(

{.pop.} # inline

# ############################################################
# #
# Cost functions #
# #
# ############################################################

func prefer_3sqr_over_2mul(F: type ExtensionField): bool {.compileTime.} =
## Returns true
## if time(3sqr) < time(2mul) in the extension fields

let a = default(F)
# No shortcut in the VM
when a.c0 is Fp12:
# Benchmarked on BLS12-381
when a.c0.c0 is Fp6:
return true
elif a.c0.c0 is Fp4:
return false
else: return false
else: return false

func has_large_NR_norm(Name: static Algebra): bool =
## Returns true if the non-residue of the extension fields
## has a large norm

const j = Name.getNonResidueFp()
const u = Name.getNonResidueFp2()[0]
const v = Name.getNonResidueFp2()[1]

const norm2 = u*u + (j*v)*(j*v)

# Compute integer square root
var norm = 0
while (norm+1) * (norm+1) <= norm2:
norm += 1

return norm > 5

func has_large_field_elem*(Name: static Algebra): bool =
## Returns true if field element are large
## and necessitate custom routine for assembly in particular
let a = default(Fp[Name])
return a.mres.limbs.len > 6

# ############################################################
# #
# Quadratic extensions #
Expand Down
2 changes: 1 addition & 1 deletion constantine/platforms/ast_rebuilder.nim
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ proc rebuildUntypedAst*(ast: NimNode, dropRootStmtList = false): NimNode =
if dropRootStmtList and ast.kind == nnkStmtList:
return rebuild(ast[0])
else:
result = rebuild(ast)
result = rebuild(ast)
10 changes: 10 additions & 0 deletions constantine/platforms/static_for.nim
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode =
return node
of nnkLiterals:
return node

# Rebuild untyped AST
# --------------------
of nnkHiddenStdConv:
if node[1].kind == nnkIntLit:
return node[1]
else:
expectKind(node[1], nnkSym)
return ident($node[1])
# --------------------
else:
var rTree = node.kind.newTree()
for child in node:
Expand Down
4 changes: 2 additions & 2 deletions constantine/platforms/x86/macro_assembler_x86_intel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ const OutputReg = {asmOutputEarlyClobber, asmInputOutput, asmInputOutputEarlyClo

func toString*(nimSymbol: NimNode): string =
# We need to dereference the hidden pointer of var param
let isPtr = nimSymbol.kind in {nnkHiddenDeref, nnkPtrTy}
let isPtr = nimSymbol.kind in {nnkHiddenDeref, nnkPtrTy, nnkDerefExpr}
let isAddr = nimSymbol.kind in {nnkInfix, nnkCall} and (nimSymbol[0].eqIdent"addr" or nimSymbol[0].eqIdent"unsafeAddr")

let nimSymbol = if isPtr: nimSymbol[0]
Expand Down Expand Up @@ -432,7 +432,7 @@ func generate*(a: Assembler_x86): NimNode =
params.add newLit(": ") & inOperands.foldl(a & newLit(", ") & b) & newLit("\n")
else:
params.add newLit(":\n")

let clobbers = [(a.isStackClobbered, "sp"),
(a.areFlagsClobbered, "cc"),
(memClobbered, "memory")]
Expand Down

0 comments on commit 112ab49

Please sign in to comment.