diff --git a/w3vm/vm.go b/w3vm/vm.go index 2ab19cdb..df076f33 100644 --- a/w3vm/vm.go +++ b/w3vm/vm.go @@ -132,6 +132,13 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re Logs: v.db.GetLogs(txHash, 0, w3.Hash0), } + // zero out the log tx hashes, indices and normalize the log indices + for i, log := range receipt.Logs { + log.Index = uint(i) + log.TxHash = w3.Hash0 + log.TxIndex = 0 + } + if err := result.Err; err != nil { if reason, unpackErr := abi.UnpackRevert(result.ReturnData); unpackErr != nil { receipt.Err = ErrRevert @@ -254,7 +261,10 @@ func (vm *VM) SetStorageAt(addr common.Address, slot, val common.Hash) { func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() } // Rollback the state of the VM to the given snapshot. -func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot } +func (vm *VM) Rollback(snapshot *state.StateDB) { + vm.db = snapshot + vm.txIndex = uint64(snapshot.TxIndex()) + 1 +} func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) { nonce := msg.Nonce @@ -474,7 +484,10 @@ func WithState(state w3types.State) Option { // // The state DB can originate from a snapshot of the VM. func WithStateDB(db *state.StateDB) Option { - return func(vm *VM) { vm.db = db } + return func(vm *VM) { + vm.db = db + vm.txIndex = uint64(db.TxIndex() + 1) + } } // WithNoBaseFee forces the EIP-1559 base fee to 0 for the VM. diff --git a/w3vm/vm_test.go b/w3vm/vm_test.go index 799dee54..d6087beb 100644 --- a/w3vm/vm_test.go +++ b/w3vm/vm_test.go @@ -345,6 +345,179 @@ func TestVMSnapshot(t *testing.T) { } } +func TestVMSnapshot_Logs(t *testing.T) { + var ( + preState = w3types.State{ + addrWETH: { + Code: codeWETH, + Storage: w3types.Storage{ + w3vm.WETHBalanceSlot(addr0): common.BigToHash(w3.I("10 ether")), + }}, + } + transferMsg = &w3types.Message{ + From: addr0, + To: &addrWETH, + Func: funcTransfer, + Args: []any{addr1, w3.I("1 ether")}, + } + ) + + tests := []struct { + Name string + F func() (receipt0, receipt1 *w3vm.Receipt, err error) + }{ + { + Name: "rollback_0", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + snap := vm.Snapshot() + + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + vm.Rollback(snap) + + receipt1, err = vm.Apply(transferMsg) + return + }, + }, + { + Name: "rollback_1", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + if _, err = vm.Apply(transferMsg); err != nil { + return + } + + snap := vm.Snapshot() + + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + vm.Rollback(snap) + + receipt1, err = vm.Apply(transferMsg) + return + }, + }, + { + Name: "rollback_2", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + snap := vm.Snapshot() + vm.Rollback(snap) + + receipt1, err = vm.Apply(transferMsg) + return + }, + }, + { + Name: "rollback_3", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + if _, err = vm.Apply(transferMsg); err != nil { + return + } + + snap := vm.Snapshot() + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + vm2, _ := w3vm.New(w3vm.WithState(preState)) + vm2.Rollback(snap) + + receipt1, err = vm2.Apply(transferMsg) + return + }, + }, + { + Name: "new_0", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + snap := vm.Snapshot() + + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + vm, _ = w3vm.New(w3vm.WithStateDB(snap)) + + receipt1, err = vm.Apply(transferMsg) + return + }, + }, + { + Name: "new_1", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + if _, err = vm.Apply(transferMsg); err != nil { + return + } + + snap := vm.Snapshot() + + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + vm, _ = w3vm.New(w3vm.WithStateDB(snap)) + + receipt1, err = vm.Apply(transferMsg) + return + }, + }, + { + Name: "new_2", + F: func() (receipt0, receipt1 *w3vm.Receipt, err error) { + vm, _ := w3vm.New(w3vm.WithState(preState)) + + receipt0, err = vm.Apply(transferMsg) + if err != nil { + return + } + + snap := vm.Snapshot() + vm, _ = w3vm.New(w3vm.WithStateDB(snap)) + + receipt1, err = vm.Apply(transferMsg) + return + }, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + receipt0, receipt1, err := test.F() + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(receipt0.Logs, receipt1.Logs); diff != "" { + t.Fatalf("(-want +got)\n%s", diff) + } + }) + } +} + func TestVMCall(t *testing.T) { tests := []struct { PreState w3types.State