diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 6c437285..2fbad300 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -10,7 +10,6 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" hintrunner "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/zero" - cairoversion "github.com/NethermindEth/cairo-vm-go/pkg/parsers/cairo_version" "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" zero "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero" "github.com/NethermindEth/cairo-vm-go/pkg/runner" @@ -110,10 +109,6 @@ func main() { if pathToFile == "" { return fmt.Errorf("path to cairo file not set") } - cairoVersion, err := cairoversion.GetCairoVersion(pathToFile) - if err != nil { - return fmt.Errorf("cannot get cairo version: %w", err) - } fmt.Printf("Loading program at %s\n", pathToFile) zeroProgram, err := zero.ZeroProgramFromFile(pathToFile) if err != nil { @@ -121,7 +116,7 @@ func main() { } var hints map[uint64][]hinter.Hinter - if cairoVersion > 0 { + if zeroProgram.CompilerVersion[0] == '1' { cairoProgram, err := starknet.StarknetProgramFromFile(pathToFile) if err != nil { return fmt.Errorf("cannot load program: %w", err) diff --git a/integration_tests/cairo_zero_hint_tests/hintrefs.cairo b/integration_tests/cairo_files_not_run_rust_vm/hintrefs.cairo similarity index 100% rename from integration_tests/cairo_zero_hint_tests/hintrefs.cairo rename to integration_tests/cairo_files_not_run_rust_vm/hintrefs.cairo diff --git a/integration_tests/cairozero_test.go b/integration_tests/cairozero_test.go index 0b0861ef..dd0e1fa6 100644 --- a/integration_tests/cairozero_test.go +++ b/integration_tests/cairozero_test.go @@ -6,7 +6,6 @@ import ( "os" "os/exec" "path/filepath" - "strconv" "strings" "sync" "testing" @@ -53,7 +52,7 @@ func (f *Filter) filtered(testFile string) bool { return false } -func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][2]int, benchmark bool, errorExpected bool) { +func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][3]int, benchmark bool, errorExpected bool) { t.Logf("testing: %s\n", path) compiledOutput, err := compileZeroCode(path) @@ -73,6 +72,17 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str } } + elapsedRs, rsTraceFile, rsMemoryFile, err := runRustVm(name, compiledOutput) + if errorExpected { + // we let the code go on so that we can check if the go vm also raises an error + assert.Error(t, err, path) + } else { + if err != nil { + t.Error(err) + return + } + } + elapsedGo, traceFile, memoryFile, _, err := runVm(compiledOutput) if errorExpected { assert.Error(t, err, path) @@ -85,7 +95,7 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str } if benchmark { - benchmarkMap[name] = [2]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds())} + benchmarkMap[name] = [3]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds()), int(elapsedRs.Milliseconds())} } pyTrace, pyMemory, err := decodeProof(pyTraceFile, pyMemoryFile) @@ -100,6 +110,20 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str return } + rsTrace, rsMemory, err := decodeProof(rsTraceFile, rsMemoryFile) + if err != nil { + t.Error(err) + return + } + + if !assert.Equal(t, pyTrace, rsTrace) { + t.Logf("pytrace:\n%s\n", traceRepr(pyTrace)) + t.Logf("rstrace:\n%s\n", traceRepr(rsTrace)) + } + if !assert.Equal(t, pyMemory, rsMemory) { + t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory)) + t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory)) + } if !assert.Equal(t, pyTrace, trace) { t.Logf("pytrace:\n%s\n", traceRepr(pyTrace)) t.Logf("trace:\n%s\n", traceRepr(trace)) @@ -108,6 +132,14 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory)) t.Logf("memory;\n%s\n", memoryRepr(memory)) } + if !assert.Equal(t, rsTrace, trace) { + t.Logf("rstrace:\n%s\n", traceRepr(rsTrace)) + t.Logf("trace:\n%s\n", traceRepr(trace)) + } + if !assert.Equal(t, rsMemory, memory) { + t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory)) + t.Logf("memory;\n%s\n", memoryRepr(memory)) + } } var zerobench = flag.Bool("zerobench", false, "run integration tests and generate benchmarks file") @@ -123,7 +155,7 @@ func TestCairoZeroFiles(t *testing.T) { filter := Filter{} filter.init() - benchmarkMap := make(map[string][2]int) + benchmarkMap := make(map[string][3]int) sem := make(chan struct{}, 5) // semaphore to limit concurrency var wg sync.WaitGroup // WaitGroup to wait for all goroutines to finish @@ -176,18 +208,17 @@ func TestCairoZeroFiles(t *testing.T) { } } -// Save the Benchmarks for the integration tests in `BenchMarks.txt` -func WriteBenchMarksToFile(benchmarkMap map[string][2]int) { - totalWidth := 123 +func WriteBenchMarksToFile(benchmarkMap map[string][3]int) { + totalWidth := 113 // Reduced width to adjust for long file names border := strings.Repeat("=", totalWidth) separator := strings.Repeat("-", totalWidth) var sb strings.Builder - w := tabwriter.NewWriter(&sb, 40, 0, 0, ' ', tabwriter.Debug) + w := tabwriter.NewWriter(&sb, 0, 0, 1, ' ', tabwriter.AlignRight) sb.WriteString(border + "\n") - fmt.Fprintln(w, "| File \t PythonVM (ms) \t GoVM (ms) \t") + fmt.Fprintf(w, "| %-40s | %-20s | %-20s | %-20s |\n", "File", "PythonVM (ms)", "GoVM (ms)", "RustVM (ms)") w.Flush() sb.WriteString(border + "\n") @@ -195,16 +226,13 @@ func WriteBenchMarksToFile(benchmarkMap map[string][2]int) { totalFiles := len(benchmarkMap) for key, values := range benchmarkMap { - row := "| " + key + "\t " - - for iter, value := range values { - row = row + strconv.Itoa(value) + "\t" - if iter == 0 { - row = row + " " - } + // Adjust the key length if it's too long + displayKey := key + if len(displayKey) > 40 { + displayKey = displayKey[:37] + "..." } - fmt.Fprintln(w, row) + fmt.Fprintf(w, "| %-40s | %-20d | %-20d | %-20d |\n", displayKey, values[0], values[1], values[2]) w.Flush() if iterator < totalFiles-1 { @@ -236,6 +264,8 @@ const ( compiledSuffix = "_compiled.json" pyTraceSuffix = "_py_trace" pyMemorySuffix = "_py_memory" + rsTraceSuffix = "_rs_trace" + rsMemorySuffix = "_rs_memory" traceSuffix = "_trace" memorySuffix = "_memory" ) @@ -323,6 +353,61 @@ func runPythonVm(testFilename, path string) (time.Duration, string, string, erro return elapsed, traceOutput, memoryOutput, nil } +// given a path to a compiled cairo zero file, execute it using the +// rust vm and return the trace and memory files location +func runRustVm(testFilename, path string) (time.Duration, string, string, error) { + traceOutput := swapExtenstion(path, rsTraceSuffix) + memoryOutput := swapExtenstion(path, rsMemorySuffix) + + args := []string{ + path, + "--proof_mode", + "--trace_file", + traceOutput, + "--memory_file", + memoryOutput, + } + + // If any other layouts are needed, add the suffix checks here. + // The convention would be: ".$layout.cairo" + // A file without this suffix will use the default ("plain") layout. + if strings.HasSuffix(testFilename, ".small.cairo") { + args = append(args, "--layout", "small") + } else if strings.HasSuffix(testFilename, ".dex.cairo") { + args = append(args, "--layout", "dex") + } else if strings.HasSuffix(testFilename, ".recursive.cairo") { + args = append(args, "--layout", "recursive") + } else if strings.HasSuffix(testFilename, ".starknet_with_keccak.cairo") { + args = append(args, "--layout", "starknet_with_keccak") + } else if strings.HasSuffix(testFilename, ".starknet.cairo") { + args = append(args, "--layout", "starknet") + } else if strings.HasSuffix(testFilename, ".recursive_large_output.cairo") { + args = append(args, "--layout", "recursive_large_output") + } else if strings.HasSuffix(testFilename, ".recursive_with_poseidon.cairo") { + args = append(args, "--layout", "recursive_with_poseidon") + } else if strings.HasSuffix(testFilename, ".all_solidity.cairo") { + args = append(args, "--layout", "all_solidity") + } else if strings.HasSuffix(testFilename, ".all_cairo.cairo") { + args = append(args, "--layout", "all_cairo") + } + + cmd := exec.Command("./../rust_vm_bin/cairo-vm-cli", args...) + + start := time.Now() + + res, err := cmd.CombinedOutput() + + elapsed := time.Since(start) + + if err != nil { + return 0, "", "", fmt.Errorf( + "./../rust_vm_bin/cairo-vm-cli %s: %w\n%s", path, err, string(res), + ) + } + + return elapsed, traceOutput, memoryOutput, nil +} + // given a path to a compiled cairo zero file, execute // it using our vm func runVm(path string) (time.Duration, string, string, string, error) { diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index b241a77b..269c8984 100755 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -24,7 +24,11 @@ func (hint *GenericZeroHinter) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu } func GetZeroHints(cairoZeroJson *zero.ZeroProgram) (map[uint64][]hinter.Hinter, error) { - hints := make(map[uint64][]hinter.Hinter) + numHints := 0 + for _, rawHints := range cairoZeroJson.Hints { + numHints += len(rawHints) + } + hints := make(map[uint64][]hinter.Hinter, numHints) for counter, rawHints := range cairoZeroJson.Hints { pc, err := strconv.ParseUint(counter, 10, 64) if err != nil { diff --git a/pkg/parsers/cairo_version/cairo_version.go b/pkg/parsers/cairo_version/cairo_version.go deleted file mode 100644 index c17104e3..00000000 --- a/pkg/parsers/cairo_version/cairo_version.go +++ /dev/null @@ -1,30 +0,0 @@ -package cairoversion - -import ( - "encoding/json" - "os" - "strconv" - "strings" -) - -type CairoVersion struct { - Version string `json:"compiler_version"` -} - -func GetCairoVersion(pathToFile string) (uint8, error) { - content, err := os.ReadFile(pathToFile) - if err != nil { - return 0, err - } - cv := CairoVersion{} - err = json.Unmarshal(content, &cv) - if err != nil { - return 0, err - } - firstNumberStr := strings.Split(cv.Version, ".")[0] - firstNumber, err := strconv.ParseUint(firstNumberStr, 10, 8) - if err != nil { - return 0, err - } - return uint8(firstNumber), nil -} diff --git a/pkg/parsers/zero/zero.go b/pkg/parsers/zero/zero.go index 1e2305e8..ceb32017 100644 --- a/pkg/parsers/zero/zero.go +++ b/pkg/parsers/zero/zero.go @@ -72,7 +72,7 @@ type ZeroProgram struct { Data []string `json:"data"` Builtins []builtins.BuiltinType `json:"builtins"` Hints map[string][]Hint `json:"hints"` - CompilerVersion string `json:"version"` + CompilerVersion string `json:"compiler_version"` MainScope string `json:"main_scope"` Identifiers map[string]*Identifier `json:"identifiers"` ReferenceManager ReferenceManager `json:"reference_manager"` @@ -95,6 +95,7 @@ type Identifier struct { Value any `json:"value"` } +// TODO: Do we really need this ? func (z ZeroProgram) MarshalToFile(filepath string) error { // Marshal Output struct into JSON bytes data, err := json.MarshalIndent(z, "", " ") diff --git a/pkg/runner/program.go b/pkg/runner/program.go index c00835c4..adc30a71 100644 --- a/pkg/runner/program.go +++ b/pkg/runner/program.go @@ -35,15 +35,7 @@ func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error) bytecode[i] = felt } - entrypoints, err := extractEntrypoints(cairoZeroJson) - if err != nil { - return nil, err - } - - labels, err := extractLabels(cairoZeroJson) - if err != nil { - return nil, err - } + entrypoints, labels := extractEntrypointsAndLabels(cairoZeroJson) return &ZeroProgram{ Bytecode: bytecode, @@ -53,49 +45,22 @@ func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error) }, nil } -func extractEntrypoints(json *zero.ZeroProgram) (map[string]uint64, error) { - result := make(map[string]uint64) - err := scanIdentifiers( - json, - func(key string, ident *zero.Identifier) error { - if ident.IdentifierType == "function" { - name := key[len(json.MainScope)+1:] - result[name] = uint64(ident.Pc) - } - return nil - }, - ) - - if err != nil { - return nil, fmt.Errorf("extracting entrypoints: %w", err) +func extractEntrypointsAndLabels(json *zero.ZeroProgram) (map[string]uint64, map[string]uint64) { + entrypoints := map[string]uint64{} + for key, ident := range json.Identifiers { + if ident.IdentifierType == "function" { + name := key[len(json.MainScope)+1:] + entrypoints[name] = uint64(ident.Pc) + } } - return result, nil -} -func extractLabels(json *zero.ZeroProgram) (map[string]uint64, error) { labels := make(map[string]uint64, 2) - err := scanIdentifiers( - json, - func(key string, ident *zero.Identifier) error { - if ident.IdentifierType == "label" { - name := key[len(json.MainScope)+1:] - labels[name] = uint64(ident.Pc) - } - return nil - }, - ) - if err != nil { - return nil, fmt.Errorf("extracting labels: %w", err) - } - - return labels, nil -} - -func scanIdentifiers(json *zero.ZeroProgram, f func(key string, ident *zero.Identifier) error) error { for key, ident := range json.Identifiers { - if err := f(key, ident); err != nil { - return err + if ident.IdentifierType == "label" { + name := key[len(json.MainScope)+1:] + labels[name] = uint64(ident.Pc) } } - return nil + + return entrypoints, labels } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index b1846307..4d1c2a64 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -84,7 +84,7 @@ func (runner *ZeroRunner) Run() error { return errors.New("cannot re-run using the same runner") } - end, err := runner.InitializeMainEntrypoint() + end, err := runner.initializeMainEntrypoint() if err != nil { return fmt.Errorf("initializing main entry point: %w", err) } @@ -116,9 +116,7 @@ func (runner *ZeroRunner) initializeSegments() (*mem.Memory, error) { return memory, nil } -// TODO: unexport it. It's only used inside this file and tests so far. -// We probably don't want various init API to leak outside (see #237 for more context). -func (runner *ZeroRunner) InitializeMainEntrypoint() (mem.MemoryAddress, error) { +func (runner *ZeroRunner) initializeMainEntrypoint() (mem.MemoryAddress, error) { memory, err := runner.initializeSegments() if err != nil { return mem.UnknownAddress, err @@ -215,7 +213,8 @@ func (runner *ZeroRunner) initializeVm( ) error { executionSegment := memory.Segments[vm.ExecutionSegment] offset := executionSegment.Len() - for idx := range stack { + stackSize := uint64(len(stack)) + for idx := uint64(0); idx < stackSize; idx++ { if err := executionSegment.Write(offset+uint64(idx), &stack[idx]); err != nil { return err } @@ -225,8 +224,8 @@ func (runner *ZeroRunner) initializeVm( // initialize vm runner.vm, err = vm.NewVirtualMachine(vm.Context{ Pc: *initialPC, - Ap: offset + uint64(len(stack)), - Fp: offset + uint64(len(stack)), + Ap: offset + stackSize, + Fp: offset + stackSize, }, memory, vm.VirtualMachineConfig{ProofMode: runner.proofmode, CollectTrace: runner.collectTrace}) return err } diff --git a/pkg/runner/runner_test.go b/pkg/runner/runner_test.go index 9a7274b1..abad43e2 100644 --- a/pkg/runner/runner_test.go +++ b/pkg/runner/runner_test.go @@ -30,7 +30,7 @@ func TestSimpleProgram(t *testing.T) { runner, err := NewRunner(program, hints, false, false, math.MaxUint64, "plain") require.NoError(t, err) - endPc, err := runner.InitializeMainEntrypoint() + endPc, err := runner.initializeMainEntrypoint() require.NoError(t, err) expectedPc := memory.MemoryAddress{SegmentIndex: 3, Offset: 0} @@ -77,7 +77,7 @@ func TestStepLimitExceeded(t *testing.T) { runner, err := NewRunner(program, hints, false, false, 3, "plain") require.NoError(t, err) - endPc, err := runner.InitializeMainEntrypoint() + endPc, err := runner.initializeMainEntrypoint() require.NoError(t, err) expectedPc := memory.MemoryAddress{SegmentIndex: 3, Offset: 0} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 8222b354..9c9dd19d 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -5,7 +5,7 @@ import ( "fmt" "math" - a "github.com/NethermindEth/cairo-vm-go/pkg/assembler" + asmb "github.com/NethermindEth/cairo-vm-go/pkg/assembler" "github.com/NethermindEth/cairo-vm-go/pkg/utils" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -85,7 +85,7 @@ type VirtualMachine struct { Trace []Context config VirtualMachineConfig // instructions cache - instructions map[uint64]*a.Instruction + instructions map[uint64]*asmb.Instruction // RcLimitsMin and RcLimitsMax define the range of values of instructions offsets, used for checking the number of potential range checks holes RcLimitsMin uint16 RcLimitsMax uint16 @@ -99,7 +99,7 @@ func NewVirtualMachine( // Initialize the trace if necesary var trace []Context if config.ProofMode || config.CollectTrace { - trace = make([]Context, 0) + trace = []Context{} } return &VirtualMachine{ @@ -107,7 +107,7 @@ func NewVirtualMachine( Memory: memory, Trace: trace, config: config, - instructions: make(map[uint64]*a.Instruction), + instructions: make(map[uint64]*asmb.Instruction), RcLimitsMin: math.MaxUint16, RcLimitsMax: 0, }, nil @@ -133,7 +133,7 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { return fmt.Errorf("reading instruction: %w", err) } - instruction, err = a.DecodeInstruction(bytecodeInstruction) + instruction, err = asmb.DecodeInstruction(bytecodeInstruction) if err != nil { return fmt.Errorf("decoding instruction: %w", err) } @@ -156,7 +156,7 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { const RC_OFFSET_BITS = 16 -func (vm *VirtualMachine) RunInstruction(instruction *a.Instruction) error { +func (vm *VirtualMachine) RunInstruction(instruction *asmb.Instruction) error { var off0 int = int(instruction.OffDest) + (1 << (RC_OFFSET_BITS - 1)) var off1 int = int(instruction.OffOp0) + (1 << (RC_OFFSET_BITS - 1)) @@ -219,9 +219,9 @@ func (vm *VirtualMachine) RunInstruction(instruction *a.Instruction) error { return nil } -func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddress, error) { +func (vm *VirtualMachine) getDstAddr(instruction *asmb.Instruction) (mem.MemoryAddress, error) { var dstRegister uint64 - if instruction.DstRegister == a.Ap { + if instruction.DstRegister == asmb.Ap { dstRegister = vm.Context.Ap } else { dstRegister = vm.Context.Fp @@ -234,9 +234,9 @@ func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddr return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: addr}, nil } -func (vm *VirtualMachine) getOp0Addr(instruction *a.Instruction) (mem.MemoryAddress, error) { +func (vm *VirtualMachine) getOp0Addr(instruction *asmb.Instruction) (mem.MemoryAddress, error) { var op0Register uint64 - if instruction.Op0Register == a.Ap { + if instruction.Op0Register == asmb.Ap { op0Register = vm.Context.Ap } else { op0Register = vm.Context.Fp @@ -250,10 +250,10 @@ func (vm *VirtualMachine) getOp0Addr(instruction *a.Instruction) (mem.MemoryAddr return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: addr}, nil } -func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.MemoryAddress) (mem.MemoryAddress, error) { +func (vm *VirtualMachine) getOp1Addr(instruction *asmb.Instruction, op0Addr *mem.MemoryAddress) (mem.MemoryAddress, error) { var op1Address mem.MemoryAddress switch instruction.Op1Source { - case a.Op0: + case asmb.Op0: // in this case Op0 is being used as an address, and must be of unwrapped as it op0Value, err := vm.Memory.ReadFromAddress(op0Addr) if err != nil { @@ -265,11 +265,11 @@ func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.Me return mem.UnknownAddress, fmt.Errorf("op0 is not an address: %w", err) } op1Address = mem.MemoryAddress{SegmentIndex: op0Address.SegmentIndex, Offset: op0Address.Offset} - case a.Imm: + case asmb.Imm: op1Address = vm.Context.AddressPc() - case a.FpPlusOffOp1: + case asmb.FpPlusOffOp1: op1Address = vm.Context.AddressFp() - case a.ApPlusOffOp1: + case asmb.ApPlusOffOp1: op1Address = vm.Context.AddressAp() } @@ -283,13 +283,13 @@ func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.Me // when there is an assertion with a substraction or division like : x = y - z // the compiler treats it as y = x + z. This means that the VM knows the -// dstCell value and either op0Cell xor op1Cell. This function infers the +// dstCell value and either op0Cell or op1Cell. This function infers the // unknow operand as well as the `res` auxiliar value func (vm *VirtualMachine) inferOperand( - instruction *a.Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, + instruction *asmb.Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, ) (mem.MemoryValue, error) { - if instruction.Opcode != a.OpCodeAssertEq || - instruction.Res == a.Unconstrained || + if instruction.Opcode != asmb.OpCodeAssertEq || + instruction.Res == asmb.Unconstrained || !vm.Memory.KnownValueAtAddress(dstAddr) { return mem.MemoryValue{}, nil } @@ -310,7 +310,7 @@ func (vm *VirtualMachine) inferOperand( return mem.MemoryValue{}, nil } - if instruction.Res == a.Op1 && !op1Value.Known() { + if instruction.Res == asmb.Op1 && !op1Value.Known() { if err = vm.Memory.WriteToAddress(op1Addr, &dstValue); err != nil { return mem.MemoryValue{}, err } @@ -328,7 +328,7 @@ func (vm *VirtualMachine) inferOperand( } var missingVal mem.MemoryValue - if instruction.Res == a.AddOperands { + if instruction.Res == asmb.AddOperands { missingVal = mem.EmptyMemoryValueAs(dstValue.IsAddress()) err = missingVal.Sub(&dstValue, &knownOpValue) } else { @@ -346,12 +346,12 @@ func (vm *VirtualMachine) inferOperand( } func (vm *VirtualMachine) computeRes( - instruction *a.Instruction, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, + instruction *asmb.Instruction, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, ) (mem.MemoryValue, error) { switch instruction.Res { - case a.Unconstrained: + case asmb.Unconstrained: return mem.MemoryValue{}, nil - case a.Op1: + case asmb.Op1: op1, err := vm.Memory.ReadFromAddress(op1Addr) if err != nil { return mem.UnknownValue, fmt.Errorf("cannot read op1: %w", err) @@ -370,9 +370,9 @@ func (vm *VirtualMachine) computeRes( } res := mem.EmptyMemoryValueAs(op0.IsAddress() || op1.IsAddress()) - if instruction.Res == a.AddOperands { + if instruction.Res == asmb.AddOperands { err = res.Add(&op0, &op1) - } else if instruction.Res == a.MulOperands { + } else if instruction.Res == asmb.MulOperands { err = res.Mul(&op0, &op1) } else { return mem.MemoryValue{}, fmt.Errorf("invalid res flag value: %d", instruction.Res) @@ -382,13 +382,13 @@ func (vm *VirtualMachine) computeRes( } func (vm *VirtualMachine) opcodeAssertions( - instruction *a.Instruction, + instruction *asmb.Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, res *mem.MemoryValue, ) error { switch instruction.Opcode { - case a.OpCodeCall: + case asmb.OpCodeCall: fpAddr := vm.Context.AddressFp() fpMv := mem.MemoryValueFromMemoryAddress(&fpAddr) // Store at [ap] the current fp @@ -404,7 +404,7 @@ func (vm *VirtualMachine) opcodeAssertions( if err := vm.Memory.WriteToAddress(op0Addr, &apMv); err != nil { return err } - case a.OpCodeAssertEq: + case asmb.OpCodeAssertEq: // assert that the calculated res is stored in dst if err := vm.Memory.WriteToAddress(dstAddr, res); err != nil { return err @@ -414,18 +414,18 @@ func (vm *VirtualMachine) opcodeAssertions( } func (vm *VirtualMachine) updatePc( - instruction *a.Instruction, + instruction *asmb.Instruction, dstAddr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, res *mem.MemoryValue, ) (mem.MemoryAddress, error) { switch instruction.PcUpdate { - case a.PcUpdateNextInstr: + case asmb.PcUpdateNextInstr: return mem.MemoryAddress{ SegmentIndex: vm.Context.Pc.SegmentIndex, Offset: vm.Context.Pc.Offset + uint64(instruction.Size()), }, nil - case a.PcUpdateJump: + case asmb.PcUpdateJump: // both address and felt are allowed here. It can be a felt when used // with an immediate or a memory address holding a felt. It can be an address // when a memory address holds a memory address @@ -441,7 +441,7 @@ func (vm *VirtualMachine) updatePc( fmt.Errorf("absolute jump: invalid jump location: %w", err) } - case a.PcUpdateJumpRel: + case asmb.PcUpdateJumpRel: val, err := res.FieldElement() if err != nil { return mem.UnknownAddress, fmt.Errorf("relative jump: %w", err) @@ -449,7 +449,7 @@ func (vm *VirtualMachine) updatePc( newPc := vm.Context.Pc err = newPc.Add(&newPc, val) return newPc, err - case a.PcUpdateJnz: + case asmb.PcUpdateJnz: destMv, err := vm.Memory.ReadFromAddress(dstAddr) if err != nil { return mem.UnknownAddress, err @@ -484,11 +484,11 @@ func (vm *VirtualMachine) updatePc( return mem.UnknownAddress, fmt.Errorf("unkwon pc update value: %d", instruction.PcUpdate) } -func (vm *VirtualMachine) updateAp(instruction *a.Instruction, res *mem.MemoryValue) (uint64, error) { +func (vm *VirtualMachine) updateAp(instruction *asmb.Instruction, res *mem.MemoryValue) (uint64, error) { switch instruction.ApUpdate { - case a.SameAp: + case asmb.SameAp: return vm.Context.Ap, nil - case a.AddRes: + case asmb.AddRes: apFelt := new(f.Element).SetUint64(vm.Context.Ap) // Convert ap value to felt resFelt, err := res.FieldElement() // Extract the f.Element from MemoryValue @@ -501,20 +501,20 @@ func (vm *VirtualMachine) updateAp(instruction *a.Instruction, res *mem.MemoryVa return 0, fmt.Errorf("resulting AP value is too large to fit in uint64") } return newAp.Uint64(), nil // Return the addition as uint64 - case a.Add1: + case asmb.Add1: return vm.Context.Ap + 1, nil - case a.Add2: + case asmb.Add2: return vm.Context.Ap + 2, nil } return 0, fmt.Errorf("cannot update ap, unknown ApUpdate flag: %d", instruction.ApUpdate) } -func (vm *VirtualMachine) updateFp(instruction *a.Instruction, dstAddr *mem.MemoryAddress) (uint64, error) { +func (vm *VirtualMachine) updateFp(instruction *asmb.Instruction, dstAddr *mem.MemoryAddress) (uint64, error) { switch instruction.Opcode { - case a.OpCodeCall: + case asmb.OpCodeCall: // [ap] and [ap + 1] are written to memory return vm.Context.Ap + 2, nil - case a.OpCodeRet: + case asmb.OpCodeRet: // [dst] should be a memory address of the form (executionSegment, fp - 2) destMv, err := vm.Memory.ReadFromAddress(dstAddr) if err != nil { diff --git a/rust_vm_bin/cairo-vm-cli b/rust_vm_bin/cairo-vm-cli new file mode 100755 index 00000000..afa56ea9 Binary files /dev/null and b/rust_vm_bin/cairo-vm-cli differ