Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

w3vm: fix txIndex during rollback #182

Merged
merged 6 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions w3vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re
Logs: v.db.GetLogs(txHash, 0, w3.Hash0),
}

// zero out the log tx hashes, indices and normalize the log indices
for i, log := range receipt.Logs {
log.Index = uint(i)
log.TxHash = w3.Hash0
log.TxIndex = 0
}

if err := result.Err; err != nil {
if reason, unpackErr := abi.UnpackRevert(result.ReturnData); unpackErr != nil {
receipt.Err = ErrRevert
Expand Down Expand Up @@ -254,7 +261,10 @@ func (vm *VM) SetStorageAt(addr common.Address, slot, val common.Hash) {
func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() }

// Rollback the state of the VM to the given snapshot.
func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot }
func (vm *VM) Rollback(snapshot *state.StateDB) {
vm.db = snapshot
vm.txIndex = uint64(snapshot.TxIndex()) + 1
}

func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) {
nonce := msg.Nonce
Expand Down Expand Up @@ -474,7 +484,10 @@ func WithState(state w3types.State) Option {
//
// The state DB can originate from a snapshot of the VM.
func WithStateDB(db *state.StateDB) Option {
return func(vm *VM) { vm.db = db }
return func(vm *VM) {
vm.db = db
vm.txIndex = uint64(db.TxIndex() + 1)
}
}

// WithNoBaseFee forces the EIP-1559 base fee to 0 for the VM.
Expand Down
173 changes: 173 additions & 0 deletions w3vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,179 @@ func TestVMSnapshot(t *testing.T) {
}
}

func TestVMSnapshot_Logs(t *testing.T) {
var (
preState = w3types.State{
addrWETH: {
Code: codeWETH,
Storage: w3types.Storage{
w3vm.WETHBalanceSlot(addr0): common.BigToHash(w3.I("10 ether")),
}},
}
transferMsg = &w3types.Message{
From: addr0,
To: &addrWETH,
Func: funcTransfer,
Args: []any{addr1, w3.I("1 ether")},
}
)

tests := []struct {
Name string
F func() (receipt0, receipt1 *w3vm.Receipt, err error)
}{
{
Name: "rollback_0",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm.Rollback(snap)

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "rollback_1",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

if _, err = vm.Apply(transferMsg); err != nil {
return
}

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm.Rollback(snap)

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "rollback_2",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

snap := vm.Snapshot()
vm.Rollback(snap)

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "rollback_3",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

if _, err = vm.Apply(transferMsg); err != nil {
return
}

snap := vm.Snapshot()
receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm2, _ := w3vm.New(w3vm.WithState(preState))
vm2.Rollback(snap)

receipt1, err = vm2.Apply(transferMsg)
return
},
},
{
Name: "new_0",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm, _ = w3vm.New(w3vm.WithStateDB(snap))

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "new_1",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

if _, err = vm.Apply(transferMsg); err != nil {
return
}

snap := vm.Snapshot()

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

vm, _ = w3vm.New(w3vm.WithStateDB(snap))

receipt1, err = vm.Apply(transferMsg)
return
},
},
{
Name: "new_2",
F: func() (receipt0, receipt1 *w3vm.Receipt, err error) {
vm, _ := w3vm.New(w3vm.WithState(preState))

receipt0, err = vm.Apply(transferMsg)
if err != nil {
return
}

snap := vm.Snapshot()
vm, _ = w3vm.New(w3vm.WithStateDB(snap))

receipt1, err = vm.Apply(transferMsg)
return
},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
receipt0, receipt1, err := test.F()
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(receipt0.Logs, receipt1.Logs); diff != "" {
t.Fatalf("(-want +got)\n%s", diff)
}
})
}
}

func TestVMCall(t *testing.T) {
tests := []struct {
PreState w3types.State
Expand Down