Skip to content

Commit

Permalink
compiler: fix compiledModule leak (#1608)
Browse files Browse the repository at this point in the history
Signed-off-by: Nuno Cruces <[email protected]>
Co-authored-by: Achille Roussel <[email protected]>
  • Loading branch information
ncruces and achille-roussel authored Aug 2, 2023
1 parent 2f2b6a9 commit 90f58bc
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 57 deletions.
4 changes: 2 additions & 2 deletions internal/engine/compiler/compiler_controlflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ func TestCompiler_callIndirect_largeTypeIndex(t *testing.T) {

makeExecutable(code1.Bytes())
f := function{
parent: &compiledFunction{parent: &compiledModule{executable: code1}},
parent: &compiledFunction{parent: &compiledCode{executable: code1}},
codeInitialAddress: uintptr(unsafe.Pointer(&code1.Bytes()[0])),
moduleInstance: env.moduleInstance,
}
Expand Down Expand Up @@ -896,7 +896,7 @@ func TestCompiler_compileCall(t *testing.T) {

makeExecutable(code.Bytes())
me.functions = append(me.functions, function{
parent: &compiledFunction{parent: &compiledModule{executable: code}},
parent: &compiledFunction{parent: &compiledCode{executable: code}},
codeInitialAddress: uintptr(unsafe.Pointer(&code.Bytes()[0])),
moduleInstance: env.moduleInstance,
})
Expand Down
6 changes: 3 additions & 3 deletions internal/engine/compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (j *compilerEnv) callEngine() *callEngine {
}

func (j *compilerEnv) exec(machineCode []byte) {
cm := new(compiledModule)
cm := &compiledModule{compiledCode: &compiledCode{}}
if err := cm.executable.Map(len(machineCode)); err != nil {
panic(err)
}
Expand All @@ -211,7 +211,7 @@ func (j *compilerEnv) exec(machineCode []byte) {
makeExecutable(executable)

f := &function{
parent: &compiledFunction{parent: cm},
parent: &compiledFunction{parent: cm.compiledCode},
codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])),
moduleInstance: j.moduleInstance,
}
Expand Down Expand Up @@ -268,7 +268,7 @@ func newCompilerEnvironment() *compilerEnv {
Globals: []*wasm.GlobalInstance{},
Engine: me,
},
ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledModule{}}}),
ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}),
}
}

Expand Down
64 changes: 42 additions & 22 deletions internal/engine/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ type (
// as the underlying memory region is accessed by assembly directly by using
// codesElement0Address.
functions []function

// Keep a reference to the compiled module to prevent the GC from reclaiming
// it while the code may still be needed.
module *compiledModule
}

// callEngine holds context per moduleEngine.Call, and shared across all the
Expand Down Expand Up @@ -130,11 +134,13 @@ type (
// initialFn is the initial function for this call engine.
initialFn *function

// Keep a reference to the compiled module to prevent the GC from reclaiming
// it while the code may still be needed.
module *compiledModule

// stackIterator provides a way to iterate over the stack for Listeners.
// It is setup and valid only during a call to a Listener hook.
stackIterator stackIterator

ensureTermination bool
}

// moduleContext holds the per-function call specific module information.
Expand Down Expand Up @@ -264,12 +270,27 @@ type (
}

compiledModule struct {
executable asm.CodeSegment
functions []compiledFunction
source *wasm.Module
// The data that need to be accessed by compiledFunction.parent are
// separated in an embedded field because we use finalizers to manage
// the lifecycle of compiledModule instances and having cyclic pointers
// prevents the Go runtime from calling them, which results in memory
// leaks since the memory mapped code segments cannot be released.
//
// The indirection guarantees that the finalizer set on compiledModule
// instances can run when all references are gone, and the Go GC can
// manage to reclaim the compiledCode when all compiledFunction objects
// referencing it have been freed.
*compiledCode
functions []compiledFunction

ensureTermination bool
}

