Skip to content

Commit

Permalink
Fix dict implementation (#575)
Browse files Browse the repository at this point in the history
* Update implementation with MemoryValue

* Fix more resolves
  • Loading branch information
har777 authored Jul 24, 2024
1 parent 13a2f14 commit b26fe18
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 136 deletions.
124 changes: 65 additions & 59 deletions integration_tests/BenchMarks.txt

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions integration_tests/cairo_zero_hint_tests/dict_store_cast_ptr.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// from lambdaclass vm tests

from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.default_dict import default_dict_new
from starkware.cairo.common.dict import dict_write, dict_read

struct Structure {
a: felt,
b: felt,
c: felt,
}

func main() {
// Create dictionary
let (dictionary: DictAccess*) = default_dict_new(default_value=0);
// Create & initialize struct_pointer
let (struct_ptr: Structure*) = alloc();
assert struct_ptr[0] = Structure(1, 2, 3);
// Cast ptr to felt and store it in the dictionary
tempvar struct_ptr_cast_felt = cast(struct_ptr, felt);
dict_write{dict_ptr=dictionary}(key=0, new_value=struct_ptr_cast_felt);
// Read the casted ptr from the dictionary and compare it to the one we stored
let read_struct_ptr_cast_felt: felt = dict_read{dict_ptr=dictionary}(key=0);
assert struct_ptr_cast_felt = read_struct_ptr_cast_felt;
// Cast he ptr back to Structure* and check its value
let read_struct_ptr_cast_struct_ptr = cast(read_struct_ptr_cast_felt, Structure*);
assert struct_ptr = read_struct_ptr_cast_struct_ptr;
// Confirm that the ptr still leads to the data we initialized
assert read_struct_ptr_cast_struct_ptr[0].a = 1;
assert read_struct_ptr_cast_struct_ptr[0].b = 2;
assert read_struct_ptr_cast_struct_ptr[0].c = 3;
// Now we do the same, but we read the struct_ptr from the dictionary as a Struct*
// without an explicit cast
let read_struct_ptr: Structure* = dict_read{dict_ptr=dictionary}(key=0);
assert struct_ptr = read_struct_ptr;
// Confirm that the ptr still leads to the data we initialized
assert read_struct_ptr[0].a = 1;
assert read_struct_ptr[0].b = 2;
assert read_struct_ptr[0].c = 3;
return ();
}
15 changes: 7 additions & 8 deletions pkg/hintrunner/hinter/zero_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@ import (

VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

// Used to keep track of all Dictionaries data
type ZeroDictionary struct {
// The Data contained in a dictionary
Data *map[fp.Element]mem.MemoryValue
Data *map[mem.MemoryValue]mem.MemoryValue
// Default value for key not present in the dictionary
DefaultValue *mem.MemoryValue
// first free offset in memory segment of dictionary
FreeOffset *uint64
}

// Gets the memory value at certain key
func (d *ZeroDictionary) at(key fp.Element) (mem.MemoryValue, error) {
func (d *ZeroDictionary) at(key mem.MemoryValue) (mem.MemoryValue, error) {
if value, ok := (*d.Data)[key]; ok {
return value, nil
}
Expand All @@ -30,7 +29,7 @@ func (d *ZeroDictionary) at(key fp.Element) (mem.MemoryValue, error) {
}

// Given a key and a value, it sets the value at the given key
func (d *ZeroDictionary) set(key fp.Element, value mem.MemoryValue) {
func (d *ZeroDictionary) set(key mem.MemoryValue, value mem.MemoryValue) {
(*d.Data)[key] = value
}

Expand Down Expand Up @@ -59,7 +58,7 @@ func NewZeroDictionaryManager() ZeroDictionaryManager {
// It creates a new segment which will hold dictionary values. It links this
// segment with the current dictionary and returns the address that points
// to the start of this segment. initial dictionary data is set from the data argument.
func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[fp.Element]mem.MemoryValue) mem.MemoryAddress {
func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[mem.MemoryValue]mem.MemoryValue) mem.MemoryAddress {
newDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.Dictionaries[newDictAddr.SegmentIndex] = ZeroDictionary{
Expand All @@ -76,7 +75,7 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[f
// querying the defaultValue will be returned instead.
func (dm *ZeroDictionaryManager) NewDefaultDictionary(vm *VM.VirtualMachine, defaultValue mem.MemoryValue) mem.MemoryAddress {
newDefaultDictAddr := vm.Memory.AllocateEmptySegment()
newData := make(map[fp.Element]mem.MemoryValue)
newData := make(map[mem.MemoryValue]mem.MemoryValue)
freeOffset := uint64(0)
dm.Dictionaries[newDefaultDictAddr.SegmentIndex] = ZeroDictionary{
Data: &newData,
Expand All @@ -101,7 +100,7 @@ func (dm *ZeroDictionaryManager) GetDictionary(dictAddr mem.MemoryAddress) (Zero

// Given a memory address and a key it returns the value held at that position. The address is used
// to locate the correct dictionary and the key to index on it
func (dm *ZeroDictionaryManager) At(dictAddr mem.MemoryAddress, key fp.Element) (mem.MemoryValue, error) {
func (dm *ZeroDictionaryManager) At(dictAddr mem.MemoryAddress, key mem.MemoryValue) (mem.MemoryValue, error) {
dict, err := dm.GetDictionary(dictAddr)
if err != nil {
return mem.UnknownValue, err
Expand All @@ -114,7 +113,7 @@ func (dm *ZeroDictionaryManager) At(dictAddr mem.MemoryAddress, key fp.Element)
}

// Given a memory address,a key and a value it stores the value at the correct position.
func (dm *ZeroDictionaryManager) Set(dictAddr mem.MemoryAddress, key fp.Element, value mem.MemoryValue) error {
func (dm *ZeroDictionaryManager) Set(dictAddr mem.MemoryAddress, key mem.MemoryValue, value mem.MemoryValue) error {
dict, err := dm.GetDictionary(dictAddr)
if err != nil {
return err
Expand Down
70 changes: 33 additions & 37 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func newDictNewHint() hinter.Hinter {
}
}

initialDict, err := hinter.GetVariableAs[map[fp.Element]memory.MemoryValue](&ctx.ScopeManager, "initial_dict")
initialDict, err := hinter.GetVariableAs[map[memory.MemoryValue]memory.MemoryValue](&ctx.ScopeManager, "initial_dict")
if err != nil {
return err
}
Expand Down Expand Up @@ -101,13 +101,12 @@ func newDefaultDictNewHint(defaultValue hinter.ResOperander) hinter.Hinter {
}

//> memory[ap] = __dict_manager.new_default_dict(segments, ids.default_value)
defaultValue, err := hinter.ResolveAsFelt(vm, defaultValue)
defaultValue, err := defaultValue.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", defaultValue, err)
}

defaultValueMv := memory.MemoryValueFromFieldElement(defaultValue)
newDefaultDictionaryAddr := dictionaryManager.NewDefaultDictionary(vm, defaultValueMv)
newDefaultDictionaryAddr := dictionaryManager.NewDefaultDictionary(vm, defaultValue)
newDefaultDictionaryAddrMv := memory.MemoryValueFromMemoryAddress(&newDefaultDictionaryAddr)
apAddr := vm.Context.AddressAp()

Expand Down Expand Up @@ -151,11 +150,11 @@ func newDictReadHint(dictPtr, key, value hinter.ResOperander) hinter.Hinter {
}

//> ids.value = dict_tracker.data[ids.key]
key, err := hinter.ResolveAsFelt(vm, key)
key, err := key.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", key, err)
}
keyValue, err := dictionaryManager.At(*dictPtr, *key)
keyValue, err := dictionaryManager.At(*dictPtr, key)
if err != nil {
return err
}
Expand Down Expand Up @@ -225,18 +224,21 @@ func newDictSquashCopyDictHint(dictAccessesEnd hinter.ResOperander) hinter.Hinte
return err
}

dictionaryDataCopy := make(map[fp.Element]memory.MemoryValue)
dictionaryDataCopy := make(map[memory.MemoryValue]memory.MemoryValue)
for k, v := range *dictionary.Data {
// Copy the key
keyCopy := fp.Element{}
keyCopy.Set(&k)
keyFeltCopy := fp.Element{}
keyFeltCopy.Set(&k.Felt)
keyCopy := memory.MemoryValue{
Felt: keyFeltCopy,
Kind: k.Kind,
}

// Copy the value
feltCopy := fp.Element{}
feltCopy.Set(&v.Felt)

valueFeltCopy := fp.Element{}
valueFeltCopy.Set(&v.Felt)
valueCopy := memory.MemoryValue{
Felt: feltCopy,
Felt: valueFeltCopy,
Kind: v.Kind,
}

Expand Down Expand Up @@ -285,9 +287,9 @@ func newDictWriteHint(dictPtr, key, newValue hinter.ResOperander) hinter.Hinter
return fmt.Errorf("__dict_manager not in scope")
}

key, err := hinter.ResolveAsFelt(vm, key)
key, err := key.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", key, err)
}

//> ids.dict_ptr.prev_value = dict_tracker.data[ids.key]
Expand All @@ -297,7 +299,7 @@ func newDictWriteHint(dictPtr, key, newValue hinter.ResOperander) hinter.Hinter
//> prev_value: felt,
//> new_value: felt,
//> }
prevKeyValue, err := dictionaryManager.At(*dictPtr, *key)
prevKeyValue, err := dictionaryManager.At(*dictPtr, key)
if err != nil {
return err
}
Expand All @@ -307,12 +309,11 @@ func newDictWriteHint(dictPtr, key, newValue hinter.ResOperander) hinter.Hinter
}

//> dict_tracker.data[ids.key] = ids.new_value
newValue, err := hinter.ResolveAsFelt(vm, newValue)
newValue, err := newValue.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", newValue, err)
}
newValueMv := memory.MemoryValueFromFieldElement(newValue)
err = dictionaryManager.Set(*dictPtr, *key, newValueMv)
err = dictionaryManager.Set(*dictPtr, key, newValue)
if err != nil {
return err
}
Expand Down Expand Up @@ -372,39 +373,34 @@ func newDictUpdateHint(dictPtr, key, newValue, prevValue hinter.ResOperander) hi
return fmt.Errorf("__dict_manager not in scope")
}

key, err := hinter.ResolveAsFelt(vm, key)
key, err := key.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", key, err)
}

//> current_value = dict_tracker.data[ids.key]
currentValueMv, err := dictionaryManager.At(*dictPtr, *key)
currentValue, err := dictionaryManager.At(*dictPtr, key)
if err != nil {
return err
}
currentValue, err := currentValueMv.FieldElement()
if err != nil {
return err
return fmt.Errorf("%s: %w", key, err)
}

//> assert current_value == ids.prev_value, \
//> f'Wrong previous value in dict. Got {ids.prev_value}, expected {current_value}.'
prevValue, err := hinter.ResolveAsFelt(vm, prevValue)
prevValue, err := prevValue.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", prevValue, err)
}
if !currentValue.Equal(prevValue) {
if !currentValue.Equal(&prevValue) {
return fmt.Errorf("wrong previous value in dict. Got %s, expected %s", prevValue, currentValue)
}

//> # Update value.
//> dict_tracker.data[ids.key] = ids.new_value
newValue, err := hinter.ResolveAsFelt(vm, newValue)
newValue, err := newValue.Resolve(vm)
if err != nil {
return err
return fmt.Errorf("%s: %w", newValue, err)
}
newValueMv := memory.MemoryValueFromFieldElement(newValue)
err = dictionaryManager.Set(*dictPtr, *key, newValueMv)
err = dictionaryManager.Set(*dictPtr, key, newValue)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit b26fe18

Please sign in to comment.