diff --git a/server.go b/server.go index 9625546..513a418 100644 --- a/server.go +++ b/server.go @@ -322,6 +322,7 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) { s.mutex.Lock() defer s.mutex.Unlock() + stat, err := s.serverClient.DeregisterExtension(s.uuid) err = errors.Wrap(err, "deregistering extension") if err == nil && stat.Code != 0 { @@ -333,7 +334,7 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) { s.server = nil // Stop the server asynchronously so that the current request // can complete. Otherwise, this is vulnerable to deadlock if a - // shutdown request is being processed when shutdown is + // shutdown request is being processed when Shutdown is // explicitly called. go func() { server.Stop() diff --git a/server_test.go b/server_test.go index 705f278..2693d4e 100644 --- a/server_test.go +++ b/server_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "os" + "runtime/pprof" "strings" "sync" "syscall" @@ -98,18 +99,19 @@ const parallelTestShutdownDeadlock = 20 func TestShutdownDeadlock(t *testing.T) { for i := 0; i < parallelTestShutdownDeadlock; i++ { + i := i t.Run("", func(t *testing.T) { t.Parallel() - testShutdownDeadlock(t) + testShutdownDeadlock(t, i) }) } } -func testShutdownDeadlock(t *testing.T) { +func testShutdownDeadlock(t *testing.T, uuid int) { tempPath, err := ioutil.TempFile("", "") require.Nil(t, err) defer os.Remove(tempPath.Name()) - retUUID := osquery.ExtensionRouteUUID(0) + retUUID := osquery.ExtensionRouteUUID(uuid) mock := &MockExtensionManager{ RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) { return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil @@ -119,16 +121,22 @@ func testShutdownDeadlock(t *testing.T) { }, CloseFunc: func() {}, } - server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} + server := ExtensionManagerServer{ + serverClient: mock, + sockPath: tempPath.Name(), + timeout: defaultTimeout, + } - wait := sync.WaitGroup{} + var wait sync.WaitGroup - wait.Add(1) go func() { + // We do not wait for this routine to finish because thrift.TServer.Serve + // seems to sometimes hang after shutdowns. (This test is just testing + // the Shutdown doesn't hang.) err := server.Start() - require.Nil(t, err) - wait.Done() + require.NoError(t, err) }() + // Wait for server to be set up server.waitStarted() @@ -138,10 +146,17 @@ func testShutdownDeadlock(t *testing.T) { addr, err := net.ResolveUnixAddr("unix", listenPath) require.Nil(t, err) timeout := 500 * time.Millisecond - trans := thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout) - err = trans.Open() - require.Nil(t, err) - client := osquery.NewExtensionManagerClientFactory(trans, + opened := false + attempt := 0 + var transport *thrift.TSocket + for !opened && attempt < 10 { + transport = thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout) + err = transport.Open() + opened = err == nil + attempt++ + } + require.NoError(t, err) + client := osquery.NewExtensionManagerClientFactory(transport, thrift.NewTBinaryProtocolFactoryDefault()) // Simultaneously call shutdown through a request from the client and @@ -156,7 +171,7 @@ func testShutdownDeadlock(t *testing.T) { go func() { defer wait.Done() err = server.Shutdown(context.Background()) - require.Nil(t, err) + require.NoError(t, err) }() // Track whether shutdown completed @@ -171,7 +186,8 @@ func testShutdownDeadlock(t *testing.T) { select { case <-completed: // Success. Do nothing. - case <-time.After(5 * time.Second): + case <-time.After(10 * time.Second): + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) t.Fatal("hung on shutdown") } }