compiledCode struct {
source *wasm.Module
executable asm.CodeSegment
}

// compiledFunction corresponds to a function in a module (not instantiated one). This holds the machine code
// compiled by wazero compiler.
compiledFunction struct {
Expand All @@ -282,7 +303,7 @@ type (
index wasm.Index
goFunc interface{}
listener experimental.FunctionListener
parent *compiledModule
parent *compiledCode
sourceOffsetMap sourceOffsetMap
}

Expand Down Expand Up @@ -496,13 +517,6 @@ func (e *engine) Close() (err error) {
e.mux.Lock()
defer e.mux.Unlock()
// Releasing the references to compiled codes including the memory-mapped machine codes.

for i := range e.codes {
for j := range e.codes[i].functions {
e.codes[i].functions[j].parent = nil
}
}

e.codes = nil
return
}
Expand All @@ -523,9 +537,11 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners
var withGoFunc bool
localFuncs, importedFuncs := len(module.FunctionSection), module.ImportFunctionCount
cm := &compiledModule{
compiledCode: &compiledCode{
source: module,
},
functions: make([]compiledFunction, localFuncs),
ensureTermination: ensureTermination,
source: module,
}

if localFuncs == 0 {
Expand Down Expand Up @@ -559,7 +575,7 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners
funcIndex := wasm.Index(i)
compiledFn := &cm.functions[i]
compiledFn.executableOffset = executable.Size()
compiledFn.parent = cm
compiledFn.parent = cm.compiledCode
compiledFn.index = importedFuncs + funcIndex
if i < ln {
compiledFn.listener = listeners[i]
Expand Down Expand Up @@ -628,6 +644,8 @@ func (e *engine) NewModuleEngine(module *wasm.Module, instance *wasm.ModuleInsta
parent: c,
}
}

me.module = cm
return me, nil
}

Expand Down Expand Up @@ -720,7 +738,7 @@ func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {

func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.initialFn.moduleInstance
if ce.ensureTermination {
if ce.module.ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the call context
Expand All @@ -741,12 +759,14 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
err = m.FailIfClosed()
}
// Ensure that the compiled module will never be GC'd before this method returns.
runtime.KeepAlive(ce.module)
}()

ft := ce.initialFn.funcType
ce.initializeStack(ft, params)

if ce.ensureTermination {
if ce.module.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}
Expand Down Expand Up @@ -959,11 +979,11 @@ var initialStackSize uint64 = 512

func (e *moduleEngine) newCallEngine(stackSize uint64, fn *function) *callEngine {
ce := &callEngine{
stack: make([]uint64, stackSize),
archContext: newArchContext(),
initialFn: fn,
moduleContext: moduleContext{fn: fn},
ensureTermination: fn.parent.parent.ensureTermination,
stack: make([]uint64, stackSize),
archContext: newArchContext(),
initialFn: fn,
moduleContext: moduleContext{fn: fn},
module: e.module,
}

stackHeader := (*reflect.SliceHeader)(unsafe.Pointer(&ce.stack))
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/compiler/engine_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func BenchmarkCallEngine_builtinFunctionFunctionListener(b *testing.B) {
},
},
index: 0,
parent: &compiledModule{
parent: &compiledCode{
source: &wasm.Module{
TypeSection: []wasm.FunctionType{{}},
FunctionSection: []wasm.Index{0},
Expand Down
9 changes: 7 additions & 2 deletions internal/engine/compiler/engine_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
func (e *engine) deleteCompiledModule(module *wasm.Module) {
e.mux.Lock()
defer e.mux.Unlock()

delete(e.codes, module.ID)

// Note: we do not call e.Cache.Delete, as the lifetime of
Expand Down Expand Up @@ -158,14 +159,18 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, modul

ensureTermination := header[cachedVersionEnd] != 0
functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
cm = &compiledModule{functions: make([]compiledFunction, functionsNum), ensureTermination: ensureTermination}
cm = &compiledModule{
compiledCode: new(compiledCode),
functions: make([]compiledFunction, functionsNum),
ensureTermination: ensureTermination,
}

imported := module.ImportFunctionCount

var eightBytes [8]byte
for i := uint32(0); i < functionsNum; i++ {
f := &cm.functions[i]
f.parent = cm
f.parent = cm.compiledCode

// Read the stack pointer ceil.
if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {
Expand Down
54 changes: 35 additions & 19 deletions internal/engine/compiler/engine_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ func TestSerializeCompiledModule(t *testing.T) {
}{
{
in: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
},
Expand All @@ -57,11 +59,13 @@ func TestSerializeCompiledModule(t *testing.T) {
},
{
in: &compiledModule{
ensureTermination: true,
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
},
ensureTermination: true,
},
exp: concat(
[]byte(wazeroMagic),
Expand All @@ -77,12 +81,14 @@ func TestSerializeCompiledModule(t *testing.T) {
},
{
in: &compiledModule{
ensureTermination: true,
executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
{executableOffset: 5, stackPointerCeil: 0xffffffff},
},
ensureTermination: true,
},
exp: concat(
[]byte(wazeroMagic),
Expand Down Expand Up @@ -159,7 +165,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte{1, 2, 3, 4, 5}, // machine code.
),
expCompiledModule: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345, index: 0},
},
Expand All @@ -181,9 +189,11 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte{1, 2, 3, 4, 5}, // code.
),
expCompiledModule: &compiledModule{
ensureTermination: true,
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}},
ensureTermination: true,
},
expStaleCache: false,
expErr: "",
Expand All @@ -208,7 +218,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
),
importedFunctionCount: 1,
expCompiledModule: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345, index: 1},
{executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2},
Expand Down Expand Up @@ -279,8 +291,8 @@ func TestDeserializeCompiledModule(t *testing.T) {
if tc.expCompiledModule != nil {
require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions))
for i := 0; i < len(cm.functions); i++ {
require.Equal(t, cm, cm.functions[i].parent)
tc.expCompiledModule.functions[i].parent = cm
require.Equal(t, cm.compiledCode, cm.functions[i].parent)
tc.expCompiledModule.functions[i].parent = cm.compiledCode
}
}

Expand Down Expand Up @@ -361,13 +373,13 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
},
expHit: true,
expCompiledModule: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
},
functions: []compiledFunction{
{stackPointerCeil: 12345, executableOffset: 0, index: 0},
{stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1},
},
source: nil,
ensureTermination: false,
},
},
}
Expand All @@ -379,7 +391,7 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
if exp := tc.expCompiledModule; exp != nil {
exp.source = m
for i := range tc.expCompiledModule.functions {
tc.expCompiledModule.functions[i].parent = exp
tc.expCompiledModule.functions[i].parent = exp.compiledCode
}
}

Expand Down Expand Up @@ -422,8 +434,10 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
tc := filecache.New(t.TempDir())
e := engine{fileCache: tc}
cm := &compiledModule{
executable: makeCodeSegment(1, 2, 3),
functions: []compiledFunction{{stackPointerCeil: 123}},
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3),
},
functions: []compiledFunction{{stackPointerCeil: 123}},
}
m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module!
err := e.addCompiledModuleToCache(m, cm)
Expand All @@ -438,8 +452,10 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
e := engine{fileCache: tc}
m := &wasm.Module{}
cm := &compiledModule{
executable: makeCodeSegment(1, 2, 3),
functions: []compiledFunction{{stackPointerCeil: 123}},
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3),
},
functions: []compiledFunction{{stackPointerCeil: 123}},
}
err := e.addCompiledModuleToCache(m, cm)
require.NoError(t, err)
Expand Down
Loading

0 comments on commit 90f58bc

Please sign in to comment.