From e63d5c76aaafc506cacd2308fa0df800e503bf38 Mon Sep 17 00:00:00 2001 From: Gyuho Lee <6799218+gyuho@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:53:36 +0800 Subject: [PATCH] feat(gpud): "gpud run --auto-update-exit-code" for daemon set auto update use case (optional) (#122) * feat(gpud): "gpud run --auto-update-exit-code" for daemon set auto update use case (optional) Signed-off-by: Gyuho Lee * exit code Signed-off-by: Gyuho Lee --------- Signed-off-by: Gyuho Lee --- cmd/gpud/command/command.go | 12 +++++- cmd/gpud/command/run.go | 3 ++ config/config.go | 10 +++++ config/config_test.go | 53 ++++++++++++++++++++++++ internal/server/server.go | 38 ++++++++++++++--- internal/session/serve.go | 51 ++++++++++++++++------- internal/session/session.go | 71 +++++++++++++++++++++++++++++--- internal/session/session_test.go | 62 +++++++++++++++++++++++++++- update/update.go | 50 +++++++++++++++------- 9 files changed, 305 insertions(+), 45 deletions(-) diff --git a/cmd/gpud/command/command.go b/cmd/gpud/command/command.go index 4457a2fe..44475a86 100644 --- a/cmd/gpud/command/command.go +++ b/cmd/gpud/command/command.go @@ -41,8 +41,10 @@ var ( pollXidEvents bool pollGPMEvents bool - enableAutoUpdate bool - filesToCheck cli.StringSlice + enableAutoUpdate bool + autoUpdateExitCode int + + filesToCheck cli.StringSlice dockerIgnoreConnectionErrors bool kubeletIgnoreConnectionErrors bool @@ -193,6 +195,12 @@ sudo rm /etc/systemd/system/gpud.service Usage: "enable auto update of gpud (default: true)", Destination: &enableAutoUpdate, }, + &cli.IntFlag{ + Name: "auto-update-exit-code", + Usage: "specifies the exit code to exit with when auto updating (default: -1 to disable exit code)", + Destination: &autoUpdateExitCode, + Value: -1, + }, &cli.StringSliceFlag{ Name: "files-to-check", Usage: "enable 'file' component that returns healthy if and only if all the files exist (default: [], use '--files-to-check=a --files-to-check=b' for multiple files)", diff --git a/cmd/gpud/command/run.go b/cmd/gpud/command/run.go index 1a2f9ce4..c06e509f 100644 --- a/cmd/gpud/command/run.go +++ b/cmd/gpud/command/run.go @@ -75,7 +75,10 @@ func cmdRun(cliContext *cli.Context) error { if webRefreshPeriod > 0 { cfg.Web.RefreshPeriod = metav1.Duration{Duration: webRefreshPeriod} } + cfg.EnableAutoUpdate = enableAutoUpdate + cfg.AutoUpdateExitCode = autoUpdateExitCode + if err := cfg.Validate(); err != nil { return err } diff --git a/config/config.go b/config/config.go index 5b98da54..3267452e 100644 --- a/config/config.go +++ b/config/config.go @@ -45,6 +45,11 @@ type Config struct { // Set false to disable auto update EnableAutoUpdate bool `json:"enable_auto_update"` + + // Exit code to exit with when auto updating. + // Only valid when the auto update is enabled. + // Set -1 to disable the auto update by exit code. + AutoUpdateExitCode int `json:"auto_update_exit_code"` } // Configures the local web configuration. @@ -62,6 +67,8 @@ type Web struct { SincePeriod metav1.Duration `json:"since_period"` } +var ErrInvalidAutoUpdateExitCode = errors.New("auto_update_exit_code is only valid when auto_update is enabled") + func (config *Config) Validate() error { if config.Address == "" { return errors.New("address is required") @@ -78,6 +85,9 @@ func (config *Config) Validate() error { if config.Web != nil && config.Web.SincePeriod.Duration < 10*time.Minute { return fmt.Errorf("web_metrics_since_period must be at least 10 minutes, got %d", config.Web.SincePeriod.Duration) } + if !config.EnableAutoUpdate && config.AutoUpdateExitCode != -1 { + return ErrInvalidAutoUpdateExitCode + } return nil } diff --git a/config/config_test.go b/config/config_test.go index d156df4d..a9f7711c 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -2,8 +2,61 @@ package config import ( "testing" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +func TestConfigValidate_AutoUpdateExitCode(t *testing.T) { + tests := []struct { + name string + enableAutoUpdate bool + autoUpdateExitCode int + wantErr bool + }{ + { + name: "Valid: Auto update enabled with exit code", + enableAutoUpdate: true, + autoUpdateExitCode: 0, + wantErr: false, + }, + { + name: "Valid: Auto update disabled with default exit code", + enableAutoUpdate: false, + autoUpdateExitCode: -1, + wantErr: false, + }, + { + name: "Invalid: Auto update disabled with non-default exit code", + enableAutoUpdate: false, + autoUpdateExitCode: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{ + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + RefreshComponentsInterval: metav1.Duration{Duration: time.Hour}, + Address: "localhost:8080", // Add a valid address to pass other validations + EnableAutoUpdate: tt.enableAutoUpdate, + AutoUpdateExitCode: tt.autoUpdateExitCode, + } + + err := cfg.Validate() + + if (err != nil) != tt.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr && err != ErrInvalidAutoUpdateExitCode { + t.Errorf("Config.Validate() error = %v, want %v", err, ErrInvalidAutoUpdateExitCode) + } + }) + } +} + func TestLoadConfigYAML(t *testing.T) { t.Parallel() diff --git a/internal/server/server.go b/internal/server/server.go index 000dd7f4..73a958ba 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -97,6 +97,7 @@ type Server struct { fifo *goOS.File session *session.Session enableAutoUpdate bool + autoUpdateExitCode int } func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID string, opts ...gpud_config.OpOption) (_ *Server, retErr error) { @@ -122,9 +123,10 @@ func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID return nil, fmt.Errorf("failed to get fifo path: %w", err) } s := &Server{ - db: db, - fifoPath: fifoPath, - enableAutoUpdate: config.EnableAutoUpdate, + db: db, + fifoPath: fifoPath, + enableAutoUpdate: config.EnableAutoUpdate, + autoUpdateExitCode: config.AutoUpdateExitCode, } defer func() { if retErr != nil { @@ -1168,9 +1170,22 @@ func (s *Server) updateToken(ctx context.Context, db *sql.DB, uid string, endpoi if dbToken, err := state.GetLoginInfo(ctx, db, uid); err == nil { userToken = dbToken } + if userToken != "" { - s.session = session.NewSession(ctx, fmt.Sprintf("https://%s/api/v1/session", endpoint), uid, 3*time.Second, s.enableAutoUpdate) + var err error + s.session, err = session.NewSession( + ctx, + fmt.Sprintf("https://%s/api/v1/session", endpoint), + session.WithMachineID(uid), + session.WithPipeInterval(3*time.Second), + session.WithEnableAutoUpdate(s.enableAutoUpdate), + session.WithAutoUpdateExitCode(s.autoUpdateExitCode), + ) + if err != nil { + log.Logger.Errorw("error creating session", "error", err) + } } + if _, err := goOS.Stat(pipePath); err == nil { if err = goOS.Remove(pipePath); err != nil { log.Logger.Errorf("error creating pipe: %v", err) @@ -1203,9 +1218,20 @@ func (s *Server) updateToken(ctx context.Context, db *sql.DB, uid string, endpoi if s.session != nil { s.session.Stop() } - s.session = session.NewSession(ctx, fmt.Sprintf("https://%s/api/v1/session", endpoint), uid, 3*time.Second, s.enableAutoUpdate) + s.session, err = session.NewSession( + ctx, + fmt.Sprintf("https://%s/api/v1/session", endpoint), + session.WithMachineID(uid), + session.WithPipeInterval(3*time.Second), + session.WithEnableAutoUpdate(s.enableAutoUpdate), + session.WithAutoUpdateExitCode(s.autoUpdateExitCode), + ) + if err != nil { + log.Logger.Errorw("error creating session", "error", err) + } } - time.Sleep(1 * time.Second) + + time.Sleep(time.Second) } } diff --git a/internal/session/serve.go b/internal/session/serve.go index 26368722..b1c0bc6a 100644 --- a/internal/session/serve.go +++ b/internal/session/serve.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" + "os" "time" v1 "github.com/leptonai/gpud/api/v1" @@ -53,7 +53,9 @@ func (s *Session) serve() { continue } + needExit := -1 response := &Response{} + switch payload.Method { case "metrics": metrics, err := s.getMetrics(ctx, payload) @@ -71,25 +73,39 @@ func (s *Session) serve() { response.Events = events case "update": + if !s.enableAutoUpdate { + log.Logger.Warnw("auto update is disabled -- skipping update") + response.Error = errors.New("auto update is disabled") + break + } + systemdManaged, _ := systemd.IsActive("gpud.service") - if !systemdManaged { - log.Logger.Debugw("gpud is not managed with systemd") + if s.autoUpdateExitCode == -1 && !systemdManaged { + log.Logger.Warnw("gpud is not managed with systemd and auto update by exit code is not set -- skipping update") response.Error = errors.New("gpud is not managed with systemd") - } else if !s.enableAutoUpdate { - log.Logger.Debugw("auto update is disabled") - response.Error = errors.New("auto update is disabled") - } else { - nextVersion := payload.UpdateVersion - if nextVersion == "" { - response.Error = fmt.Errorf("update_version is empty") - } else { - err := update.Update(nextVersion, update.DefaultUpdateURL) - if err != nil { - response.Error = err - } + break + } + + nextVersion := payload.UpdateVersion + if nextVersion == "" { + log.Logger.Warnw("target update_version is empty -- skipping update") + response.Error = errors.New("update_version is empty") + break + } + + if systemdManaged { + response.Error = update.Update(nextVersion, update.DefaultUpdateURL) + break + } + + if s.autoUpdateExitCode != -1 { + response.Error = update.UpdateOnlyBinary(nextVersion, update.DefaultUpdateURL) + if response.Error == nil { + needExit = s.autoUpdateExitCode } } } + cancel() responseRaw, _ := json.Marshal(response) @@ -97,6 +113,11 @@ func (s *Session) serve() { Data: responseRaw, ReqID: body.ReqID, } + + if needExit != -1 { + log.Logger.Infow("exiting with code for auto update", "code", needExit) + os.Exit(s.autoUpdateExitCode) + } } } diff --git a/internal/session/session.go b/internal/session/session.go index 9ab0022b..60acdb3a 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -12,6 +12,58 @@ import ( "github.com/leptonai/gpud/log" ) +type Op struct { + machineID string + pipeInterval time.Duration + enableAutoUpdate bool + autoUpdateExitCode int +} + +type OpOption func(*Op) + +var ErrAutoUpdateDisabledButExitCodeSet = errors.New("auto update is disabled but auto update by exit code is set") + +func (op *Op) applyOpts(opts []OpOption) error { + op.autoUpdateExitCode = -1 + + for _, opt := range opts { + opt(op) + } + + if !op.enableAutoUpdate && op.autoUpdateExitCode != -1 { + return ErrAutoUpdateDisabledButExitCodeSet + } + + return nil +} + +func WithMachineID(machineID string) OpOption { + return func(op *Op) { + op.machineID = machineID + } +} + +func WithPipeInterval(t time.Duration) OpOption { + return func(op *Op) { + op.pipeInterval = t + } +} + +func WithEnableAutoUpdate(enableAutoUpdate bool) OpOption { + return func(op *Op) { + op.enableAutoUpdate = enableAutoUpdate + } +} + +// Triggers an auto update of GPUd itself by exiting the process with the given exit code. +// Useful when the machine is managed by the Kubernetes daemonset and we want to +// trigger an auto update when the daemonset restarts the machine. +func WithAutoUpdateExitCode(autoUpdateExitCode int) OpOption { + return func(op *Op) { + op.autoUpdateExitCode = autoUpdateExitCode + } +} + type Session struct { ctx context.Context cancel context.CancelFunc @@ -31,10 +83,16 @@ type Session struct { readerCloseCh chan bool readerClosedCh chan bool - enableAutoUpdate bool + enableAutoUpdate bool + autoUpdateExitCode int } -func NewSession(ctx context.Context, endpoint string, machineID string, pipeInterval time.Duration, enableAutoUpdate bool) *Session { +func NewSession(ctx context.Context, endpoint string, opts ...OpOption) (*Session, error) { + op := &Op{} + if err := op.applyOpts(opts); err != nil { + return nil, err + } + cps := make([]string, 0) allComponents := components.GetAllComponents() for key := range allComponents { @@ -46,14 +104,15 @@ func NewSession(ctx context.Context, endpoint string, machineID string, pipeInte ctx: cctx, cancel: ccancel, - pipeInterval: pipeInterval, + pipeInterval: op.pipeInterval, endpoint: endpoint, - machineID: machineID, + machineID: op.machineID, components: cps, - enableAutoUpdate: enableAutoUpdate, + enableAutoUpdate: op.enableAutoUpdate, + autoUpdateExitCode: op.autoUpdateExitCode, } s.reader = make(chan Body, 20) @@ -65,7 +124,7 @@ func NewSession(ctx context.Context, endpoint string, machineID string, pipeInte s.keepAlive() go s.serve() - return s + return s, nil } type Body struct { diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 2a16373b..330dad5d 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -9,6 +9,63 @@ import ( "time" ) +func TestApplyOpts(t *testing.T) { + tests := []struct { + name string + opts []OpOption + wantErr bool + }{ + { + name: "Default options", + opts: []OpOption{}, + wantErr: false, + }, + { + name: "Enable auto update", + opts: []OpOption{ + WithEnableAutoUpdate(true), + }, + wantErr: false, + }, + { + name: "Disable auto update", + opts: []OpOption{ + WithEnableAutoUpdate(false), + }, + wantErr: false, + }, + { + name: "Set auto update by exit code with auto update enabled", + opts: []OpOption{ + WithEnableAutoUpdate(true), + WithAutoUpdateExitCode(1), + }, + wantErr: false, + }, + { + name: "Set auto update by exit code with auto update disabled", + opts: []OpOption{ + WithEnableAutoUpdate(false), + WithAutoUpdateExitCode(1), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op := &Op{} + err := op.applyOpts(tt.opts) + if (err != nil) != tt.wantErr { + t.Errorf("applyOpts() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && err != ErrAutoUpdateDisabledButExitCodeSet { + t.Errorf("applyOpts() expected error %v, got %v", ErrAutoUpdateDisabledButExitCodeSet, err) + } + }) + } +} + func TestNewSession(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -16,7 +73,10 @@ func TestNewSession(t *testing.T) { endpoint := "test-endpoint.com" machineID := "test-machine-id" - session := NewSession(ctx, endpoint, machineID, time.Second, true) + session, err := NewSession(ctx, endpoint, WithMachineID(machineID), WithPipeInterval(time.Second), WithEnableAutoUpdate(true)) + if err != nil { + t.Fatalf("error creating session: %v", err) + } defer session.Stop() if session == nil { diff --git a/update/update.go b/update/update.go index 7a4c24a7..217e87ae 100644 --- a/update/update.go +++ b/update/update.go @@ -191,33 +191,53 @@ func StopSystemdUnit() error { } func Update(ver, url string) error { - if err := RequireRoot(); err != nil { - log.Logger.Errorf("this command needs to be run as root: %v", err) - return err + return update(ver, url, true, true) +} + +// Updates the gpud binary by only downloading the tarball and unpacking it, +// without restarting the service or requiring root. +func UpdateOnlyBinary(ver, url string) error { + return update(ver, url, false, false) +} + +func update(ver, url string, requireRoot bool, useSystemd bool) error { + log.Logger.Infow("starting gpud update", "version", ver, "url", url, "requireRoot", requireRoot, "useSystemd", useSystemd) + + if requireRoot { + if err := RequireRoot(); err != nil { + log.Logger.Errorf("this command needs to be run as root: %v", err) + return err + } } dlPath, err := downloadLinuxTarball(ver, url) if err != nil { return err } - fmt.Printf("Extracting %q", dlPath) + log.Logger.Infow("downloaded update tarball", "path", dlPath) + if err := unpackLinuxTarball(dlPath); err != nil { return err } + log.Logger.Infow("unpacked update tarball", "path", dlPath) + if err := os.Remove(dlPath); err != nil { - log.Logger.Errorf("failed to cleanup: %s", err) - } - if err := RestartSystemdUnit(); err != nil { - if strings.Contains(err.Error(), "signal: terminated") { - // an expected error - log.Logger.Infof("gpud binary updated successfully. Waiting complete of systemd restart.") - } else if errors.Is(err, errors.ErrUnsupported) { - log.Logger.Errorf("gpud binary updated successfully. Please restart gpud to finish the update.") + log.Logger.Errorw("failed to cleanup the downloaded update tarball", "error", err) + } + + if useSystemd { + if err := RestartSystemdUnit(); err != nil { + if strings.Contains(err.Error(), "signal: terminated") { + // an expected error + log.Logger.Infof("gpud binary updated successfully. Waiting complete of systemd restart.") + } else if errors.Is(err, errors.ErrUnsupported) { + log.Logger.Errorf("gpud binary updated successfully. Please restart gpud to finish the update.") + } else { + log.Logger.Errorf("gpud binary updated successfully, but failed to restart gpud: %s. Please restart gpud to finish the update.", err) + } } else { - log.Logger.Errorf("gpud binary updated successfully, but failed to restart gpud: %s. Please restart gpud to finish the update.", err) + log.Logger.Infow("completed gpud update", "version", ver) } - } else { - log.Logger.Infof("updating gpud to version %s completed", ver) } return nil