Skip to content

Commit

Permalink
feat(gpud): "gpud run --auto-update-exit-code" for daemon set auto up…
Browse files Browse the repository at this point in the history
…date use case (optional) (#122)

* feat(gpud): "gpud run --auto-update-exit-code" for daemon set auto update use case (optional)

Signed-off-by: Gyuho Lee <[email protected]>

* exit code

Signed-off-by: Gyuho Lee <[email protected]>

---------

Signed-off-by: Gyuho Lee <[email protected]>
  • Loading branch information
gyuho authored Oct 16, 2024
1 parent c2b7b31 commit e63d5c7
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 45 deletions.
12 changes: 10 additions & 2 deletions cmd/gpud/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ var (
pollXidEvents bool
pollGPMEvents bool

enableAutoUpdate bool
filesToCheck cli.StringSlice
enableAutoUpdate bool
autoUpdateExitCode int

filesToCheck cli.StringSlice

dockerIgnoreConnectionErrors bool
kubeletIgnoreConnectionErrors bool
Expand Down Expand Up @@ -193,6 +195,12 @@ sudo rm /etc/systemd/system/gpud.service
Usage: "enable auto update of gpud (default: true)",
Destination: &enableAutoUpdate,
},
&cli.IntFlag{
Name: "auto-update-exit-code",
Usage: "specifies the exit code to exit with when auto updating (default: -1 to disable exit code)",
Destination: &autoUpdateExitCode,
Value: -1,
},
&cli.StringSliceFlag{
Name: "files-to-check",
Usage: "enable 'file' component that returns healthy if and only if all the files exist (default: [], use '--files-to-check=a --files-to-check=b' for multiple files)",
Expand Down
3 changes: 3 additions & 0 deletions cmd/gpud/command/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ func cmdRun(cliContext *cli.Context) error {
if webRefreshPeriod > 0 {
cfg.Web.RefreshPeriod = metav1.Duration{Duration: webRefreshPeriod}
}

cfg.EnableAutoUpdate = enableAutoUpdate
cfg.AutoUpdateExitCode = autoUpdateExitCode

if err := cfg.Validate(); err != nil {
return err
}
Expand Down
10 changes: 10 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ type Config struct {

// Set false to disable auto update
EnableAutoUpdate bool `json:"enable_auto_update"`

// Exit code to exit with when auto updating.
// Only valid when the auto update is enabled.
// Set -1 to disable the auto update by exit code.
AutoUpdateExitCode int `json:"auto_update_exit_code"`
}

// Configures the local web configuration.
Expand All @@ -62,6 +67,8 @@ type Web struct {
SincePeriod metav1.Duration `json:"since_period"`
}

var ErrInvalidAutoUpdateExitCode = errors.New("auto_update_exit_code is only valid when auto_update is enabled")

func (config *Config) Validate() error {
if config.Address == "" {
return errors.New("address is required")
Expand All @@ -78,6 +85,9 @@ func (config *Config) Validate() error {
if config.Web != nil && config.Web.SincePeriod.Duration < 10*time.Minute {
return fmt.Errorf("web_metrics_since_period must be at least 10 minutes, got %d", config.Web.SincePeriod.Duration)
}
if !config.EnableAutoUpdate && config.AutoUpdateExitCode != -1 {
return ErrInvalidAutoUpdateExitCode
}
return nil
}

Expand Down
53 changes: 53 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,61 @@ package config

import (
"testing"
"time"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestConfigValidate_AutoUpdateExitCode(t *testing.T) {
tests := []struct {
name string
enableAutoUpdate bool
autoUpdateExitCode int
wantErr bool
}{
{
name: "Valid: Auto update enabled with exit code",
enableAutoUpdate: true,
autoUpdateExitCode: 0,
wantErr: false,
},
{
name: "Valid: Auto update disabled with default exit code",
enableAutoUpdate: false,
autoUpdateExitCode: -1,
wantErr: false,
},
{
name: "Invalid: Auto update disabled with non-default exit code",
enableAutoUpdate: false,
autoUpdateExitCode: 0,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
RetentionPeriod: metav1.Duration{Duration: time.Hour},
RefreshComponentsInterval: metav1.Duration{Duration: time.Hour},
Address: "localhost:8080", // Add a valid address to pass other validations
EnableAutoUpdate: tt.enableAutoUpdate,
AutoUpdateExitCode: tt.autoUpdateExitCode,
}

err := cfg.Validate()

if (err != nil) != tt.wantErr {
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
}

if tt.wantErr && err != ErrInvalidAutoUpdateExitCode {
t.Errorf("Config.Validate() error = %v, want %v", err, ErrInvalidAutoUpdateExitCode)
}
})
}
}

func TestLoadConfigYAML(t *testing.T) {
t.Parallel()

Expand Down
38 changes: 32 additions & 6 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type Server struct {
fifo *goOS.File
session *session.Session
enableAutoUpdate bool
autoUpdateExitCode int
}

func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID string, opts ...gpud_config.OpOption) (_ *Server, retErr error) {
Expand All @@ -122,9 +123,10 @@ func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID
return nil, fmt.Errorf("failed to get fifo path: %w", err)
}
s := &Server{
db: db,
fifoPath: fifoPath,
enableAutoUpdate: config.EnableAutoUpdate,
db: db,
fifoPath: fifoPath,
enableAutoUpdate: config.EnableAutoUpdate,
autoUpdateExitCode: config.AutoUpdateExitCode,
}
defer func() {
if retErr != nil {
Expand Down Expand Up @@ -1168,9 +1170,22 @@ func (s *Server) updateToken(ctx context.Context, db *sql.DB, uid string, endpoi
if dbToken, err := state.GetLoginInfo(ctx, db, uid); err == nil {
userToken = dbToken
}

if userToken != "" {
s.session = session.NewSession(ctx, fmt.Sprintf("https://%s/api/v1/session", endpoint), uid, 3*time.Second, s.enableAutoUpdate)
var err error
s.session, err = session.NewSession(
ctx,
fmt.Sprintf("https://%s/api/v1/session", endpoint),
session.WithMachineID(uid),
session.WithPipeInterval(3*time.Second),
session.WithEnableAutoUpdate(s.enableAutoUpdate),
session.WithAutoUpdateExitCode(s.autoUpdateExitCode),
)
if err != nil {
log.Logger.Errorw("error creating session", "error", err)
}
}

if _, err := goOS.Stat(pipePath); err == nil {
if err = goOS.Remove(pipePath); err != nil {
log.Logger.Errorf("error creating pipe: %v", err)
Expand Down Expand Up @@ -1203,9 +1218,20 @@ func (s *Server) updateToken(ctx context.Context, db *sql.DB, uid string, endpoi
if s.session != nil {
s.session.Stop()
}
s.session = session.NewSession(ctx, fmt.Sprintf("https://%s/api/v1/session", endpoint), uid, 3*time.Second, s.enableAutoUpdate)
s.session, err = session.NewSession(
ctx,
fmt.Sprintf("https://%s/api/v1/session", endpoint),
session.WithMachineID(uid),
session.WithPipeInterval(3*time.Second),
session.WithEnableAutoUpdate(s.enableAutoUpdate),
session.WithAutoUpdateExitCode(s.autoUpdateExitCode),
)
if err != nil {
log.Logger.Errorw("error creating session", "error", err)
}
}
time.Sleep(1 * time.Second)

time.Sleep(time.Second)
}
}

Expand Down
51 changes: 36 additions & 15 deletions internal/session/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"time"

v1 "github.com/leptonai/gpud/api/v1"
Expand Down Expand Up @@ -53,7 +53,9 @@ func (s *Session) serve() {
continue
}

needExit := -1
response := &Response{}

switch payload.Method {
case "metrics":
metrics, err := s.getMetrics(ctx, payload)
Expand All @@ -71,32 +73,51 @@ func (s *Session) serve() {
response.Events = events

case "update":
if !s.enableAutoUpdate {
log.Logger.Warnw("auto update is disabled -- skipping update")
response.Error = errors.New("auto update is disabled")
break
}

systemdManaged, _ := systemd.IsActive("gpud.service")
if !systemdManaged {
log.Logger.Debugw("gpud is not managed with systemd")
if s.autoUpdateExitCode == -1 && !systemdManaged {
log.Logger.Warnw("gpud is not managed with systemd and auto update by exit code is not set -- skipping update")
response.Error = errors.New("gpud is not managed with systemd")
} else if !s.enableAutoUpdate {
log.Logger.Debugw("auto update is disabled")
response.Error = errors.New("auto update is disabled")
} else {
nextVersion := payload.UpdateVersion
if nextVersion == "" {
response.Error = fmt.Errorf("update_version is empty")
} else {
err := update.Update(nextVersion, update.DefaultUpdateURL)
if err != nil {
response.Error = err
}
break
}

nextVersion := payload.UpdateVersion
if nextVersion == "" {
log.Logger.Warnw("target update_version is empty -- skipping update")
response.Error = errors.New("update_version is empty")
break
}

if systemdManaged {
response.Error = update.Update(nextVersion, update.DefaultUpdateURL)
break
}

if s.autoUpdateExitCode != -1 {
response.Error = update.UpdateOnlyBinary(nextVersion, update.DefaultUpdateURL)
if response.Error == nil {
needExit = s.autoUpdateExitCode
}
}
}

cancel()

responseRaw, _ := json.Marshal(response)
s.writer <- Body{
Data: responseRaw,
ReqID: body.ReqID,
}

if needExit != -1 {
log.Logger.Infow("exiting with code for auto update", "code", needExit)
os.Exit(s.autoUpdateExitCode)
}
}
}

Expand Down
71 changes: 65 additions & 6 deletions internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,58 @@ import (
"github.com/leptonai/gpud/log"
)

type Op struct {
machineID string
pipeInterval time.Duration
enableAutoUpdate bool
autoUpdateExitCode int
}

type OpOption func(*Op)

var ErrAutoUpdateDisabledButExitCodeSet = errors.New("auto update is disabled but auto update by exit code is set")

func (op *Op) applyOpts(opts []OpOption) error {
op.autoUpdateExitCode = -1

for _, opt := range opts {
opt(op)
}

if !op.enableAutoUpdate && op.autoUpdateExitCode != -1 {
return ErrAutoUpdateDisabledButExitCodeSet
}

return nil
}

func WithMachineID(machineID string) OpOption {
return func(op *Op) {
op.machineID = machineID
}
}

func WithPipeInterval(t time.Duration) OpOption {
return func(op *Op) {
op.pipeInterval = t
}
}

func WithEnableAutoUpdate(enableAutoUpdate bool) OpOption {
return func(op *Op) {
op.enableAutoUpdate = enableAutoUpdate
}
}

// Triggers an auto update of GPUd itself by exiting the process with the given exit code.
// Useful when the machine is managed by the Kubernetes daemonset and we want to
// trigger an auto update when the daemonset restarts the machine.
func WithAutoUpdateExitCode(autoUpdateExitCode int) OpOption {
return func(op *Op) {
op.autoUpdateExitCode = autoUpdateExitCode
}
}

type Session struct {
ctx context.Context
cancel context.CancelFunc
Expand All @@ -31,10 +83,16 @@ type Session struct {
readerCloseCh chan bool
readerClosedCh chan bool

enableAutoUpdate bool
enableAutoUpdate bool
autoUpdateExitCode int
}

func NewSession(ctx context.Context, endpoint string, machineID string, pipeInterval time.Duration, enableAutoUpdate bool) *Session {
func NewSession(ctx context.Context, endpoint string, opts ...OpOption) (*Session, error) {
op := &Op{}
if err := op.applyOpts(opts); err != nil {
return nil, err
}

cps := make([]string, 0)
allComponents := components.GetAllComponents()
for key := range allComponents {
Expand All @@ -46,14 +104,15 @@ func NewSession(ctx context.Context, endpoint string, machineID string, pipeInte
ctx: cctx,
cancel: ccancel,

pipeInterval: pipeInterval,
pipeInterval: op.pipeInterval,

endpoint: endpoint,
machineID: machineID,
machineID: op.machineID,

components: cps,

enableAutoUpdate: enableAutoUpdate,
enableAutoUpdate: op.enableAutoUpdate,
autoUpdateExitCode: op.autoUpdateExitCode,
}

s.reader = make(chan Body, 20)
Expand All @@ -65,7 +124,7 @@ func NewSession(ctx context.Context, endpoint string, machineID string, pipeInte
s.keepAlive()
go s.serve()

return s
return s, nil
}

type Body struct {
Expand Down
Loading

0 comments on commit e63d5c7

Please sign in to comment.