Skip to content

Commit

Permalink
Add test for modulo
Browse files Browse the repository at this point in the history
  • Loading branch information
Sh0g0-1758 committed Sep 19, 2024
1 parent a219b2d commit 23129af
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 135 deletions.
195 changes: 60 additions & 135 deletions pkg/vm/builtins/modulo.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,20 @@ import (
"math/big"

"github.com/NethermindEth/cairo-vm-go/pkg/utils"

"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

const ModuloName = "Mod"

const VALUES_PTR_OFFSET = 4
const OFFSETS_PTR_OFFSET = 5
const N_OFFSET = 6
const N_WORDS = 4
const CELLS_PER_MOD = 7
const FILL_MEMORY_MAX = 100000

type ModInstanceDef struct {
ratio *uint32
wordBitLen uint32
batchSize uint
}

func (*ModInstanceDef) New(ratio *uint32, wordBitLen uint32, batchSize uint) *ModInstanceDef {
return &ModInstanceDef{
ratio: ratio,
wordBitLen: wordBitLen,
batchSize: batchSize,
}
}

type ModBuiltinInputs struct {
p big.Int
pValues [N_WORDS]fp.Element
Expand All @@ -48,82 +37,58 @@ const (
type Operation string

const (
AddOp Operation = "add"
SubOp Operation = "sub"
MulOp Operation = "mul"
DivModOp Operation = "divmod"
addOp Operation = "add"
subOp Operation = "invAdd"
mulOp Operation = "mul"
divOp Operation = "invMul"
)

type ModBuiltin struct {
builtinType ModBuiltinType
base uint64
stop_ptr *uint
instanceDef ModInstanceDef
included bool
zeroSegmentIndex uint
zeroSegmentSize uint
shift big.Int
shiftPowers [N_WORDS]big.Int
}

func (m *ModBuiltin) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error {
return nil
}

func (m *ModBuiltin) InferValue(segment *memory.Segment, offset uint64) error {
return nil
ratio uint64
modBuiltinType ModBuiltinType
wordBitLen uint64
batchSize uint64
shift big.Int
shiftPowers [N_WORDS]big.Int
}

func (m *ModBuiltin) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) {
return 0, nil
}

func max(a, b uint) uint {
if a > b {
return a
}
return b
}

func (*ModBuiltin) New(builtinType ModBuiltinType, included bool, instanceDef ModInstanceDef) *ModBuiltin {
shift := new(big.Int).Lsh(big.NewInt(1), uint(instanceDef.wordBitLen))
func NewModBuiltin(ratio uint64, modBuiltinType ModBuiltinType, wordBitLen uint64, batchSize uint64) *ModBuiltin {
shift := new(big.Int).Lsh(big.NewInt(1), uint(wordBitLen))
shiftPowers := [N_WORDS]big.Int{}
for i := 0; i < N_WORDS; i++ {
shiftPowers[i] = *new(big.Int).Exp(shift, big.NewInt(int64(i)), nil)
shiftPowers[i].Exp(shift, big.NewInt(int64(i)), nil)
}
return &ModBuiltin{
builtinType: builtinType,
base: 0,
stop_ptr: nil,
instanceDef: instanceDef,
included: included,
zeroSegmentIndex: 0,
zeroSegmentSize: max(N_WORDS, instanceDef.batchSize*3),
shift: *shift,
shiftPowers: shiftPowers,
ratio: ratio,
modBuiltinType: modBuiltinType,
wordBitLen: wordBitLen,
batchSize: batchSize,
shift: *shift,
shiftPowers: shiftPowers,
}
}

func (mbr *ModBuiltin) NewAddMod(instanceDef *ModInstanceDef, included bool) *ModBuiltin {
return mbr.New(Add, included, *instanceDef)
func (m *ModBuiltin) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error {
return nil
}

func (mbr *ModBuiltin) NewMulMod(instanceDef *ModInstanceDef, included bool) *ModBuiltin {
return mbr.New(Mul, included, *instanceDef)
func (m *ModBuiltin) InferValue(segment *memory.Segment, offset uint64) error {
return nil
}

func (mbr *ModBuiltin) String() string {
switch mbr.builtinType {
case Add:
return "add_mod_builtin"
case Mul:
return "mul_mod_builtin"
default:
return "unknown"
func (m *ModBuiltin) String() string {
if m.modBuiltinType == Add {
return string(Add) + ModuloName
} else {
return string(Mul) + ModuloName
}
}

func (mbr *ModBuiltin) ReadNWordsValue(memory *memory.Memory, addr memory.MemoryAddress) ([N_WORDS]fp.Element, big.Int, error) {
func (m *ModBuiltin) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) {
return 0, nil
}

func (m *ModBuiltin) readNWordsValue(memory *memory.Memory, addr memory.MemoryAddress) ([N_WORDS]fp.Element, big.Int, error) {
var words [N_WORDS]fp.Element
value := new(big.Int).SetInt64(0)

Expand All @@ -140,12 +105,12 @@ func (mbr *ModBuiltin) ReadNWordsValue(memory *memory.Memory, addr memory.Memory

var word big.Int
wordFelt.BigInt(&word)
if word.Cmp(&mbr.shift) >= 0 {
if word.Cmp(&m.shift) >= 0 {
return words, *value, fmt.Errorf("word exceeds mod builtin word bit length")
}

words[i] = wordFelt
value.Add(value, new(big.Int).Mul(&word, &mbr.shiftPowers[i]))
value.Add(value, new(big.Int).Mul(&word, &m.shiftPowers[i]))
}

return words, *value, nil
Expand Down Expand Up @@ -176,7 +141,7 @@ func (m *ModBuiltin) readInputs(mem *memory.Memory, addr memory.MemoryAddress) (
if n < 1 {
return ModBuiltinInputs{}, fmt.Errorf("moduloBuiltin: n must be at least 1")
}
pValues, p, err := m.ReadNWordsValue(mem, addr)
pValues, p, err := m.readNWordsValue(mem, addr)
if err != nil {
return ModBuiltinInputs{}, err
}
Expand All @@ -189,52 +154,12 @@ func (m *ModBuiltin) readInputs(mem *memory.Memory, addr memory.MemoryAddress) (
}, nil
}

func (mbr *ModBuiltin) ComputeValue(memory memory.Memory, valuesPtr memory.MemoryAddress, offsetsPtr memory.MemoryAddress, indexInBatch uint, index uint) (big.Int, error) {
newOffsetPtr, err := offsetsPtr.AddOffset(int16(index + 3*indexInBatch))
if err != nil {
return big.Int{}, err
}
offset, err := memory.ReadFromAddress(&newOffsetPtr)
if err != nil {
return big.Int{}, err
}
offsetFelt, err := offset.Uint64()
if err != nil {
return big.Int{}, err
}
valueAddr, err := valuesPtr.AddOffset(int16(offsetFelt))
if err != nil {
return big.Int{}, err
}
_, value, err := mbr.ReadNWordsValue(&memory, valueAddr)
if err != nil {
return big.Int{}, err
}
return value, nil
}

func (mbr *ModBuiltin) ReadMemoryVars(memory memory.Memory, valuesPtr memory.MemoryAddress, offsetsPtr memory.MemoryAddress, indexInBatch uint) (big.Int, big.Int, big.Int, error) {
a, err := mbr.ComputeValue(memory, valuesPtr, offsetsPtr, indexInBatch, 0)
if err != nil {
return big.Int{}, big.Int{}, big.Int{}, err
}
b, err := mbr.ComputeValue(memory, valuesPtr, offsetsPtr, indexInBatch, 1)
if err != nil {
return big.Int{}, big.Int{}, big.Int{}, err
}
c, err := mbr.ComputeValue(memory, valuesPtr, offsetsPtr, indexInBatch, 2)
if err != nil {
return big.Int{}, big.Int{}, big.Int{}, err
}
return a, b, c, nil
}

func (m *ModBuiltin) fillInputs(mem *memory.Memory, builtinPtr memory.MemoryAddress, inputs ModBuiltinInputs) error {
if inputs.n > FILL_MEMORY_MAX {
return fmt.Errorf("fill memory max exceeded")
}

nInstances, err := utils.SafeDivUint64(inputs.n, uint64(m.instanceDef.batchSize))
nInstances, err := utils.SafeDivUint64(inputs.n, m.batchSize)
if err != nil {
return err
}
Expand Down Expand Up @@ -269,7 +194,7 @@ func (m *ModBuiltin) fillInputs(mem *memory.Memory, builtinPtr memory.MemoryAddr
if err != nil {
return err
}
newAddr, err := inputs.offsetsPtr.AddOffset(3 * int16(instance) * int16(m.instanceDef.batchSize))
newAddr, err := inputs.offsetsPtr.AddOffset(3 * int16(instance) * int16(m.batchSize))
if err != nil {
return err
}
Expand All @@ -282,7 +207,7 @@ func (m *ModBuiltin) fillInputs(mem *memory.Memory, builtinPtr memory.MemoryAddr
if err != nil {
return err
}
val := fp.NewElement(inputs.n - uint64(m.instanceDef.batchSize)*uint64(instance))
val := fp.NewElement(inputs.n - m.batchSize*uint64(instance))
mv = memory.MemoryValueFromFieldElement(&val)
if err := mem.WriteToAddress(&addr, &mv); err != nil {
return err
Expand Down Expand Up @@ -361,7 +286,7 @@ func (m *ModBuiltin) fillValue(mem *memory.Memory, inputs ModBuiltinInputs, inde
return false, err
}
addresses = append(addresses, addr)
_, value, err := m.ReadNWordsValue(mem, addr)
_, value, err := m.readNWordsValue(mem, addr)
if err != nil {
return false, err
}
Expand All @@ -372,13 +297,13 @@ func (m *ModBuiltin) fillValue(mem *memory.Memory, inputs ModBuiltinInputs, inde

applyOp := func(a, b *big.Int, op Operation) big.Int {
switch op {
case AddOp:
case addOp:
return *new(big.Int).Add(a, b)
case SubOp:
case subOp:
return *new(big.Int).Sub(a, b)
case MulOp:
case mulOp:
return *new(big.Int).Mul(a, b)
case DivModOp:
case divOp:
return *new(big.Int).Div(a, b)
default:
return *new(big.Int)
Expand Down Expand Up @@ -423,43 +348,43 @@ func FillMemory(mem *memory.Memory, addModBuiltinAddr memory.MemoryAddress, nAdd
if ok {
return fmt.Errorf("MulMod builtin segment doesn't exist")
}
addModBuiltin, ok := addModBuiltinSegment.BuiltinRunner.(*ModBuiltin)
addModBuiltinRunner, ok := addModBuiltinSegment.BuiltinRunner.(*ModBuiltin)
if !ok {
return fmt.Errorf("addModBuiltin is not a ModBuiltin")
return fmt.Errorf("addModBuiltinRunner is not a ModBuiltin")
}
mulModBuiltin, ok := mulModBuiltinSegment.BuiltinRunner.(*ModBuiltin)
mulModBuiltinRunner, ok := mulModBuiltinSegment.BuiltinRunner.(*ModBuiltin)
if !ok {
return fmt.Errorf("mulModBuiltin is not a ModBuiltin")
return fmt.Errorf("mulModBuiltinRunner is not a ModBuiltin")
}
if addModBuiltin.instanceDef.wordBitLen != mulModBuiltin.instanceDef.wordBitLen {
if addModBuiltinRunner.wordBitLen != mulModBuiltinRunner.wordBitLen {
return fmt.Errorf("AddMod and MulMod wordBitLen mismatch")
}

addModBuiltinInputs, err := addModBuiltin.readInputs(mem, addModBuiltinAddr)
addModBuiltinInputs, err := addModBuiltinRunner.readInputs(mem, addModBuiltinAddr)
if err != nil {
return err
}
if err := addModBuiltin.fillInputs(mem, addModBuiltinAddr, addModBuiltinInputs); err != nil {
if err := addModBuiltinRunner.fillInputs(mem, addModBuiltinAddr, addModBuiltinInputs); err != nil {
return err
}
if err := addModBuiltin.fillOffsets(mem, addModBuiltinInputs.offsetsPtr, nAddModsIndex, addModBuiltinInputs.n-nAddModsIndex); err != nil {
if err := addModBuiltinRunner.fillOffsets(mem, addModBuiltinInputs.offsetsPtr, nAddModsIndex, addModBuiltinInputs.n-nAddModsIndex); err != nil {
return err
}

mulModBuiltinInputs, err := mulModBuiltin.readInputs(mem, mulModBuiltinAddr)
mulModBuiltinInputs, err := mulModBuiltinRunner.readInputs(mem, mulModBuiltinAddr)
if err != nil {
return err
}
if err := mulModBuiltin.fillInputs(mem, mulModBuiltinAddr, mulModBuiltinInputs); err != nil {
if err := mulModBuiltinRunner.fillInputs(mem, mulModBuiltinAddr, mulModBuiltinInputs); err != nil {
return err
}
if err := mulModBuiltin.fillOffsets(mem, mulModBuiltinInputs.offsetsPtr, nMulModsIndex, mulModBuiltinInputs.n-nMulModsIndex); err != nil {
if err := mulModBuiltinRunner.fillOffsets(mem, mulModBuiltinInputs.offsetsPtr, nMulModsIndex, mulModBuiltinInputs.n-nMulModsIndex); err != nil {
return err
}

addModIndex, mulModIndex := uint64(0), uint64(0)
for addModIndex < nAddModsIndex {
ok, err := addModBuiltin.fillValue(mem, addModBuiltinInputs, int(addModIndex), AddOp, SubOp)
ok, err := addModBuiltinRunner.fillValue(mem, addModBuiltinInputs, int(addModIndex), addOp, subOp)
if err != nil {
return err
}
Expand All @@ -469,7 +394,7 @@ func FillMemory(mem *memory.Memory, addModBuiltinAddr memory.MemoryAddress, nAdd
}

for mulModIndex < nMulModsIndex {
ok, err = mulModBuiltin.fillValue(mem, mulModBuiltinInputs, int(mulModIndex), MulOp, DivModOp)
ok, err = mulModBuiltinRunner.fillValue(mem, mulModBuiltinInputs, int(mulModIndex), mulOp, divOp)
if err != nil {
return err
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/vm/builtins/modulo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package builtins

import (
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
)

func TestModuloBuiltin(t *testing.T) {
mod := &ModBuiltin{ratio: 2048, modBuiltinType: Add}
segment := memory.EmptySegmentWithLength(9)
segment.WithBuiltinRunner(mod)
}

0 comments on commit 23129af

Please sign in to comment.