-
Notifications
You must be signed in to change notification settings - Fork 0
/
lifecycle.go
117 lines (99 loc) Β· 2.52 KB
/
lifecycle.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package lifecycle
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
var _exit = func() { os.Exit(1) }
type Printer interface {
Printf(format string, args ...interface{})
}
type Service interface {
Start(ctx context.Context) error
Stop(ctx context.Context) error
}
type Hook struct {
OnStart func(ctx context.Context) error
OnStop func(ctx context.Context) error
}
var _ Service = (*Hook)(nil)
func (s *Hook) Start(ctx context.Context) error {
return s.OnStart(ctx)
}
func (s *Hook) Stop(ctx context.Context) error {
return s.OnStop(ctx)
}
type Lifecycle struct {
hooks []*Hook
}
func (lc *Lifecycle) Append(hook *Hook) {
lc.hooks = append(lc.hooks, hook)
}
func (lc *Lifecycle) AppendService(service Service) {
lc.hooks = append(lc.hooks, service.(*Hook))
}
func (lc *Lifecycle) Start(ctx context.Context) error {
for _, hook := range lc.hooks {
if hook.OnStart == nil {
continue
}
if err := hook.OnStart(ctx); err != nil {
return err
}
}
return nil
}
func (lc *Lifecycle) Stop(ctx context.Context) error {
var errs []error
for i := len(lc.hooks) - 1; i >= 0; i-- {
hook := lc.hooks[i]
if hook.OnStop == nil {
continue
}
if err := hook.OnStop(ctx); err != nil {
// for best-effort cleanup, keep going after errors
errs = append(errs, err)
}
}
return multierr.Combine(errs...)
}
func (lc *Lifecycle) Run(logger Printer, stopTimeout time.Duration) {
lc.RunContext(context.Background(), logger, stopTimeout)
}
func (lc *Lifecycle) RunContext(ctx context.Context, logger Printer, stopTimeout time.Duration) {
startCtx, cancelStart := context.WithCancel(ctx)
go func() {
if err := lc.Start(startCtx); err != nil {
logger.Printf("ERROR\t\tcould not start lifecycle: %+v\n", err)
// rollback
if stopErr := lc.Stop(ctx); stopErr != nil {
logger.Printf("ERROR\t\tcould not rollback cleanly: %+v", stopErr)
}
_exit()
}
}()
_ = lc.waitForSignal(startCtx)
cancelStart()
logger.Printf("INFO\t\tattempting clean stop...")
stopCtx, cancelStop := context.WithTimeout(context.Background(), stopTimeout)
defer cancelStop()
if err := lc.Stop(stopCtx); err != nil {
logger.Printf("ERROR\t\tcould not stop cleanly: %+v\n", err)
_exit()
}
logger.Printf("INFO\t\tstopped cleanly")
}
func (lc *Lifecycle) waitForSignal(ctx context.Context) error {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
select {
case <-c:
return errors.New("shutdown received")
case <-ctx.Done():
return ctx.Err()
}
}