Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and a few optimizations #671

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
hintrunner "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/zero"
cairoversion "github.com/NethermindEth/cairo-vm-go/pkg/parsers/cairo_version"
"github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet"
zero "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero"
"github.com/NethermindEth/cairo-vm-go/pkg/runner"
Expand Down Expand Up @@ -110,18 +109,14 @@ func main() {
if pathToFile == "" {
return fmt.Errorf("path to cairo file not set")
}
cairoVersion, err := cairoversion.GetCairoVersion(pathToFile)
if err != nil {
return fmt.Errorf("cannot get cairo version: %w", err)
}
fmt.Printf("Loading program at %s\n", pathToFile)
zeroProgram, err := zero.ZeroProgramFromFile(pathToFile)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
}

var hints map[uint64][]hinter.Hinter
if cairoVersion > 0 {
if zeroProgram.CompilerVersion[0] == '1' {
cairoProgram, err := starknet.StarknetProgramFromFile(pathToFile)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
Expand Down
Sh0g0-1758 marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
119 changes: 102 additions & 17 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -53,7 +52,7 @@ func (f *Filter) filtered(testFile string) bool {
return false
}

func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][2]int, benchmark bool, errorExpected bool) {
func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][3]int, benchmark bool, errorExpected bool) {
t.Logf("testing: %s\n", path)

compiledOutput, err := compileZeroCode(path)
Expand All @@ -73,6 +72,17 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
}
}

elapsedRs, rsTraceFile, rsMemoryFile, err := runRustVm(name, compiledOutput)
if errorExpected {
// we let the code go on so that we can check if the go vm also raises an error
assert.Error(t, err, path)
} else {
if err != nil {
t.Error(err)
return
}
}

elapsedGo, traceFile, memoryFile, _, err := runVm(compiledOutput)
if errorExpected {
assert.Error(t, err, path)
Expand All @@ -85,7 +95,7 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
}

if benchmark {
benchmarkMap[name] = [2]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds())}
benchmarkMap[name] = [3]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds()), int(elapsedRs.Milliseconds())}
}

pyTrace, pyMemory, err := decodeProof(pyTraceFile, pyMemoryFile)
Expand All @@ -100,6 +110,20 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
return
}

rsTrace, rsMemory, err := decodeProof(rsTraceFile, rsMemoryFile)
if err != nil {
t.Error(err)
return
}

if !assert.Equal(t, pyTrace, rsTrace) {
t.Logf("pytrace:\n%s\n", traceRepr(pyTrace))
t.Logf("rstrace:\n%s\n", traceRepr(rsTrace))
}
if !assert.Equal(t, pyMemory, rsMemory) {
t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory))
t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory))
}
Sh0g0-1758 marked this conversation as resolved.
Show resolved Hide resolved
if !assert.Equal(t, pyTrace, trace) {
t.Logf("pytrace:\n%s\n", traceRepr(pyTrace))
t.Logf("trace:\n%s\n", traceRepr(trace))
Expand All @@ -108,6 +132,14 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory))
t.Logf("memory;\n%s\n", memoryRepr(memory))
}
if !assert.Equal(t, rsTrace, trace) {
t.Logf("rstrace:\n%s\n", traceRepr(rsTrace))
t.Logf("trace:\n%s\n", traceRepr(trace))
}
if !assert.Equal(t, rsMemory, memory) {
t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory))
t.Logf("memory;\n%s\n", memoryRepr(memory))
}
}

var zerobench = flag.Bool("zerobench", false, "run integration tests and generate benchmarks file")
Expand All @@ -123,7 +155,7 @@ func TestCairoZeroFiles(t *testing.T) {
filter := Filter{}
filter.init()

benchmarkMap := make(map[string][2]int)
benchmarkMap := make(map[string][3]int)

sem := make(chan struct{}, 5) // semaphore to limit concurrency
var wg sync.WaitGroup // WaitGroup to wait for all goroutines to finish
Expand Down Expand Up @@ -176,35 +208,31 @@ func TestCairoZeroFiles(t *testing.T) {
}
}

