Skip to content

Commit

Permalink
w3vm: fix txIndex during rollback (#182)
Browse files Browse the repository at this point in the history
* fix: import txIndex during rollback

* added tests

* fix

* added rollback test case

* fix

* fix

---------

Co-authored-by: lmittmann <[email protected]>
  • Loading branch information
wesraph and lmittmann authored Aug 22, 2024
1 parent 1210634 commit 54aeeda
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 2 deletions.
17 changes: 15 additions & 2 deletions w3vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re
Logs: v.db.GetLogs(txHash, 0, w3.Hash0),
}

// zero out the log tx hashes, indices and normalize the log indices
for i, log := range receipt.Logs {
log.Index = uint(i)
log.TxHash = w3.Hash0
log.TxIndex = 0
}

if err := result.Err; err != nil {
if reason, unpackErr := abi.UnpackRevert(result.ReturnData); unpackErr != nil {
receipt.Err = ErrRevert
Expand Down Expand Up @@ -254,7 +261,10 @@ func (vm *VM) SetStorageAt(addr common.Address, slot, val common.Hash) {
func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() }

// Rollback the state of the VM to the given snapshot.
func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot }
func (vm *VM) Rollback(snapshot *state.StateDB) {
vm.db = snapshot
vm.txIndex = uint64(snapshot.TxIndex()) + 1
}

func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) {
nonce := msg.Nonce
Expand Down Expand Up @@ -474,7 +484,10 @@ func WithState(state w3types.State) Option {
//
// The state DB can originate from a snapshot of the VM.
func WithStateDB(db *state.StateDB) Option {
return func(vm *VM) { vm.db = db }
return func(vm *VM) {
vm.db = db
vm.txIndex = uint64(db.TxIndex() + 1)
}
}

// WithNoBaseFee forces the EIP-1559 base fee to 0 for the VM.
Expand Down
173 changes: 173 additions & 0 deletions w3vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,179 @@ func TestVMSnapshot(t *testing.T) {
}
}

func TestVMSnapshot_Logs(t *testing.T) {
var (
preState = w3types.State{
addrWETH: {
Code: codeWETH,
Storage: w3types.Storage{
w3vm.WETHBalanceSlot(addr0): common.BigToHash(w3.I("10 ether")),
}},
}
transferMsg = &w3types.Message{
From: addr0,
To: &addrWETH,
Func: funcTransfer,
Args: []any{addr1, w3.I("1 ether")},
}
)

tests := []struct {
Name string
F func() (receipt0, receipt1 *w3vm.Receipt, err error)
}{
{
Name: "rollback_0",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm.Rollback(snap)

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "rollback_1",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

if _, err = vm.Apply(transferMsg); err != nil {
return
}

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm.Rollback(snap)

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "rollback_2",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

snap := vm.Snapshot()
vm.Rollback(snap)

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "rollback_3",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

if _, err = vm.Apply(transferMsg); err != nil {
return
}

snap := vm.Snapshot()
receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm2, _ := w3vm.New(w3vm.WithState(preState))
vm2.Rollback(snap)

receipt1, err = vm2.Apply(transferMsg)
return
},
},
{
Name: "new_0",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm, _ = w3vm.New(w3vm.WithStateDB(snap))

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "new_1",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

if _, err = vm.Apply(transferMsg); err != nil {
return
}

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm, _ = w3vm.New(w3vm.WithStateDB(snap))

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "new_2",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

snap := vm.Snapshot()
vm, _ = w3vm.New(w3vm.WithStateDB(snap))

receipt1, err = vm.Apply(transferMsg)
return
},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
receipt0, receipt1, err := test.F()
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(receipt0.Logs, receipt1.Logs); diff != "" {
t.Fatalf("(-want +got)\n%s", diff)
}
})
}
}

func TestVMCall(t *testing.T) {
tests := []struct {
PreState w3types.State
Expand Down

0 comments on commit 54aeeda

Please sign in to comment.