From 0f2491eeb6ed240d162cb27c839f1b12c4a2a4a9 Mon Sep 17 00:00:00 2001 From: Shourya Goel Date: Wed, 21 Aug 2024 01:30:58 +0530 Subject: [PATCH] Implement `Rangecheck96` Builtin (#640) * Added functionality * Added support for starkent * Added Functionality * nit * pass integration tests * nit --- pkg/hintrunner/core/hint_benchmark_test.go | 2 +- pkg/hintrunner/core/hint_test.go | 2 +- pkg/parsers/starknet/starknet.go | 5 +++ pkg/runners/zero/zero.go | 2 +- pkg/vm/builtins/builtin_runner.go | 6 ++-- pkg/vm/builtins/ecdsa.go | 2 +- pkg/vm/builtins/layouts.go | 4 +-- pkg/vm/builtins/range_check.go | 34 +++++++++++++++----- pkg/vm/builtins/range_check_test.go | 36 +++++++++++++++++++--- 9 files changed, 73 insertions(+), 20 deletions(-) diff --git a/pkg/hintrunner/core/hint_benchmark_test.go b/pkg/hintrunner/core/hint_benchmark_test.go index c70fccd90..e121919fe 100644 --- a/pkg/hintrunner/core/hint_benchmark_test.go +++ b/pkg/hintrunner/core/hint_benchmark_test.go @@ -368,7 +368,7 @@ func BenchmarkAssertLeFindSmallArc(b *testing.B) { rand := utils.DefaultRandGenerator() ctx := hinter.SetContextWithScope(map[string]any{"excluded": 0}) - rangeCheckPtr := vm.Memory.AllocateBuiltinSegment(&builtins.RangeCheck{}) + rangeCheckPtr := vm.Memory.AllocateBuiltinSegment(&builtins.RangeCheck{RangeCheckNParts: 8}) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/pkg/hintrunner/core/hint_test.go b/pkg/hintrunner/core/hint_test.go index 0fce1b0ef..9f0c2dbe1 100644 --- a/pkg/hintrunner/core/hint_test.go +++ b/pkg/hintrunner/core/hint_test.go @@ -1115,7 +1115,7 @@ func TestAssertLeFindSmallArc(t *testing.T) { vm.Context.Ap = 0 vm.Context.Fp = 0 // The addr that the range check pointer will point to - addr := vm.Memory.AllocateBuiltinSegment(&builtins.RangeCheck{}) + addr := vm.Memory.AllocateBuiltinSegment(&builtins.RangeCheck{RangeCheckNParts: 8}) utils.WriteTo(vm, VM.ExecutionSegment, vm.Context.Ap, mem.MemoryValueFromMemoryAddress(&addr)) hint := AssertLeFindSmallArc{ diff --git a/pkg/parsers/starknet/starknet.go b/pkg/parsers/starknet/starknet.go index 2423effc8..13868ae71 100644 --- a/pkg/parsers/starknet/starknet.go +++ b/pkg/parsers/starknet/starknet.go @@ -21,6 +21,7 @@ const ( ECOP Poseidon SegmentArena + RangeCheck96 ) func (b Builtin) MarshalJSON() ([]byte, error) { @@ -29,6 +30,8 @@ func (b Builtin) MarshalJSON() ([]byte, error) { return []byte("output"), nil case RangeCheck: return []byte("range_check"), nil + case RangeCheck96: + return []byte("range_check96"), nil case Pedersen: return []byte("pedersen"), nil case ECDSA: @@ -59,6 +62,8 @@ func (b *Builtin) UnmarshalJSON(data []byte) error { *b = Output case "range_check": *b = RangeCheck + case "range_check96": + *b = RangeCheck96 case "pedersen": *b = Pedersen case "ecdsa": diff --git a/pkg/runners/zero/zero.go b/pkg/runners/zero/zero.go index 9980524da..b3e8c161a 100644 --- a/pkg/runners/zero/zero.go +++ b/pkg/runners/zero/zero.go @@ -319,7 +319,7 @@ func (runner *ZeroRunner) checkRangeCheckUsage() error { } rangeCheckSegment, ok := runner.vm.Memory.FindSegmentWithBuiltin(rangeCheckRunner.String()) if ok { - rcUnitsUsedByBuiltins += rangeCheckSegment.Len() * builtins.RANGE_CHECK_N_PARTS + rcUnitsUsedByBuiltins += rangeCheckSegment.Len() * rangeCheckRunner.RangeCheckNParts } } } diff --git a/pkg/vm/builtins/builtin_runner.go b/pkg/vm/builtins/builtin_runner.go index 7e13d7bb2..77b74fa00 100644 --- a/pkg/vm/builtins/builtin_runner.go +++ b/pkg/vm/builtins/builtin_runner.go @@ -14,7 +14,9 @@ func Runner(name starknetParser.Builtin) memory.BuiltinRunner { case starknetParser.Output: return &Output{} case starknetParser.RangeCheck: - return &RangeCheck{} + return &RangeCheck{0, 8} + case starknetParser.RangeCheck96: + return &RangeCheck{0, 6} case starknetParser.Pedersen: return &Pedersen{} case starknetParser.ECDSA: @@ -51,7 +53,7 @@ func GetBuiltinAllocatedInstances(ratio uint64, cellsPerInstance uint64, segment } minSteps := ratio * instancesPerComponent if vmCurrentStep < minSteps { - return 0, fmt.Errorf("Number of steps must be at least %d. Current step: %d", minSteps, vmCurrentStep) + return 0, fmt.Errorf("number of steps must be at least %d. Current step: %d", minSteps, vmCurrentStep) } return vmCurrentStep / ratio, nil } diff --git a/pkg/vm/builtins/ecdsa.go b/pkg/vm/builtins/ecdsa.go index 9b11e6f2c..02fe7f3e5 100644 --- a/pkg/vm/builtins/ecdsa.go +++ b/pkg/vm/builtins/ecdsa.go @@ -156,7 +156,7 @@ func recoverY(x *fp.Element) (fp.Element, fp.Element, error) { x2.Add(x2, &utils.Beta) y := x2.Sqrt(x2) if y == nil { - return fp.Element{}, fp.Element{}, fmt.Errorf("Invalid Public key") + return fp.Element{}, fp.Element{}, fmt.Errorf("invalid Public key") } negY := fp.Element{} negY.Neg(y) diff --git a/pkg/vm/builtins/layouts.go b/pkg/vm/builtins/layouts.go index 32bc3bcbd..2a464ff42 100644 --- a/pkg/vm/builtins/layouts.go +++ b/pkg/vm/builtins/layouts.go @@ -28,7 +28,7 @@ func getSmallLayout() Layout { return Layout{Name: "small", RcUnits: 16, Builtins: []LayoutBuiltin{ {Runner: &Output{}, Builtin: starknet.Output}, {Runner: &Pedersen{ratio: 8}, Builtin: starknet.Pedersen}, - {Runner: &RangeCheck{ratio: 8}, Builtin: starknet.RangeCheck}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: starknet.RangeCheck}, {Runner: &ECDSA{ratio: 512}, Builtin: starknet.ECDSA}, }} } @@ -41,7 +41,7 @@ func getStarknetWithKeccakLayout() Layout { return Layout{Name: "starknet_with_keccak", RcUnits: 4, Builtins: []LayoutBuiltin{ {Runner: &Output{}, Builtin: starknet.Output}, {Runner: &Pedersen{ratio: 32}, Builtin: starknet.Pedersen}, - {Runner: &RangeCheck{ratio: 16}, Builtin: starknet.RangeCheck}, + {Runner: &RangeCheck{ratio: 16, RangeCheckNParts: 8}, Builtin: starknet.RangeCheck}, {Runner: &ECDSA{ratio: 2048}, Builtin: starknet.ECDSA}, {Runner: &Bitwise{ratio: 64}, Builtin: starknet.Bitwise}, {Runner: &EcOp{ratio: 1024, cache: make(map[uint64]fp.Element)}, Builtin: starknet.ECOP}, diff --git a/pkg/vm/builtins/range_check.go b/pkg/vm/builtins/range_check.go index e968cfa79..7670bef41 100644 --- a/pkg/vm/builtins/range_check.go +++ b/pkg/vm/builtins/range_check.go @@ -7,20 +7,20 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -const RangeCheckName = "range_check" const inputCellsPerRangeCheck = 1 const cellsPerRangeCheck = 1 const instancesPerComponentRangeCheck = 1 -// Each range check instance consists of RANGE_CHECK_N_PARTS 16-bit parts. INNER_RC_BOUND_SHIFT and INNER_RC_BOUND_MASK are used to extract 16-bit parts from the field elements stored in the range check segment. +// Each range check instance consists of RangeCheckNParts 16-bit parts. INNER_RC_BOUND_SHIFT and INNER_RC_BOUND_MASK are used to extract 16-bit parts from the field elements stored in the range check segment. const INNER_RC_BOUND_SHIFT = 16 const INNER_RC_BOUND_MASK = (1 << 16) - 1 -const RANGE_CHECK_N_PARTS = 8 type RangeCheck struct { - ratio uint64 + ratio uint64 + RangeCheckNParts uint64 } func (r *RangeCheck) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error { @@ -29,10 +29,24 @@ func (r *RangeCheck) CheckWrite(segment *memory.Segment, offset uint64, value *m return fmt.Errorf("check write: %w", err) } - // felt >= (2^128) - if felt.Cmp(&utils.FeltMax128) != -1 { - return fmt.Errorf("check write: 2**128 < %s", value) + if r.RangeCheckNParts == 6 { + // 2**96 + BOUND_96, err := new(fp.Element).SetString("79228162514264337593543950336") + if err != nil { + return fmt.Errorf("check write: %w", err) + } + + // felt >= (2^96) + if felt.Cmp(BOUND_96) != -1 { + return fmt.Errorf("check write: 2**96 < %s", value) + } + } else { + // felt >= (2^128) + if felt.Cmp(&utils.FeltMax128) != -1 { + return fmt.Errorf("check write: 2**128 < %s", value) + } } + return nil } @@ -41,7 +55,11 @@ func (r *RangeCheck) InferValue(segment *memory.Segment, offset uint64) error { } func (r *RangeCheck) String() string { - return RangeCheckName + if r.RangeCheckNParts == 6 { + return "range_check96" + } else { + return "range_check" + } } func (r *RangeCheck) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) { diff --git a/pkg/vm/builtins/range_check_test.go b/pkg/vm/builtins/range_check_test.go index 57b2244b1..3bc4bb027 100644 --- a/pkg/vm/builtins/range_check_test.go +++ b/pkg/vm/builtins/range_check_test.go @@ -10,13 +10,13 @@ import ( ) func TestRangeCheckWriteMemoryAddress(t *testing.T) { - builtin := RangeCheck{} + builtin := RangeCheck{0, 8} memoryAddress := memory.EmptyMemoryValueAsAddress() assert.Error(t, builtin.CheckWrite(nil, 0, &memoryAddress)) } func TestRangeCheckWriteOutOfRange(t *testing.T) { - builtin := RangeCheck{} + builtin := RangeCheck{0, 8} outOfRangeValueFelt, err := new(fp.Element).SetString("0x100000000000000000000000000000001") require.NoError(t, err) outOfRangeValue := memory.MemoryValueFromFieldElement(outOfRangeValueFelt) @@ -24,7 +24,7 @@ func TestRangeCheckWriteOutOfRange(t *testing.T) { } func TestRangeCheckWrite(t *testing.T) { - builtin := RangeCheck{} + builtin := RangeCheck{0, 8} f, err := new(fp.Element).SetString("0x44") require.NoError(t, err) v := memory.MemoryValueFromFieldElement(f) @@ -32,7 +32,35 @@ func TestRangeCheckWrite(t *testing.T) { } func TestRangeCheckInfer(t *testing.T) { - builtin := RangeCheck{} + builtin := RangeCheck{0, 8} + segment := memory.EmptySegmentWithLength(3) + assert.ErrorContains(t, builtin.InferValue(segment, 0), "cannot infer value") +} + +func TestRangeCheck96WriteMemoryAddress(t *testing.T) { + builtin := RangeCheck{0, 6} + memoryAddress := memory.EmptyMemoryValueAsAddress() + assert.Error(t, builtin.CheckWrite(nil, 0, &memoryAddress)) +} + +func TestRangeCheck96WriteOutOfRange(t *testing.T) { + builtin := RangeCheck{0, 6} + outOfRangeValueFelt, err := new(fp.Element).SetString("40564819207303340847894502572032") + require.NoError(t, err) + outOfRangeValue := memory.MemoryValueFromFieldElement(outOfRangeValueFelt) + assert.Error(t, builtin.CheckWrite(nil, 0, &outOfRangeValue)) +} + +func TestRangeCheck96Write(t *testing.T) { + builtin := RangeCheck{0, 6} + f, err := new(fp.Element).SetString("19342813113834066795298816") + require.NoError(t, err) + v := memory.MemoryValueFromFieldElement(f) + assert.NoError(t, builtin.CheckWrite(nil, 0, &v)) +} + +func TestRangeCheck96Infer(t *testing.T) { + builtin := RangeCheck{0, 6} segment := memory.EmptySegmentWithLength(3) assert.ErrorContains(t, builtin.InferValue(segment, 0), "cannot infer value") }