// Save the Benchmarks for the integration tests in `BenchMarks.txt`
func WriteBenchMarksToFile(benchmarkMap map[string][2]int) {
totalWidth := 123
func WriteBenchMarksToFile(benchmarkMap map[string][3]int) {
totalWidth := 113 // Reduced width to adjust for long file names

border := strings.Repeat("=", totalWidth)
separator := strings.Repeat("-", totalWidth)

var sb strings.Builder
w := tabwriter.NewWriter(&sb, 40, 0, 0, ' ', tabwriter.Debug)
w := tabwriter.NewWriter(&sb, 0, 0, 1, ' ', tabwriter.AlignRight)

sb.WriteString(border + "\n")
fmt.Fprintln(w, "| File \t PythonVM (ms) \t GoVM (ms) \t")
fmt.Fprintf(w, "| %-40s | %-20s | %-20s | %-20s |\n", "File", "PythonVM (ms)", "GoVM (ms)", "RustVM (ms)")
w.Flush()
sb.WriteString(border + "\n")

iterator := 0
totalFiles := len(benchmarkMap)

for key, values := range benchmarkMap {
row := "| " + key + "\t "

for iter, value := range values {
row = row + strconv.Itoa(value) + "\t"
if iter == 0 {
row = row + " "
}
// Adjust the key length if it's too long
displayKey := key
if len(displayKey) > 40 {
displayKey = displayKey[:37] + "..."
}

fmt.Fprintln(w, row)
fmt.Fprintf(w, "| %-40s | %-20d | %-20d | %-20d |\n", displayKey, values[0], values[1], values[2])
w.Flush()

if iterator < totalFiles-1 {
Expand Down Expand Up @@ -236,6 +264,8 @@ const (
compiledSuffix = "_compiled.json"
pyTraceSuffix = "_py_trace"
pyMemorySuffix = "_py_memory"
rsTraceSuffix = "_rs_trace"
rsMemorySuffix = "_rs_memory"
traceSuffix = "_trace"
memorySuffix = "_memory"
)
Expand Down Expand Up @@ -323,6 +353,61 @@ func runPythonVm(testFilename, path string) (time.Duration, string, string, erro
return elapsed, traceOutput, memoryOutput, nil
}

// given a path to a compiled cairo zero file, execute it using the
// rust vm and return the trace and memory files location
func runRustVm(testFilename, path string) (time.Duration, string, string, error) {
traceOutput := swapExtenstion(path, rsTraceSuffix)
memoryOutput := swapExtenstion(path, rsMemorySuffix)

args := []string{
path,
"--proof_mode",
"--trace_file",
traceOutput,
"--memory_file",
memoryOutput,
}

// If any other layouts are needed, add the suffix checks here.
// The convention would be: ".$layout.cairo"
// A file without this suffix will use the default ("plain") layout.
if strings.HasSuffix(testFilename, ".small.cairo") {
args = append(args, "--layout", "small")
} else if strings.HasSuffix(testFilename, ".dex.cairo") {
args = append(args, "--layout", "dex")
} else if strings.HasSuffix(testFilename, ".recursive.cairo") {
args = append(args, "--layout", "recursive")
} else if strings.HasSuffix(testFilename, ".starknet_with_keccak.cairo") {
args = append(args, "--layout", "starknet_with_keccak")
} else if strings.HasSuffix(testFilename, ".starknet.cairo") {
args = append(args, "--layout", "starknet")
} else if strings.HasSuffix(testFilename, ".recursive_large_output.cairo") {
args = append(args, "--layout", "recursive_large_output")
} else if strings.HasSuffix(testFilename, ".recursive_with_poseidon.cairo") {
args = append(args, "--layout", "recursive_with_poseidon")
} else if strings.HasSuffix(testFilename, ".all_solidity.cairo") {
args = append(args, "--layout", "all_solidity")
} else if strings.HasSuffix(testFilename, ".all_cairo.cairo") {
args = append(args, "--layout", "all_cairo")
}

cmd := exec.Command("./../rust_vm_bin/cairo-vm-cli", args...)

start := time.Now()

res, err := cmd.CombinedOutput()

elapsed := time.Since(start)

if err != nil {
return 0, "", "", fmt.Errorf(
"./../rust_vm_bin/cairo-vm-cli %s: %w\n%s", path, err, string(res),
)
}

return elapsed, traceOutput, memoryOutput, nil
}

// given a path to a compiled cairo zero file, execute
// it using our vm
func runVm(path string) (time.Duration, string, string, string, error) {
Expand Down
6 changes: 5 additions & 1 deletion pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ func (hint *GenericZeroHinter) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu
}

func GetZeroHints(cairoZeroJson *zero.ZeroProgram) (map[uint64][]hinter.Hinter, error) {
hints := make(map[uint64][]hinter.Hinter)
numHints := 0
for _, rawHints := range cairoZeroJson.Hints {
numHints += len(rawHints)
}
hints := make(map[uint64][]hinter.Hinter, numHints)
Sh0g0-1758 marked this conversation as resolved.
Show resolved Hide resolved
for counter, rawHints := range cairoZeroJson.Hints {
pc, err := strconv.ParseUint(counter, 10, 64)
if err != nil {
Expand Down
30 changes: 0 additions & 30 deletions pkg/parsers/cairo_version/cairo_version.go

This file was deleted.

3 changes: 2 additions & 1 deletion pkg/parsers/zero/zero.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type ZeroProgram struct {
Data []string `json:"data"`
Builtins []builtins.BuiltinType `json:"builtins"`
Hints map[string][]Hint `json:"hints"`
CompilerVersion string `json:"version"`
CompilerVersion string `json:"compiler_version"`
MainScope string `json:"main_scope"`
Identifiers map[string]*Identifier `json:"identifiers"`
ReferenceManager ReferenceManager `json:"reference_manager"`
Expand All @@ -95,6 +95,7 @@ type Identifier struct {
Value any `json:"value"`
}

// TODO: Do we really need this ?
func (z ZeroProgram) MarshalToFile(filepath string) error {
// Marshal Output struct into JSON bytes
data, err := json.MarshalIndent(z, "", " ")
Expand Down
61 changes: 13 additions & 48 deletions pkg/runner/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,7 @@ func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error)
bytecode[i] = felt
}

entrypoints, err := extractEntrypoints(cairoZeroJson)
if err != nil {
return nil, err
}

labels, err := extractLabels(cairoZeroJson)
if err != nil {
return nil, err
}
entrypoints, labels := extractEntrypointsAndLabels(cairoZeroJson)

return &ZeroProgram{
Bytecode: bytecode,
Expand All @@ -53,49 +45,22 @@ func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error)
}, nil
}

func extractEntrypoints(json *zero.ZeroProgram) (map[string]uint64, error) {
result := make(map[string]uint64)
err := scanIdentifiers(
json,
func(key string, ident *zero.Identifier) error {
if ident.IdentifierType == "function" {
name := key[len(json.MainScope)+1:]
result[name] = uint64(ident.Pc)
}
return nil
},
)

if err != nil {
return nil, fmt.Errorf("extracting entrypoints: %w", err)
func extractEntrypointsAndLabels(json *zero.ZeroProgram) (map[string]uint64, map[string]uint64) {
entrypoints := map[string]uint64{}
for key, ident := range json.Identifiers {
if ident.IdentifierType == "function" {
name := key[len(json.MainScope)+1:]
entrypoints[name] = uint64(ident.Pc)
}
}
return result, nil
}

func extractLabels(json *zero.ZeroProgram) (map[string]uint64, error) {
labels := make(map[string]uint64, 2)
err := scanIdentifiers(
json,
func(key string, ident *zero.Identifier) error {
if ident.IdentifierType == "label" {
name := key[len(json.MainScope)+1:]
labels[name] = uint64(ident.Pc)
}
return nil
},
)
if err != nil {
return nil, fmt.Errorf("extracting labels: %w", err)
}

return labels, nil
}

func scanIdentifiers(json *zero.ZeroProgram, f func(key string, ident *zero.Identifier) error) error {
for key, ident := range json.Identifiers {
if err := f(key, ident); err != nil {
return err
if ident.IdentifierType == "label" {
name := key[len(json.MainScope)+1:]
labels[name] = uint64(ident.Pc)
}
}
return nil

return entrypoints, labels
}
Loading
Loading