diff --git a/lib/autoupdate/agent.go b/lib/autoupdate/agent.go
deleted file mode 100644
index 921562251f1f..000000000000
--- a/lib/autoupdate/agent.go
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2024 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package autoupdate
-
-import (
- "context"
- "errors"
- "io/fs"
- "log/slog"
- "os"
-
- "github.com/google/renameio/v2"
- "github.com/gravitational/trace"
- "gopkg.in/yaml.v3"
-)
-
-const (
- agentUpdateConfigVersion = "v1"
- agentUpdateConfigKind = "update_config"
-)
-
-// AgentUpdateConfig describes the update.yaml file schema.
-type AgentUpdateConfig struct {
- // Version of the configuration file
- Version string `yaml:"version"`
- // Kind of configuration file (always "update_config")
- Kind string `yaml:"kind"`
- // Spec contains user-specified configuration.
- Spec AgentUpdateSpec `yaml:"spec"`
- // Status contains state configuration.
- Status AgentUpdateStatus `yaml:"status"`
-}
-
-// AgentUpdateSpec describes the spec field in update.yaml.
-type AgentUpdateSpec struct {
- // Proxy address
- Proxy string `yaml:"proxy"`
- // Group update identifier
- Group string `yaml:"group"`
- // URLTemplate for the Teleport tgz download URL.
- URLTemplate string `yaml:"url_template"`
- // Enabled controls whether auto-updates are enabled.
- Enabled bool `yaml:"enabled"`
-}
-
-// AgentUpdateStatus describes the status field in update.yaml.
-type AgentUpdateStatus struct {
- // ActiveVersion is the currently active Teleport version.
- ActiveVersion string `yaml:"active_version"`
-}
-
-type AgentUpdater struct {
- Log *slog.Logger
-}
-
-// Disable disables agent updates.
-// updatePath must be a path to the update.yaml file.
-func (u AgentUpdater) Disable(ctx context.Context, updatePath string) error {
- cfg, err := u.readConfig(updatePath)
- if err != nil {
- return trace.Errorf("failed to read updates.yaml: %w", err)
- }
- if !cfg.Spec.Enabled {
- u.Log.InfoContext(ctx, "Automatic updates already disabled")
- return nil
- }
- cfg.Spec.Enabled = false
- if err := u.writeConfig(updatePath, cfg); err != nil {
- return trace.Errorf("failed to write updates.yaml: %w", err)
- }
- return nil
-}
-
-// readConfig reads update.yaml
-func (AgentUpdater) readConfig(path string) (*AgentUpdateConfig, error) {
- f, err := os.Open(path)
- if errors.Is(err, fs.ErrNotExist) {
- return &AgentUpdateConfig{
- Version: agentUpdateConfigVersion,
- Kind: agentUpdateConfigKind,
- }, nil
- }
- if err != nil {
- return nil, trace.Errorf("failed to open: %w", err)
- }
- defer f.Close()
- var cfg AgentUpdateConfig
- if err := yaml.NewDecoder(f).Decode(&cfg); err != nil {
- return nil, trace.Errorf("failed to parse: %w", err)
- }
- if k := cfg.Kind; k != agentUpdateConfigKind {
- return nil, trace.Errorf("invalid kind %q", k)
- }
- if v := cfg.Version; v != agentUpdateConfigVersion {
- return nil, trace.Errorf("invalid version %q", v)
- }
- return &cfg, nil
-}
-
-// writeConfig writes update.yaml atomically, ensuring the file cannot be corrupted.
-func (AgentUpdater) writeConfig(filename string, cfg *AgentUpdateConfig) error {
- opts := []renameio.Option{
- renameio.WithPermissions(0755),
- renameio.WithExistingPermissions(),
- }
- t, err := renameio.NewPendingFile(filename, opts...)
- if err != nil {
- return trace.Wrap(err)
- }
- defer t.Cleanup()
- err = yaml.NewEncoder(t).Encode(cfg)
- if err != nil {
- return trace.Wrap(err)
- }
- return trace.Wrap(t.CloseAtomicallyReplace())
-}
diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go
new file mode 100644
index 000000000000..e31813866eac
--- /dev/null
+++ b/lib/autoupdate/agent/installer.go
@@ -0,0 +1,317 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "io"
+ "log/slog"
+ "net/http"
+ "os"
+ "path/filepath"
+ "runtime"
+ "text/template"
+ "time"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+const (
+ checksumType = "sha256"
+ checksumHexLen = sha256.Size * 2 // bytes to hex
+)
+
+// LocalInstaller manages the creation and removal of installations
+// of Teleport.
+type LocalInstaller struct {
+ // InstallDir contains each installation, named by version.
+ InstallDir string
+ // HTTP is an HTTP client for downloading Teleport.
+ HTTP *http.Client
+ // Log contains a logger.
+ Log *slog.Logger
+ // ReservedFreeTmpDisk is the amount of disk that must remain free in /tmp
+ ReservedFreeTmpDisk uint64
+ // ReservedFreeInstallDisk is the amount of disk that must remain free in the install directory.
+ ReservedFreeInstallDisk uint64
+}
+
+// Remove a Teleport version directory from InstallDir.
+// This function is idempotent.
+func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
+ versionDir := filepath.Join(li.InstallDir, version)
+ sumPath := filepath.Join(versionDir, checksumType)
+
+ // invalidate checksum first, to protect against partially-removed
+ // directory with valid checksum.
+ err := os.Remove(sumPath)
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return trace.Wrap(err)
+ }
+ if err := os.RemoveAll(versionDir); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// Install a Teleport version directory in InstallDir.
+// This function is idempotent.
+func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
+ versionDir := filepath.Join(li.InstallDir, version)
+ sumPath := filepath.Join(versionDir, checksumType)
+
+ // generate download URI from template
+ uri, err := makeURL(template, version, flags)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Get new and old checksums. If they match, skip download.
+ // Otherwise, clear the old version directory and re-download.
+ checksumURI := uri + "." + checksumType
+ newSum, err := li.getChecksum(ctx, checksumURI)
+ if err != nil {
+ return trace.Errorf("failed to download checksum from %s: %w", checksumURI, err)
+ }
+ oldSum, err := readChecksum(sumPath)
+ if err == nil {
+ if bytes.Equal(oldSum, newSum) {
+ li.Log.InfoContext(ctx, "Version already present.", "version", version)
+ return nil
+ }
+ li.Log.WarnContext(ctx, "Removing version that does not match checksum.", "version", version)
+ if err := li.Remove(ctx, version); err != nil {
+ return trace.Wrap(err)
+ }
+ } else if !errors.Is(err, os.ErrNotExist) {
+ li.Log.WarnContext(ctx, "Removing version with unreadable checksum.", "version", version, "error", err)
+ if err := li.Remove(ctx, version); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
+ // Verify that we have enough free temp space, then download tgz
+ freeTmp, err := utils.FreeDiskWithReserve(os.TempDir(), li.ReservedFreeTmpDisk)
+ if err != nil {
+ return trace.Errorf("failed to calculate free disk: %w", err)
+ }
+ f, err := os.CreateTemp("", "teleport-update-")
+ if err != nil {
+ return trace.Errorf("failed to create temporary file: %w", err)
+ }
+ defer func() {
+ _ = f.Close() // data never read after close
+ if err := os.Remove(f.Name()); err != nil {
+ li.Log.WarnContext(ctx, "Failed to cleanup temporary download.", "error", err)
+ }
+ }()
+ pathSum, err := li.download(ctx, f, int64(freeTmp), uri)
+ if err != nil {
+ return trace.Errorf("failed to download teleport: %w", err)
+ }
+
+ // Seek to the start of the tgz file after writing
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ return trace.Errorf("failed seek to start of download: %w", err)
+ }
+ // Check integrity before decompression
+ if !bytes.Equal(newSum, pathSum) {
+ return trace.Errorf("mismatched checksum, download possibly corrupt")
+ }
+ // Get uncompressed size of the tgz
+ n, err := uncompressedSize(f)
+ if err != nil {
+ return trace.Errorf("failed to determine uncompressed size: %w", err)
+ }
+ // Seek to start of tgz after reading size
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ return trace.Errorf("failed seek to start: %w", err)
+ }
+ if err := li.extract(ctx, versionDir, f, n); err != nil {
+ return trace.Errorf("failed to extract teleport: %w", err)
+ }
+ // Write the checksum last. This marks the version directory as valid.
+ err = os.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), 0755)
+ if err != nil {
+ return trace.Errorf("failed to write checksum: %w", err)
+ }
+ return nil
+}
+
+// makeURL to download the Teleport tgz.
+func makeURL(uriTmpl, version string, flags InstallFlags) (string, error) {
+ tmpl, err := template.New("uri").Parse(uriTmpl)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ var uriBuf bytes.Buffer
+ params := struct {
+ OS, Version, Arch string
+ FIPS, Enterprise bool
+ }{
+ OS: runtime.GOOS,
+ Version: version,
+ Arch: runtime.GOARCH,
+ FIPS: flags&FlagFIPS != 0,
+ Enterprise: flags&(FlagEnterprise|FlagFIPS) != 0,
+ }
+ err = tmpl.Execute(&uriBuf, params)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ return uriBuf.String(), nil
+}
+
+// readChecksum from the version directory.
+func readChecksum(path string) ([]byte, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer f.Close()
+ var buf bytes.Buffer
+ _, err = io.CopyN(&buf, f, checksumHexLen)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ raw := buf.String()
+ sum, err := hex.DecodeString(raw)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return sum, nil
+}
+
+func (li *LocalInstaller) getChecksum(ctx context.Context, url string) ([]byte, error) {
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ resp, err := li.HTTP.Do(req)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, trace.Errorf("checksum not found: %s", url)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, trace.Errorf("unexpected HTTP status code: %d", resp.StatusCode)
+ }
+
+ var buf bytes.Buffer
+ _, err = io.CopyN(&buf, resp.Body, checksumHexLen)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ sum, err := hex.DecodeString(buf.String())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return sum, nil
+}
+
+func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, url string) (sum []byte, err error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ resp, err := li.HTTP.Do(req)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, trace.Errorf("Teleport download not found: %s", url)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, trace.Errorf("unexpected HTTP status code: %d", resp.StatusCode)
+ }
+ li.Log.InfoContext(ctx, "Downloading Teleport tarball.", "url", url, "size", resp.ContentLength)
+
+ // Ensure there's enough space in /tmp for the download.
+ size := resp.ContentLength
+ if size < 0 {
+ li.Log.WarnContext(ctx, "Content length missing from response, unable to verify Teleport download size.")
+ size = max
+ } else if size > max {
+ return nil, trace.Errorf("size of download (%d bytes) exceeds available disk space (%d bytes)", resp.ContentLength, max)
+ }
+ // Calculate checksum concurrently with download.
+ shaReader := sha256.New()
+ n, err := io.CopyN(w, io.TeeReader(resp.Body, shaReader), size)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if resp.ContentLength >= 0 && n != resp.ContentLength {
+ return nil, trace.Errorf("mismatch in Teleport download size")
+ }
+ return shaReader.Sum(nil), nil
+}
+
+func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64) error {
+ if err := os.MkdirAll(dstDir, 0755); err != nil {
+ return trace.Wrap(err)
+ }
+ free, err := utils.FreeDiskWithReserve(dstDir, li.ReservedFreeInstallDisk)
+ if err != nil {
+ return trace.Errorf("failed to calculate free disk in %q: %w", dstDir, err)
+ }
+ // Bail if there's not enough free disk space at the target
+ if d := int64(free) - max; d < 0 {
+ return trace.Errorf("%q needs %d additional bytes of disk space for decompression", dstDir, -d)
+ }
+ zr, err := gzip.NewReader(src)
+ if err != nil {
+ return trace.Errorf("requires gzip-compressed body: %v", err)
+ }
+ li.Log.InfoContext(ctx, "Extracting Teleport tarball.", "path", dstDir, "size", max)
+
+ // TODO(sclevine): add variadic arg to Extract to extract teleport/ subdir into bin/.
+ err = utils.Extract(zr, dstDir)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+func uncompressedSize(f io.Reader) (int64, error) {
+ // NOTE: The gzip length trailer is very unreliable,
+ // but we could optimize this in the future if
+ // we are willing to verify that all published
+ // Teleport tarballs have valid trailers.
+ r, err := gzip.NewReader(f)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ n, err := io.Copy(io.Discard, r)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ return n, nil
+}
diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go
new file mode 100644
index 000000000000..be778f7bcf16
--- /dev/null
+++ b/lib/autoupdate/agent/installer_test.go
@@ -0,0 +1,189 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "archive/tar"
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTeleportInstaller_Install(t *testing.T) {
+ t.Parallel()
+ const version = "new-version"
+
+ _, testSum := testTGZ(t, version)
+
+ tests := []struct {
+ name string
+ reservedTmp uint64
+ reservedInstall uint64
+ existingSum string
+ flags InstallFlags
+
+ errMatch string
+ }{
+ {
+ name: "not present",
+ },
+ {
+ name: "present",
+ existingSum: testSum,
+ },
+ {
+ name: "mismatched checksum",
+ existingSum: hex.EncodeToString(sha256.New().Sum(nil)),
+ },
+ {
+ name: "unreadable checksum",
+ existingSum: "bad",
+ },
+ {
+ name: "out of space in /tmp",
+ reservedTmp: reservedFreeDisk * 1_000_000_000,
+ errMatch: "no free space left",
+ },
+ {
+ name: "out of space in install dir",
+ reservedInstall: reservedFreeDisk * 1_000_000_000,
+ errMatch: "no free space left",
+ },
+ // TODO(sclevine): test flags
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+
+ if tt.existingSum != "" {
+ err := os.WriteFile(filepath.Join(dir, checksumType), []byte(tt.existingSum), os.ModePerm)
+ require.NoError(t, err)
+ }
+
+ // test parameters
+ var dlPath, shaPath, shasum string
+
+ // test server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ tgz, sum := testTGZ(t, version)
+ shasum = sum
+ var out *bytes.Buffer
+ if strings.HasSuffix(r.URL.Path, "."+checksumType) { // checksum request
+ shaPath = r.URL.Path
+ out = bytes.NewBufferString(sum)
+ } else { // tgz request
+ dlPath = r.URL.Path
+ out = tgz
+ }
+ w.Header().Set("Content-Length", strconv.Itoa(out.Len()))
+ _, err := io.Copy(w, out)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }))
+ t.Cleanup(server.Close)
+
+ installer := &LocalInstaller{
+ InstallDir: dir,
+ HTTP: http.DefaultClient,
+ Log: slog.Default(),
+ ReservedFreeTmpDisk: tt.reservedTmp,
+ ReservedFreeInstallDisk: tt.reservedInstall,
+ }
+ ctx := context.Background()
+ err := installer.Install(ctx, version, server.URL+"/{{.OS}}/{{.Arch}}/{{.Version}}", tt.flags)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+
+ const expectedPath = "/" + runtime.GOOS + "/" + runtime.GOARCH + "/" + version
+ require.Equal(t, expectedPath, dlPath)
+ require.Equal(t, expectedPath+"."+checksumType, shaPath)
+
+ teleportVersion, err := os.ReadFile(filepath.Join(dir, version, "teleport"))
+ require.NoError(t, err)
+ require.Equal(t, version, string(teleportVersion))
+
+ tshVersion, err := os.ReadFile(filepath.Join(dir, version, "tsh"))
+ require.NoError(t, err)
+ require.Equal(t, version, string(tshVersion))
+
+ sum, err := os.ReadFile(filepath.Join(dir, version, checksumType))
+ require.NoError(t, err)
+ require.Equal(t, string(sum), shasum)
+ })
+ }
+}
+
+func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) {
+ t.Helper()
+
+ var buf bytes.Buffer
+
+ sha := sha256.New()
+ gz := gzip.NewWriter(io.MultiWriter(&buf, sha))
+ tw := tar.NewWriter(gz)
+
+ var files = []struct {
+ Name, Body string
+ }{
+ {"teleport", version},
+ {"tsh", version},
+ }
+ for _, file := range files {
+ hdr := &tar.Header{
+ Name: file.Name,
+ Mode: 0600,
+ Size: int64(len(file.Body)),
+ }
+ if err := tw.WriteHeader(hdr); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := tw.Write([]byte(file.Body)); err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err := tw.Close(); err != nil {
+ t.Fatal(err)
+ }
+ if err := gz.Close(); err != nil {
+ t.Fatal(err)
+ }
+ return &buf, hex.EncodeToString(sha.Sum(nil))
+}
diff --git a/lib/autoupdate/testdata/TestAgentUpdater_Disable/already_disabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden
similarity index 100%
rename from lib/autoupdate/testdata/TestAgentUpdater_Disable/already_disabled.golden
rename to lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden
diff --git a/lib/autoupdate/testdata/TestAgentUpdater_Disable/enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden
similarity index 100%
rename from lib/autoupdate/testdata/TestAgentUpdater_Disable/enabled.golden
rename to lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden
new file mode 100644
index 000000000000..e03f369eb101
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden
@@ -0,0 +1,9 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ group: ""
+ url_template: ""
+ enabled: true
+status:
+ active_version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden
new file mode 100644
index 000000000000..e03f369eb101
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden
@@ -0,0 +1,9 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ group: ""
+ url_template: ""
+ enabled: true
+status:
+ active_version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden
new file mode 100644
index 000000000000..b172d858bc55
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden
@@ -0,0 +1,9 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ group: group
+ url_template: https://example.com
+ enabled: true
+status:
+ active_version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden
new file mode 100644
index 000000000000..bb9ce8b9d8fa
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden
@@ -0,0 +1,9 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ group: new-group
+ url_template: https://example.com/new
+ enabled: true
+status:
+ active_version: new-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden
new file mode 100644
index 000000000000..e03f369eb101
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden
@@ -0,0 +1,9 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ group: ""
+ url_template: ""
+ enabled: true
+status:
+ active_version: 16.3.0
diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go
new file mode 100644
index 000000000000..59df5f0b3ba8
--- /dev/null
+++ b/lib/autoupdate/agent/updater.go
@@ -0,0 +1,341 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "io/fs"
+ "log/slog"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/google/renameio/v2"
+ "github.com/gravitational/trace"
+ "gopkg.in/yaml.v3"
+
+ "github.com/gravitational/teleport/api/client/webclient"
+ libdefaults "github.com/gravitational/teleport/lib/defaults"
+ libutils "github.com/gravitational/teleport/lib/utils"
+)
+
+const (
+ // cdnURITemplate is the default template for the Teleport tgz download.
+ cdnURITemplate = "https://cdn.teleport.dev/teleport{{if .Enterprise}}-ent{{end}}-v{{.Version}}-{{.OS}}-{{.Arch}}{{if .FIPS}}-fips{{end}}-bin.tar.gz"
+ // reservedFreeDisk is the minimum required free space left on disk during downloads.
+ // TODO(sclevine): This value is arbitrary and could be replaced by, e.g., min(1%, 200mb) in the future
+ // to account for a range of disk sizes.
+ reservedFreeDisk = 10_000_000 // 10 MB
+)
+
+const (
+ // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update.
+ updateConfigName = "update.yaml"
+
+ // UpdateConfig metadata
+ updateConfigVersion = "v1"
+ updateConfigKind = "update_config"
+)
+
+// UpdateConfig describes the update.yaml file schema.
+type UpdateConfig struct {
+ // Version of the configuration file
+ Version string `yaml:"version"`
+ // Kind of configuration file (always "update_config")
+ Kind string `yaml:"kind"`
+ // Spec contains user-specified configuration.
+ Spec UpdateSpec `yaml:"spec"`
+ // Status contains state configuration.
+ Status UpdateStatus `yaml:"status"`
+}
+
+// UpdateSpec describes the spec field in update.yaml.
+type UpdateSpec struct {
+ // Proxy address
+ Proxy string `yaml:"proxy"`
+ // Group specifies the update group identifier for the agent.
+ Group string `yaml:"group"`
+ // URLTemplate for the Teleport tgz download URL.
+ URLTemplate string `yaml:"url_template"`
+ // Enabled controls whether auto-updates are enabled.
+ Enabled bool `yaml:"enabled"`
+}
+
+// UpdateStatus describes the status field in update.yaml.
+type UpdateStatus struct {
+ // ActiveVersion is the currently active Teleport version.
+ ActiveVersion string `yaml:"active_version"`
+}
+
+// NewLocalUpdater returns a new Updater that auto-updates local
+// installations of the Teleport agent.
+// The AutoUpdater uses an HTTP client with sane defaults for downloads, and
+// will not fill disk to within 10 MB of available capacity.
+func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) {
+ certPool, err := x509.SystemCertPool()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ tr, err := libdefaults.Transport()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ tr.TLSClientConfig = &tls.Config{
+ InsecureSkipVerify: cfg.InsecureSkipVerify,
+ RootCAs: certPool,
+ }
+ client := &http.Client{
+ Transport: tr,
+ Timeout: cfg.DownloadTimeout,
+ }
+ if cfg.Log == nil {
+ cfg.Log = slog.Default()
+ }
+ return &Updater{
+ Log: cfg.Log,
+ Pool: certPool,
+ InsecureSkipVerify: cfg.InsecureSkipVerify,
+ ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName),
+ Installer: &LocalInstaller{
+ InstallDir: cfg.VersionsDir,
+ HTTP: client,
+ Log: cfg.Log,
+
+ ReservedFreeTmpDisk: reservedFreeDisk,
+ ReservedFreeInstallDisk: reservedFreeDisk,
+ },
+ }, nil
+}
+
+// LocalUpdaterConfig specifies configuration for managing local agent auto-updates.
+type LocalUpdaterConfig struct {
+ // Log contains a slog logger.
+ // Defaults to slog.Default() if nil.
+ Log *slog.Logger
+ // InsecureSkipVerify turns off TLS certificate verification.
+ InsecureSkipVerify bool
+ // DownloadTimeout is a timeout for file download requests.
+ // Defaults to no timeout.
+ DownloadTimeout time.Duration
+ // VersionsDir for installing Teleport (usually /var/lib/teleport/versions).
+ VersionsDir string
+}
+
+// Updater implements the agent-local logic for Teleport agent auto-updates.
+type Updater struct {
+ // Log contains a logger.
+ Log *slog.Logger
+ // Pool used for requests to the Teleport web API.
+ Pool *x509.CertPool
+ // InsecureSkipVerify skips TLS verification.
+ InsecureSkipVerify bool
+ // ConfigPath contains the path to the agent auto-updates configuration.
+ ConfigPath string
+ // Installer manages installations of the Teleport agent.
+ Installer Installer
+}
+
+// Installer provides an API for installing Teleport agents.
+type Installer interface {
+ // Install the Teleport agent at version from the download template.
+ // This function must be idempotent.
+ Install(ctx context.Context, version, template string, flags InstallFlags) error
+ // Remove the Teleport agent at version.
+ // This function must be idempotent.
+ Remove(ctx context.Context, version string) error
+}
+
+// InstallFlags sets flags for the Teleport installation
+type InstallFlags int
+
+const (
+ // FlagEnterprise installs enterprise Teleport
+ FlagEnterprise InstallFlags = 1 << iota
+ // FlagFIPS installs FIPS Teleport
+ FlagFIPS
+)
+
+// OverrideConfig contains overrides for individual update operations.
+// If validated, these overrides may be persisted to disk.
+type OverrideConfig struct {
+ // Proxy address, scheme and port optional.
+ // Overrides existing value if specified.
+ Proxy string
+ // Group identifier for updates (e.g., staging)
+ // Overrides existing value if specified.
+ Group string
+ // URLTemplate for the Teleport tgz download URL
+ // Overrides existing value if specified.
+ URLTemplate string
+ // ForceVersion to the specified version.
+ ForceVersion string
+}
+
+// Enable enables agent updates and attempts an initial update.
+// If the initial update succeeds, auto-updates are enabled and the configuration is persisted.
+// Otherwise, the auto-updates configuration is not changed.
+// This function is idempotent.
+func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
+ // Read configuration from update.yaml and override any new values passed as flags.
+ cfg, err := u.readConfig(u.ConfigPath)
+ if err != nil {
+ return trace.Errorf("failed to read %s: %w", updateConfigName, err)
+ }
+ if override.Proxy != "" {
+ cfg.Spec.Proxy = override.Proxy
+ }
+ if override.Group != "" {
+ cfg.Spec.Group = override.Group
+ }
+ if override.URLTemplate != "" {
+ cfg.Spec.URLTemplate = override.URLTemplate
+ }
+ cfg.Spec.Enabled = true
+ if err := validateUpdatesSpec(&cfg.Spec); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Lookup target version from the proxy.
+ addr, err := libutils.ParseAddr(cfg.Spec.Proxy)
+ if err != nil {
+ return trace.Errorf("failed to parse proxy server address: %w", err)
+ }
+
+ desiredVersion := override.ForceVersion
+ if desiredVersion == "" {
+ resp, err := webclient.Find(&webclient.Config{
+ Context: ctx,
+ ProxyAddr: addr.Addr,
+ Insecure: u.InsecureSkipVerify,
+ Timeout: 30 * time.Second,
+ //Group: cfg.Spec.Group, // TODO(sclevine): add web API for verssion
+ Pool: u.Pool,
+ })
+ if err != nil {
+ return trace.Errorf("failed to request version from proxy: %w", err)
+ }
+ desiredVersion, _ = "16.3.0", resp // TODO(sclevine): add web API for version
+ //desiredVersion := resp.AutoUpdate.AgentVersion
+ }
+
+ if desiredVersion == "" {
+ return trace.Errorf("agent version not available from Teleport cluster")
+ }
+ // If the active version and target don't match, kick off upgrade.
+ template := cfg.Spec.URLTemplate
+ if template == "" {
+ template = cdnURITemplate
+ }
+ err = u.Installer.Install(ctx, desiredVersion, template, 0) // TODO(sclevine): add web API for flags
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if cfg.Status.ActiveVersion != desiredVersion {
+ u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion)
+ } else {
+ u.Log.InfoContext(ctx, "Target version successfully validated.", "version", desiredVersion)
+ }
+ cfg.Status.ActiveVersion = desiredVersion
+
+ // Always write the configuration file if enable succeeds.
+ if err := u.writeConfig(u.ConfigPath, cfg); err != nil {
+ return trace.Errorf("failed to write %s: %w", updateConfigName, err)
+ }
+ u.Log.InfoContext(ctx, "Configuration updated.")
+ return nil
+}
+
+func validateUpdatesSpec(spec *UpdateSpec) error {
+ if spec.URLTemplate != "" &&
+ !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") {
+ return trace.Errorf("Teleport download URL must use TLS (https://)")
+ }
+
+ if spec.Proxy == "" {
+ return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName)
+ }
+ return nil
+}
+
+// Disable disables agent auto-updates.
+// This function is idempotent.
+func (u *Updater) Disable(ctx context.Context) error {
+ cfg, err := u.readConfig(u.ConfigPath)
+ if err != nil {
+ return trace.Errorf("failed to read %s: %w", updateConfigName, err)
+ }
+ if !cfg.Spec.Enabled {
+ u.Log.InfoContext(ctx, "Automatic updates already disabled.")
+ return nil
+ }
+ cfg.Spec.Enabled = false
+ if err := u.writeConfig(u.ConfigPath, cfg); err != nil {
+ return trace.Errorf("failed to write %s: %w", updateConfigName, err)
+ }
+ return nil
+}
+
+// readConfig reads UpdateConfig from a file.
+func (*Updater) readConfig(path string) (*UpdateConfig, error) {
+ f, err := os.Open(path)
+ if errors.Is(err, fs.ErrNotExist) {
+ return &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ }, nil
+ }
+ if err != nil {
+ return nil, trace.Errorf("failed to open: %w", err)
+ }
+ defer f.Close()
+ var cfg UpdateConfig
+ if err := yaml.NewDecoder(f).Decode(&cfg); err != nil {
+ return nil, trace.Errorf("failed to parse: %w", err)
+ }
+ if k := cfg.Kind; k != updateConfigKind {
+ return nil, trace.Errorf("invalid kind %q", k)
+ }
+ if v := cfg.Version; v != updateConfigVersion {
+ return nil, trace.Errorf("invalid version %q", v)
+ }
+ return &cfg, nil
+}
+
+// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted.
+func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error {
+ opts := []renameio.Option{
+ renameio.WithPermissions(0755),
+ renameio.WithExistingPermissions(),
+ }
+ t, err := renameio.NewPendingFile(filename, opts...)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer t.Cleanup()
+ err = yaml.NewEncoder(t).Encode(cfg)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(t.CloseAtomicallyReplace())
+}
diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go
new file mode 100644
index 000000000000..6568fbaede9e
--- /dev/null
+++ b/lib/autoupdate/agent/updater_test.go
@@ -0,0 +1,315 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gopkg.in/yaml.v3"
+
+ "github.com/gravitational/teleport/lib/utils/golden"
+)
+
+func TestUpdater_Disable(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *UpdateConfig // nil -> file not present
+ errMatch string
+ }{
+ {
+ name: "enabled",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: true,
+ },
+ },
+ },
+ {
+ name: "already disabled",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: false,
+ },
+ },
+ },
+ {
+ name: "config does not exist",
+ },
+ {
+ name: "invalid metadata",
+ cfg: &UpdateConfig{
+ Spec: UpdateSpec{
+ Enabled: true,
+ },
+ },
+ errMatch: "invalid",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "update.yaml")
+
+ // Create config file only if provided in test case
+ if tt.cfg != nil {
+ b, err := yaml.Marshal(tt.cfg)
+ require.NoError(t, err)
+ err = os.WriteFile(cfgPath, b, 0600)
+ require.NoError(t, err)
+ }
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ VersionsDir: dir,
+ })
+ require.NoError(t, err)
+ err = updater.Disable(context.Background())
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+
+ data, err := os.ReadFile(cfgPath)
+
+ // If no config is present, disable should not create it
+ if tt.cfg == nil {
+ require.ErrorIs(t, err, os.ErrNotExist)
+ return
+ }
+ require.NoError(t, err)
+
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+}
+
+func TestUpdater_Enable(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *UpdateConfig // nil -> file not present
+ userCfg OverrideConfig
+ installErr error
+
+ installedVersion string
+ installedTemplate string
+ errMatch string
+ }{
+ {
+ name: "config from file",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Group: "group",
+ URLTemplate: "https://example.com",
+ },
+ Status: UpdateStatus{
+ ActiveVersion: "old-version",
+ },
+ },
+ installedVersion: "16.3.0",
+ installedTemplate: "https://example.com",
+ },
+ {
+ name: "config from user",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Group: "old-group",
+ URLTemplate: "https://example.com/old",
+ },
+ Status: UpdateStatus{
+ ActiveVersion: "old-version",
+ },
+ },
+ userCfg: OverrideConfig{
+ Group: "new-group",
+ URLTemplate: "https://example.com/new",
+ ForceVersion: "new-version",
+ },
+ installedVersion: "new-version",
+ installedTemplate: "https://example.com/new",
+ },
+ {
+ name: "already enabled",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ ActiveVersion: "old-version",
+ },
+ },
+ installedVersion: "16.3.0",
+ installedTemplate: cdnURITemplate,
+ },
+ {
+ name: "insecure URL",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ URLTemplate: "http://example.com",
+ },
+ },
+ errMatch: "URL must use TLS",
+ },
+ {
+ name: "install error",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ URLTemplate: "https://example.com",
+ },
+ },
+ installErr: errors.New("install error"),
+ errMatch: "install error",
+ },
+ {
+ name: "version already installed",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Status: UpdateStatus{
+ ActiveVersion: "16.3.0",
+ },
+ },
+ installedVersion: "16.3.0",
+ installedTemplate: cdnURITemplate,
+ },
+ {
+ name: "config does not exist",
+ installedVersion: "16.3.0",
+ installedTemplate: cdnURITemplate,
+ },
+ {
+ name: "invalid metadata",
+ cfg: &UpdateConfig{},
+ errMatch: "invalid",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "update.yaml")
+
+ // Create config file only if provided in test case
+ if tt.cfg != nil {
+ b, err := yaml.Marshal(tt.cfg)
+ require.NoError(t, err)
+ err = os.WriteFile(cfgPath, b, 0600)
+ require.NoError(t, err)
+ }
+
+ server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // TODO(sclevine): add web API test including group verification
+ w.Write([]byte(`{}`))
+ }))
+ t.Cleanup(server.Close)
+
+ if tt.userCfg.Proxy == "" {
+ tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://")
+ }
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ VersionsDir: dir,
+ })
+ require.NoError(t, err)
+
+ var installedVersion, installedTemplate string
+ updater.Installer = &testInstaller{
+ FuncInstall: func(_ context.Context, version, template string, _ InstallFlags) error {
+ installedVersion = version
+ installedTemplate = template
+ return tt.installErr
+ },
+ }
+
+ ctx := context.Background()
+ err = updater.Enable(ctx, tt.userCfg)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tt.installedVersion, installedVersion)
+ require.Equal(t, tt.installedTemplate, installedTemplate)
+
+ data, err := os.ReadFile(cfgPath)
+ require.NoError(t, err)
+ data = blankTestAddr(data)
+
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+}
+
+var serverRegexp = regexp.MustCompile("127.0.0.1:[0-9]+")
+
+func blankTestAddr(s []byte) []byte {
+ return serverRegexp.ReplaceAll(s, []byte("localhost"))
+}
+
+type testInstaller struct {
+ FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error
+ FuncRemove func(ctx context.Context, version string) error
+}
+
+func (ti *testInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
+ return ti.FuncInstall(ctx, version, template, flags)
+}
+
+func (ti *testInstaller) Remove(ctx context.Context, version string) error {
+ return ti.FuncRemove(ctx, version)
+}
diff --git a/lib/autoupdate/agent_test.go b/lib/autoupdate/agent_test.go
deleted file mode 100644
index 7ac4ad379b78..000000000000
--- a/lib/autoupdate/agent_test.go
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2024 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package autoupdate
-
-import (
- "context"
- "log/slog"
- "os"
- "path/filepath"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "gopkg.in/yaml.v3"
-
- "github.com/gravitational/teleport/lib/utils/golden"
-)
-
-func TestAgentUpdater_Disable(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- cfg *AgentUpdateConfig // nil -> file not present
- errMatch string
- }{
- {
- name: "enabled",
- cfg: &AgentUpdateConfig{
- Version: agentUpdateConfigVersion,
- Kind: agentUpdateConfigKind,
- Spec: AgentUpdateSpec{
- Enabled: true,
- },
- },
- },
- {
- name: "already disabled",
- cfg: &AgentUpdateConfig{
- Version: agentUpdateConfigVersion,
- Kind: agentUpdateConfigKind,
- Spec: AgentUpdateSpec{
- Enabled: false,
- },
- },
- },
- {
- name: "config does not exist",
- },
- {
- name: "invalid metadata",
- cfg: &AgentUpdateConfig{
- Spec: AgentUpdateSpec{
- Enabled: true,
- },
- },
- errMatch: "invalid",
- },
- }
-
- for _, tt := range tests {
- tt := tt
- t.Run(tt.name, func(t *testing.T) {
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
- err := os.MkdirAll(filepath.Dir(cfgPath), 0777)
- require.NoError(t, err)
-
- // Create config file only if provided in test case
- if tt.cfg != nil {
- b, err := yaml.Marshal(tt.cfg)
- require.NoError(t, err)
- err = os.WriteFile(cfgPath, b, 0600)
- require.NoError(t, err)
- }
-
- updater := AgentUpdater{
- Log: slog.Default(),
- }
- err = updater.Disable(context.Background(), cfgPath)
- if tt.errMatch != "" {
- require.Error(t, err)
- assert.Contains(t, err.Error(), tt.errMatch)
- return
- }
- require.NoError(t, err)
-
- data, err := os.ReadFile(cfgPath)
-
- // If no config is present, disable should not create it
- if tt.cfg == nil {
- require.ErrorIs(t, err, os.ErrNotExist)
- return
- }
- require.NoError(t, err)
-
- if golden.ShouldSet() {
- golden.Set(t, data)
- }
- require.Equal(t, string(golden.Get(t)), string(data))
- })
- }
-}
diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go
index a13bde9b45d0..11aee2aae390 100644
--- a/tool/teleport-update/main.go
+++ b/tool/teleport-update/main.go
@@ -20,20 +20,16 @@ package main
import (
"context"
- "crypto/tls"
- "crypto/x509"
"log/slog"
- "net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
- "time"
"github.com/gravitational/trace"
"github.com/gravitational/teleport"
- "github.com/gravitational/teleport/lib/autoupdate"
+ autoupdate "github.com/gravitational/teleport/lib/autoupdate/agent"
libdefaults "github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/modules"
libutils "github.com/gravitational/teleport/lib/utils"
@@ -56,13 +52,13 @@ const (
proxyServerEnvVar = "TELEPORT_PROXY"
// updateGroupEnvVar allows the update group to be specified via env var.
updateGroupEnvVar = "TELEPORT_UPDATE_GROUP"
+ // updateVersionEnvVar forces the version to specified value.
+ updateVersionEnvVar = "TELEPORT_UPDATE_VERSION"
)
const (
// versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions.
versionsDirName = "versions"
- // configFileName specifies the name of the file inside versionsDirName containing configuration for the teleport update
- configFileName = "update.yaml"
// lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution.
lockFileName = ".lock"
)
@@ -76,23 +72,15 @@ func main() {
}
type cliConfig struct {
+ autoupdate.OverrideConfig
+
// Debug logs enabled
Debug bool
- // DataDir for Teleport (usually /var/lib/teleport)
- DataDir string
// LogFormat controls the format of logging. Can be either `json` or `text`.
// By default, this is `text`.
LogFormat string
-
- // ProxyServer address, scheme and port optional.
- // Overrides existing value if specified.
- ProxyServer string
- // Group identifier for updates (e.g., staging)
- // Overrides existing value if specified.
- Group string
- // Template for the Teleport tgz download URL
- // Overrides existing value if specified.
- Template string
+ // DataDir for Teleport (usually /var/lib/teleport)
+ DataDir string
}
func (c *cliConfig) CheckAndSetDefaults() error {
@@ -122,17 +110,21 @@ func Run(args []string) error {
versionCmd := app.Command("version", "Print the version of your teleport-updater binary.")
- enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial updates.")
- enableCmd.Flag("proxy", "Address of the Teleport Proxy.").Short('p').
- Envar(proxyServerEnvVar).StringVar(&ccfg.ProxyServer)
- enableCmd.Flag("group", "Update group, for staged updates.").Short('g').
- Envar(updateGroupEnvVar).StringVar(&ccfg.Group)
- enableCmd.Flag("template", "Go template to override Teleport tgz download URL.").
- Short('t').Envar(templateEnvVar).StringVar(&ccfg.Template)
+ enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.")
+ enableCmd.Flag("proxy", "Address of the Teleport Proxy.").
+ Short('p').Envar(proxyServerEnvVar).StringVar(&ccfg.Proxy)
+ enableCmd.Flag("group", "Update group for this agent installation.").
+ Short('g').Envar(updateGroupEnvVar).StringVar(&ccfg.Group)
+ enableCmd.Flag("template", "Go template used to override Teleport download URL.").
+ Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate)
+ enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster.").
+ Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion)
disableCmd := app.Command("disable", "Disable agent auto-updates.")
updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.")
+ updateCmd.Flag("force-version", "Use the provided version instead of querying it from the Teleport cluster.").
+ Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion)
libutils.UpdateAppUsageTemplate(app, args)
command, err := app.Parse(args)
@@ -186,21 +178,28 @@ func setupLogger(debug bool, format string) error {
// cmdDisable disables updates.
func cmdDisable(ctx context.Context, ccfg *cliConfig) error {
- var (
- versionsDir = filepath.Join(ccfg.DataDir, versionsDirName)
- updateYAML = filepath.Join(versionsDir, configFileName)
- )
+ versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
+ if err := os.MkdirAll(versionsDir, 0755); err != nil {
+ return trace.Errorf("failed to create versions directory: %w", err)
+ }
+
unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
if err != nil {
- return trace.Wrap(err)
+ return trace.Errorf("failed to grab concurrent execution lock: %w", err)
}
defer func() {
if err := unlock(); err != nil {
plog.DebugContext(ctx, "Failed to close lock file", "error", err)
}
}()
- updater := autoupdate.AgentUpdater{Log: plog}
- if err := updater.Disable(ctx, updateYAML); err != nil {
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ VersionsDir: versionsDir,
+ Log: plog,
+ })
+ if err != nil {
+ return trace.Errorf("failed to setup updater: %w", err)
+ }
+ if err := updater.Disable(ctx); err != nil {
return trace.Wrap(err)
}
return nil
@@ -208,37 +207,36 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error {
// cmdEnable enables updates and triggers an initial update.
func cmdEnable(ctx context.Context, ccfg *cliConfig) error {
- return trace.NotImplemented("TODO")
-}
-
-// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address.
-func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
- return trace.NotImplemented("TODO")
-}
+ versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
+ if err := os.MkdirAll(versionsDir, 0755); err != nil {
+ return trace.Errorf("failed to create versions directory: %w", err)
+ }
-//nolint:unused // scaffolding used in upcoming PR
-type downloadConfig struct {
- // Insecure turns off TLS certificate verification when enabled.
- Insecure bool
- // Pool defines the set of root CAs to use when verifying server
- // certificates.
- Pool *x509.CertPool
- // Timeout is a timeout for requests.
- Timeout time.Duration
-}
+ // Ensure enable can't run concurrently.
+ unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
+ if err != nil {
+ return trace.Errorf("failed to grab concurrent execution lock: %w", err)
+ }
+ defer func() {
+ if err := unlock(); err != nil {
+ plog.DebugContext(ctx, "Failed to close lock file", "error", err)
+ }
+ }()
-//nolint:unused // scaffolding used in upcoming PR
-func newClient(cfg *downloadConfig) (*http.Client, error) {
- tr, err := libdefaults.Transport()
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ VersionsDir: versionsDir,
+ Log: plog,
+ })
if err != nil {
- return nil, trace.Wrap(err)
+ return trace.Errorf("failed to setup updater: %w", err)
}
- tr.TLSClientConfig = &tls.Config{
- InsecureSkipVerify: cfg.Insecure,
- RootCAs: cfg.Pool,
+ if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil {
+ return trace.Wrap(err)
}
- return &http.Client{
- Transport: tr,
- Timeout: cfg.Timeout,
- }, nil
+ return nil
+}
+
+// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address.
+func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
+ return trace.NotImplemented("TODO")
}