diff --git a/lib/autoupdate/agent.go b/lib/autoupdate/agent.go deleted file mode 100644 index 921562251f1f..000000000000 --- a/lib/autoupdate/agent.go +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Teleport - * Copyright (C) 2024 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package autoupdate - -import ( - "context" - "errors" - "io/fs" - "log/slog" - "os" - - "github.com/google/renameio/v2" - "github.com/gravitational/trace" - "gopkg.in/yaml.v3" -) - -const ( - agentUpdateConfigVersion = "v1" - agentUpdateConfigKind = "update_config" -) - -// AgentUpdateConfig describes the update.yaml file schema. -type AgentUpdateConfig struct { - // Version of the configuration file - Version string `yaml:"version"` - // Kind of configuration file (always "update_config") - Kind string `yaml:"kind"` - // Spec contains user-specified configuration. - Spec AgentUpdateSpec `yaml:"spec"` - // Status contains state configuration. - Status AgentUpdateStatus `yaml:"status"` -} - -// AgentUpdateSpec describes the spec field in update.yaml. -type AgentUpdateSpec struct { - // Proxy address - Proxy string `yaml:"proxy"` - // Group update identifier - Group string `yaml:"group"` - // URLTemplate for the Teleport tgz download URL. - URLTemplate string `yaml:"url_template"` - // Enabled controls whether auto-updates are enabled. - Enabled bool `yaml:"enabled"` -} - -// AgentUpdateStatus describes the status field in update.yaml. -type AgentUpdateStatus struct { - // ActiveVersion is the currently active Teleport version. - ActiveVersion string `yaml:"active_version"` -} - -type AgentUpdater struct { - Log *slog.Logger -} - -// Disable disables agent updates. -// updatePath must be a path to the update.yaml file. -func (u AgentUpdater) Disable(ctx context.Context, updatePath string) error { - cfg, err := u.readConfig(updatePath) - if err != nil { - return trace.Errorf("failed to read updates.yaml: %w", err) - } - if !cfg.Spec.Enabled { - u.Log.InfoContext(ctx, "Automatic updates already disabled") - return nil - } - cfg.Spec.Enabled = false - if err := u.writeConfig(updatePath, cfg); err != nil { - return trace.Errorf("failed to write updates.yaml: %w", err) - } - return nil -} - -// readConfig reads update.yaml -func (AgentUpdater) readConfig(path string) (*AgentUpdateConfig, error) { - f, err := os.Open(path) - if errors.Is(err, fs.ErrNotExist) { - return &AgentUpdateConfig{ - Version: agentUpdateConfigVersion, - Kind: agentUpdateConfigKind, - }, nil - } - if err != nil { - return nil, trace.Errorf("failed to open: %w", err) - } - defer f.Close() - var cfg AgentUpdateConfig - if err := yaml.NewDecoder(f).Decode(&cfg); err != nil { - return nil, trace.Errorf("failed to parse: %w", err) - } - if k := cfg.Kind; k != agentUpdateConfigKind { - return nil, trace.Errorf("invalid kind %q", k) - } - if v := cfg.Version; v != agentUpdateConfigVersion { - return nil, trace.Errorf("invalid version %q", v) - } - return &cfg, nil -} - -// writeConfig writes update.yaml atomically, ensuring the file cannot be corrupted. -func (AgentUpdater) writeConfig(filename string, cfg *AgentUpdateConfig) error { - opts := []renameio.Option{ - renameio.WithPermissions(0755), - renameio.WithExistingPermissions(), - } - t, err := renameio.NewPendingFile(filename, opts...) - if err != nil { - return trace.Wrap(err) - } - defer t.Cleanup() - err = yaml.NewEncoder(t).Encode(cfg) - if err != nil { - return trace.Wrap(err) - } - return trace.Wrap(t.CloseAtomicallyReplace()) -} diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go new file mode 100644 index 000000000000..e31813866eac --- /dev/null +++ b/lib/autoupdate/agent/installer.go @@ -0,0 +1,317 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "runtime" + "text/template" + "time" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/utils" +) + +const ( + checksumType = "sha256" + checksumHexLen = sha256.Size * 2 // bytes to hex +) + +// LocalInstaller manages the creation and removal of installations +// of Teleport. +type LocalInstaller struct { + // InstallDir contains each installation, named by version. + InstallDir string + // HTTP is an HTTP client for downloading Teleport. + HTTP *http.Client + // Log contains a logger. + Log *slog.Logger + // ReservedFreeTmpDisk is the amount of disk that must remain free in /tmp + ReservedFreeTmpDisk uint64 + // ReservedFreeInstallDisk is the amount of disk that must remain free in the install directory. + ReservedFreeInstallDisk uint64 +} + +// Remove a Teleport version directory from InstallDir. +// This function is idempotent. +func (li *LocalInstaller) Remove(ctx context.Context, version string) error { + versionDir := filepath.Join(li.InstallDir, version) + sumPath := filepath.Join(versionDir, checksumType) + + // invalidate checksum first, to protect against partially-removed + // directory with valid checksum. + err := os.Remove(sumPath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return trace.Wrap(err) + } + if err := os.RemoveAll(versionDir); err != nil { + return trace.Wrap(err) + } + return nil +} + +// Install a Teleport version directory in InstallDir. +// This function is idempotent. +func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error { + versionDir := filepath.Join(li.InstallDir, version) + sumPath := filepath.Join(versionDir, checksumType) + + // generate download URI from template + uri, err := makeURL(template, version, flags) + if err != nil { + return trace.Wrap(err) + } + + // Get new and old checksums. If they match, skip download. + // Otherwise, clear the old version directory and re-download. + checksumURI := uri + "." + checksumType + newSum, err := li.getChecksum(ctx, checksumURI) + if err != nil { + return trace.Errorf("failed to download checksum from %s: %w", checksumURI, err) + } + oldSum, err := readChecksum(sumPath) + if err == nil { + if bytes.Equal(oldSum, newSum) { + li.Log.InfoContext(ctx, "Version already present.", "version", version) + return nil + } + li.Log.WarnContext(ctx, "Removing version that does not match checksum.", "version", version) + if err := li.Remove(ctx, version); err != nil { + return trace.Wrap(err) + } + } else if !errors.Is(err, os.ErrNotExist) { + li.Log.WarnContext(ctx, "Removing version with unreadable checksum.", "version", version, "error", err) + if err := li.Remove(ctx, version); err != nil { + return trace.Wrap(err) + } + } + + // Verify that we have enough free temp space, then download tgz + freeTmp, err := utils.FreeDiskWithReserve(os.TempDir(), li.ReservedFreeTmpDisk) + if err != nil { + return trace.Errorf("failed to calculate free disk: %w", err) + } + f, err := os.CreateTemp("", "teleport-update-") + if err != nil { + return trace.Errorf("failed to create temporary file: %w", err) + } + defer func() { + _ = f.Close() // data never read after close + if err := os.Remove(f.Name()); err != nil { + li.Log.WarnContext(ctx, "Failed to cleanup temporary download.", "error", err) + } + }() + pathSum, err := li.download(ctx, f, int64(freeTmp), uri) + if err != nil { + return trace.Errorf("failed to download teleport: %w", err) + } + + // Seek to the start of the tgz file after writing + if _, err := f.Seek(0, io.SeekStart); err != nil { + return trace.Errorf("failed seek to start of download: %w", err) + } + // Check integrity before decompression + if !bytes.Equal(newSum, pathSum) { + return trace.Errorf("mismatched checksum, download possibly corrupt") + } + // Get uncompressed size of the tgz + n, err := uncompressedSize(f) + if err != nil { + return trace.Errorf("failed to determine uncompressed size: %w", err) + } + // Seek to start of tgz after reading size + if _, err := f.Seek(0, io.SeekStart); err != nil { + return trace.Errorf("failed seek to start: %w", err) + } + if err := li.extract(ctx, versionDir, f, n); err != nil { + return trace.Errorf("failed to extract teleport: %w", err) + } + // Write the checksum last. This marks the version directory as valid. + err = os.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), 0755) + if err != nil { + return trace.Errorf("failed to write checksum: %w", err) + } + return nil +} + +// makeURL to download the Teleport tgz. +func makeURL(uriTmpl, version string, flags InstallFlags) (string, error) { + tmpl, err := template.New("uri").Parse(uriTmpl) + if err != nil { + return "", trace.Wrap(err) + } + var uriBuf bytes.Buffer + params := struct { + OS, Version, Arch string + FIPS, Enterprise bool + }{ + OS: runtime.GOOS, + Version: version, + Arch: runtime.GOARCH, + FIPS: flags&FlagFIPS != 0, + Enterprise: flags&(FlagEnterprise|FlagFIPS) != 0, + } + err = tmpl.Execute(&uriBuf, params) + if err != nil { + return "", trace.Wrap(err) + } + return uriBuf.String(), nil +} + +// readChecksum from the version directory. +func readChecksum(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, trace.Wrap(err) + } + defer f.Close() + var buf bytes.Buffer + _, err = io.CopyN(&buf, f, checksumHexLen) + if err != nil { + return nil, trace.Wrap(err) + } + raw := buf.String() + sum, err := hex.DecodeString(raw) + if err != nil { + return nil, trace.Wrap(err) + } + return sum, nil +} + +func (li *LocalInstaller) getChecksum(ctx context.Context, url string) ([]byte, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, trace.Wrap(err) + } + resp, err := li.HTTP.Do(req) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return nil, trace.Errorf("checksum not found: %s", url) + } + if resp.StatusCode != http.StatusOK { + return nil, trace.Errorf("unexpected HTTP status code: %d", resp.StatusCode) + } + + var buf bytes.Buffer + _, err = io.CopyN(&buf, resp.Body, checksumHexLen) + if err != nil { + return nil, trace.Wrap(err) + } + sum, err := hex.DecodeString(buf.String()) + if err != nil { + return nil, trace.Wrap(err) + } + return sum, nil +} + +func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, url string) (sum []byte, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, trace.Wrap(err) + } + resp, err := li.HTTP.Do(req) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return nil, trace.Errorf("Teleport download not found: %s", url) + } + if resp.StatusCode != http.StatusOK { + return nil, trace.Errorf("unexpected HTTP status code: %d", resp.StatusCode) + } + li.Log.InfoContext(ctx, "Downloading Teleport tarball.", "url", url, "size", resp.ContentLength) + + // Ensure there's enough space in /tmp for the download. + size := resp.ContentLength + if size < 0 { + li.Log.WarnContext(ctx, "Content length missing from response, unable to verify Teleport download size.") + size = max + } else if size > max { + return nil, trace.Errorf("size of download (%d bytes) exceeds available disk space (%d bytes)", resp.ContentLength, max) + } + // Calculate checksum concurrently with download. + shaReader := sha256.New() + n, err := io.CopyN(w, io.TeeReader(resp.Body, shaReader), size) + if err != nil { + return nil, trace.Wrap(err) + } + if resp.ContentLength >= 0 && n != resp.ContentLength { + return nil, trace.Errorf("mismatch in Teleport download size") + } + return shaReader.Sum(nil), nil +} + +func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64) error { + if err := os.MkdirAll(dstDir, 0755); err != nil { + return trace.Wrap(err) + } + free, err := utils.FreeDiskWithReserve(dstDir, li.ReservedFreeInstallDisk) + if err != nil { + return trace.Errorf("failed to calculate free disk in %q: %w", dstDir, err) + } + // Bail if there's not enough free disk space at the target + if d := int64(free) - max; d < 0 { + return trace.Errorf("%q needs %d additional bytes of disk space for decompression", dstDir, -d) + } + zr, err := gzip.NewReader(src) + if err != nil { + return trace.Errorf("requires gzip-compressed body: %v", err) + } + li.Log.InfoContext(ctx, "Extracting Teleport tarball.", "path", dstDir, "size", max) + + // TODO(sclevine): add variadic arg to Extract to extract teleport/ subdir into bin/. + err = utils.Extract(zr, dstDir) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +func uncompressedSize(f io.Reader) (int64, error) { + // NOTE: The gzip length trailer is very unreliable, + // but we could optimize this in the future if + // we are willing to verify that all published + // Teleport tarballs have valid trailers. + r, err := gzip.NewReader(f) + if err != nil { + return 0, trace.Wrap(err) + } + n, err := io.Copy(io.Discard, r) + if err != nil { + return 0, trace.Wrap(err) + } + return n, nil +} diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go new file mode 100644 index 000000000000..be778f7bcf16 --- /dev/null +++ b/lib/autoupdate/agent/installer_test.go @@ -0,0 +1,189 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTeleportInstaller_Install(t *testing.T) { + t.Parallel() + const version = "new-version" + + _, testSum := testTGZ(t, version) + + tests := []struct { + name string + reservedTmp uint64 + reservedInstall uint64 + existingSum string + flags InstallFlags + + errMatch string + }{ + { + name: "not present", + }, + { + name: "present", + existingSum: testSum, + }, + { + name: "mismatched checksum", + existingSum: hex.EncodeToString(sha256.New().Sum(nil)), + }, + { + name: "unreadable checksum", + existingSum: "bad", + }, + { + name: "out of space in /tmp", + reservedTmp: reservedFreeDisk * 1_000_000_000, + errMatch: "no free space left", + }, + { + name: "out of space in install dir", + reservedInstall: reservedFreeDisk * 1_000_000_000, + errMatch: "no free space left", + }, + // TODO(sclevine): test flags + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + + if tt.existingSum != "" { + err := os.WriteFile(filepath.Join(dir, checksumType), []byte(tt.existingSum), os.ModePerm) + require.NoError(t, err) + } + + // test parameters + var dlPath, shaPath, shasum string + + // test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tgz, sum := testTGZ(t, version) + shasum = sum + var out *bytes.Buffer + if strings.HasSuffix(r.URL.Path, "."+checksumType) { // checksum request + shaPath = r.URL.Path + out = bytes.NewBufferString(sum) + } else { // tgz request + dlPath = r.URL.Path + out = tgz + } + w.Header().Set("Content-Length", strconv.Itoa(out.Len())) + _, err := io.Copy(w, out) + if err != nil { + t.Fatal(err) + } + })) + t.Cleanup(server.Close) + + installer := &LocalInstaller{ + InstallDir: dir, + HTTP: http.DefaultClient, + Log: slog.Default(), + ReservedFreeTmpDisk: tt.reservedTmp, + ReservedFreeInstallDisk: tt.reservedInstall, + } + ctx := context.Background() + err := installer.Install(ctx, version, server.URL+"/{{.OS}}/{{.Arch}}/{{.Version}}", tt.flags) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + + const expectedPath = "/" + runtime.GOOS + "/" + runtime.GOARCH + "/" + version + require.Equal(t, expectedPath, dlPath) + require.Equal(t, expectedPath+"."+checksumType, shaPath) + + teleportVersion, err := os.ReadFile(filepath.Join(dir, version, "teleport")) + require.NoError(t, err) + require.Equal(t, version, string(teleportVersion)) + + tshVersion, err := os.ReadFile(filepath.Join(dir, version, "tsh")) + require.NoError(t, err) + require.Equal(t, version, string(tshVersion)) + + sum, err := os.ReadFile(filepath.Join(dir, version, checksumType)) + require.NoError(t, err) + require.Equal(t, string(sum), shasum) + }) + } +} + +func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) { + t.Helper() + + var buf bytes.Buffer + + sha := sha256.New() + gz := gzip.NewWriter(io.MultiWriter(&buf, sha)) + tw := tar.NewWriter(gz) + + var files = []struct { + Name, Body string + }{ + {"teleport", version}, + {"tsh", version}, + } + for _, file := range files { + hdr := &tar.Header{ + Name: file.Name, + Mode: 0600, + Size: int64(len(file.Body)), + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte(file.Body)); err != nil { + t.Fatal(err) + } + } + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gz.Close(); err != nil { + t.Fatal(err) + } + return &buf, hex.EncodeToString(sha.Sum(nil)) +} diff --git a/lib/autoupdate/testdata/TestAgentUpdater_Disable/already_disabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden similarity index 100% rename from lib/autoupdate/testdata/TestAgentUpdater_Disable/already_disabled.golden rename to lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden diff --git a/lib/autoupdate/testdata/TestAgentUpdater_Disable/enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden similarity index 100% rename from lib/autoupdate/testdata/TestAgentUpdater_Disable/enabled.golden rename to lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden new file mode 100644 index 000000000000..e03f369eb101 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden @@ -0,0 +1,9 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: "" + enabled: true +status: + active_version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden new file mode 100644 index 000000000000..e03f369eb101 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden @@ -0,0 +1,9 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: "" + enabled: true +status: + active_version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden new file mode 100644 index 000000000000..b172d858bc55 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden @@ -0,0 +1,9 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: group + url_template: https://example.com + enabled: true +status: + active_version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden new file mode 100644 index 000000000000..bb9ce8b9d8fa --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden @@ -0,0 +1,9 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: new-group + url_template: https://example.com/new + enabled: true +status: + active_version: new-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden new file mode 100644 index 000000000000..e03f369eb101 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden @@ -0,0 +1,9 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: "" + enabled: true +status: + active_version: 16.3.0 diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go new file mode 100644 index 000000000000..59df5f0b3ba8 --- /dev/null +++ b/lib/autoupdate/agent/updater.go @@ -0,0 +1,341 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io/fs" + "log/slog" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/renameio/v2" + "github.com/gravitational/trace" + "gopkg.in/yaml.v3" + + "github.com/gravitational/teleport/api/client/webclient" + libdefaults "github.com/gravitational/teleport/lib/defaults" + libutils "github.com/gravitational/teleport/lib/utils" +) + +const ( + // cdnURITemplate is the default template for the Teleport tgz download. + cdnURITemplate = "https://cdn.teleport.dev/teleport{{if .Enterprise}}-ent{{end}}-v{{.Version}}-{{.OS}}-{{.Arch}}{{if .FIPS}}-fips{{end}}-bin.tar.gz" + // reservedFreeDisk is the minimum required free space left on disk during downloads. + // TODO(sclevine): This value is arbitrary and could be replaced by, e.g., min(1%, 200mb) in the future + // to account for a range of disk sizes. + reservedFreeDisk = 10_000_000 // 10 MB +) + +const ( + // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update. + updateConfigName = "update.yaml" + + // UpdateConfig metadata + updateConfigVersion = "v1" + updateConfigKind = "update_config" +) + +// UpdateConfig describes the update.yaml file schema. +type UpdateConfig struct { + // Version of the configuration file + Version string `yaml:"version"` + // Kind of configuration file (always "update_config") + Kind string `yaml:"kind"` + // Spec contains user-specified configuration. + Spec UpdateSpec `yaml:"spec"` + // Status contains state configuration. + Status UpdateStatus `yaml:"status"` +} + +// UpdateSpec describes the spec field in update.yaml. +type UpdateSpec struct { + // Proxy address + Proxy string `yaml:"proxy"` + // Group specifies the update group identifier for the agent. + Group string `yaml:"group"` + // URLTemplate for the Teleport tgz download URL. + URLTemplate string `yaml:"url_template"` + // Enabled controls whether auto-updates are enabled. + Enabled bool `yaml:"enabled"` +} + +// UpdateStatus describes the status field in update.yaml. +type UpdateStatus struct { + // ActiveVersion is the currently active Teleport version. + ActiveVersion string `yaml:"active_version"` +} + +// NewLocalUpdater returns a new Updater that auto-updates local +// installations of the Teleport agent. +// The AutoUpdater uses an HTTP client with sane defaults for downloads, and +// will not fill disk to within 10 MB of available capacity. +func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { + certPool, err := x509.SystemCertPool() + if err != nil { + return nil, trace.Wrap(err) + } + tr, err := libdefaults.Transport() + if err != nil { + return nil, trace.Wrap(err) + } + tr.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: cfg.InsecureSkipVerify, + RootCAs: certPool, + } + client := &http.Client{ + Transport: tr, + Timeout: cfg.DownloadTimeout, + } + if cfg.Log == nil { + cfg.Log = slog.Default() + } + return &Updater{ + Log: cfg.Log, + Pool: certPool, + InsecureSkipVerify: cfg.InsecureSkipVerify, + ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName), + Installer: &LocalInstaller{ + InstallDir: cfg.VersionsDir, + HTTP: client, + Log: cfg.Log, + + ReservedFreeTmpDisk: reservedFreeDisk, + ReservedFreeInstallDisk: reservedFreeDisk, + }, + }, nil +} + +// LocalUpdaterConfig specifies configuration for managing local agent auto-updates. +type LocalUpdaterConfig struct { + // Log contains a slog logger. + // Defaults to slog.Default() if nil. + Log *slog.Logger + // InsecureSkipVerify turns off TLS certificate verification. + InsecureSkipVerify bool + // DownloadTimeout is a timeout for file download requests. + // Defaults to no timeout. + DownloadTimeout time.Duration + // VersionsDir for installing Teleport (usually /var/lib/teleport/versions). + VersionsDir string +} + +// Updater implements the agent-local logic for Teleport agent auto-updates. +type Updater struct { + // Log contains a logger. + Log *slog.Logger + // Pool used for requests to the Teleport web API. + Pool *x509.CertPool + // InsecureSkipVerify skips TLS verification. + InsecureSkipVerify bool + // ConfigPath contains the path to the agent auto-updates configuration. + ConfigPath string + // Installer manages installations of the Teleport agent. + Installer Installer +} + +// Installer provides an API for installing Teleport agents. +type Installer interface { + // Install the Teleport agent at version from the download template. + // This function must be idempotent. + Install(ctx context.Context, version, template string, flags InstallFlags) error + // Remove the Teleport agent at version. + // This function must be idempotent. + Remove(ctx context.Context, version string) error +} + +// InstallFlags sets flags for the Teleport installation +type InstallFlags int + +const ( + // FlagEnterprise installs enterprise Teleport + FlagEnterprise InstallFlags = 1 << iota + // FlagFIPS installs FIPS Teleport + FlagFIPS +) + +// OverrideConfig contains overrides for individual update operations. +// If validated, these overrides may be persisted to disk. +type OverrideConfig struct { + // Proxy address, scheme and port optional. + // Overrides existing value if specified. + Proxy string + // Group identifier for updates (e.g., staging) + // Overrides existing value if specified. + Group string + // URLTemplate for the Teleport tgz download URL + // Overrides existing value if specified. + URLTemplate string + // ForceVersion to the specified version. + ForceVersion string +} + +// Enable enables agent updates and attempts an initial update. +// If the initial update succeeds, auto-updates are enabled and the configuration is persisted. +// Otherwise, the auto-updates configuration is not changed. +// This function is idempotent. +func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { + // Read configuration from update.yaml and override any new values passed as flags. + cfg, err := u.readConfig(u.ConfigPath) + if err != nil { + return trace.Errorf("failed to read %s: %w", updateConfigName, err) + } + if override.Proxy != "" { + cfg.Spec.Proxy = override.Proxy + } + if override.Group != "" { + cfg.Spec.Group = override.Group + } + if override.URLTemplate != "" { + cfg.Spec.URLTemplate = override.URLTemplate + } + cfg.Spec.Enabled = true + if err := validateUpdatesSpec(&cfg.Spec); err != nil { + return trace.Wrap(err) + } + + // Lookup target version from the proxy. + addr, err := libutils.ParseAddr(cfg.Spec.Proxy) + if err != nil { + return trace.Errorf("failed to parse proxy server address: %w", err) + } + + desiredVersion := override.ForceVersion + if desiredVersion == "" { + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: addr.Addr, + Insecure: u.InsecureSkipVerify, + Timeout: 30 * time.Second, + //Group: cfg.Spec.Group, // TODO(sclevine): add web API for verssion + Pool: u.Pool, + }) + if err != nil { + return trace.Errorf("failed to request version from proxy: %w", err) + } + desiredVersion, _ = "16.3.0", resp // TODO(sclevine): add web API for version + //desiredVersion := resp.AutoUpdate.AgentVersion + } + + if desiredVersion == "" { + return trace.Errorf("agent version not available from Teleport cluster") + } + // If the active version and target don't match, kick off upgrade. + template := cfg.Spec.URLTemplate + if template == "" { + template = cdnURITemplate + } + err = u.Installer.Install(ctx, desiredVersion, template, 0) // TODO(sclevine): add web API for flags + if err != nil { + return trace.Wrap(err) + } + if cfg.Status.ActiveVersion != desiredVersion { + u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion) + } else { + u.Log.InfoContext(ctx, "Target version successfully validated.", "version", desiredVersion) + } + cfg.Status.ActiveVersion = desiredVersion + + // Always write the configuration file if enable succeeds. + if err := u.writeConfig(u.ConfigPath, cfg); err != nil { + return trace.Errorf("failed to write %s: %w", updateConfigName, err) + } + u.Log.InfoContext(ctx, "Configuration updated.") + return nil +} + +func validateUpdatesSpec(spec *UpdateSpec) error { + if spec.URLTemplate != "" && + !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") { + return trace.Errorf("Teleport download URL must use TLS (https://)") + } + + if spec.Proxy == "" { + return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName) + } + return nil +} + +// Disable disables agent auto-updates. +// This function is idempotent. +func (u *Updater) Disable(ctx context.Context) error { + cfg, err := u.readConfig(u.ConfigPath) + if err != nil { + return trace.Errorf("failed to read %s: %w", updateConfigName, err) + } + if !cfg.Spec.Enabled { + u.Log.InfoContext(ctx, "Automatic updates already disabled.") + return nil + } + cfg.Spec.Enabled = false + if err := u.writeConfig(u.ConfigPath, cfg); err != nil { + return trace.Errorf("failed to write %s: %w", updateConfigName, err) + } + return nil +} + +// readConfig reads UpdateConfig from a file. +func (*Updater) readConfig(path string) (*UpdateConfig, error) { + f, err := os.Open(path) + if errors.Is(err, fs.ErrNotExist) { + return &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + }, nil + } + if err != nil { + return nil, trace.Errorf("failed to open: %w", err) + } + defer f.Close() + var cfg UpdateConfig + if err := yaml.NewDecoder(f).Decode(&cfg); err != nil { + return nil, trace.Errorf("failed to parse: %w", err) + } + if k := cfg.Kind; k != updateConfigKind { + return nil, trace.Errorf("invalid kind %q", k) + } + if v := cfg.Version; v != updateConfigVersion { + return nil, trace.Errorf("invalid version %q", v) + } + return &cfg, nil +} + +// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted. +func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error { + opts := []renameio.Option{ + renameio.WithPermissions(0755), + renameio.WithExistingPermissions(), + } + t, err := renameio.NewPendingFile(filename, opts...) + if err != nil { + return trace.Wrap(err) + } + defer t.Cleanup() + err = yaml.NewEncoder(t).Encode(cfg) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(t.CloseAtomicallyReplace()) +} diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go new file mode 100644 index 000000000000..6568fbaede9e --- /dev/null +++ b/lib/autoupdate/agent/updater_test.go @@ -0,0 +1,315 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" + + "github.com/gravitational/teleport/lib/utils/golden" +) + +func TestUpdater_Disable(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + errMatch string + }{ + { + name: "enabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: true, + }, + }, + }, + { + name: "already disabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + }, + { + name: "config does not exist", + }, + { + name: "invalid metadata", + cfg: &UpdateConfig{ + Spec: UpdateSpec{ + Enabled: true, + }, + }, + errMatch: "invalid", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "update.yaml") + + // Create config file only if provided in test case + if tt.cfg != nil { + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + VersionsDir: dir, + }) + require.NoError(t, err) + err = updater.Disable(context.Background()) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + + data, err := os.ReadFile(cfgPath) + + // If no config is present, disable should not create it + if tt.cfg == nil { + require.ErrorIs(t, err, os.ErrNotExist) + return + } + require.NoError(t, err) + + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +func TestUpdater_Enable(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + userCfg OverrideConfig + installErr error + + installedVersion string + installedTemplate string + errMatch string + }{ + { + name: "config from file", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Group: "group", + URLTemplate: "https://example.com", + }, + Status: UpdateStatus{ + ActiveVersion: "old-version", + }, + }, + installedVersion: "16.3.0", + installedTemplate: "https://example.com", + }, + { + name: "config from user", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Group: "old-group", + URLTemplate: "https://example.com/old", + }, + Status: UpdateStatus{ + ActiveVersion: "old-version", + }, + }, + userCfg: OverrideConfig{ + Group: "new-group", + URLTemplate: "https://example.com/new", + ForceVersion: "new-version", + }, + installedVersion: "new-version", + installedTemplate: "https://example.com/new", + }, + { + name: "already enabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: true, + }, + Status: UpdateStatus{ + ActiveVersion: "old-version", + }, + }, + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + }, + { + name: "insecure URL", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + URLTemplate: "http://example.com", + }, + }, + errMatch: "URL must use TLS", + }, + { + name: "install error", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + URLTemplate: "https://example.com", + }, + }, + installErr: errors.New("install error"), + errMatch: "install error", + }, + { + name: "version already installed", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Status: UpdateStatus{ + ActiveVersion: "16.3.0", + }, + }, + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + }, + { + name: "config does not exist", + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + }, + { + name: "invalid metadata", + cfg: &UpdateConfig{}, + errMatch: "invalid", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "update.yaml") + + // Create config file only if provided in test case + if tt.cfg != nil { + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TODO(sclevine): add web API test including group verification + w.Write([]byte(`{}`)) + })) + t.Cleanup(server.Close) + + if tt.userCfg.Proxy == "" { + tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://") + } + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + VersionsDir: dir, + }) + require.NoError(t, err) + + var installedVersion, installedTemplate string + updater.Installer = &testInstaller{ + FuncInstall: func(_ context.Context, version, template string, _ InstallFlags) error { + installedVersion = version + installedTemplate = template + return tt.installErr + }, + } + + ctx := context.Background() + err = updater.Enable(ctx, tt.userCfg) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + require.Equal(t, tt.installedVersion, installedVersion) + require.Equal(t, tt.installedTemplate, installedTemplate) + + data, err := os.ReadFile(cfgPath) + require.NoError(t, err) + data = blankTestAddr(data) + + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +var serverRegexp = regexp.MustCompile("127.0.0.1:[0-9]+") + +func blankTestAddr(s []byte) []byte { + return serverRegexp.ReplaceAll(s, []byte("localhost")) +} + +type testInstaller struct { + FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error + FuncRemove func(ctx context.Context, version string) error +} + +func (ti *testInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error { + return ti.FuncInstall(ctx, version, template, flags) +} + +func (ti *testInstaller) Remove(ctx context.Context, version string) error { + return ti.FuncRemove(ctx, version) +} diff --git a/lib/autoupdate/agent_test.go b/lib/autoupdate/agent_test.go deleted file mode 100644 index 7ac4ad379b78..000000000000 --- a/lib/autoupdate/agent_test.go +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Teleport - * Copyright (C) 2024 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package autoupdate - -import ( - "context" - "log/slog" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" - - "github.com/gravitational/teleport/lib/utils/golden" -) - -func TestAgentUpdater_Disable(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - cfg *AgentUpdateConfig // nil -> file not present - errMatch string - }{ - { - name: "enabled", - cfg: &AgentUpdateConfig{ - Version: agentUpdateConfigVersion, - Kind: agentUpdateConfigKind, - Spec: AgentUpdateSpec{ - Enabled: true, - }, - }, - }, - { - name: "already disabled", - cfg: &AgentUpdateConfig{ - Version: agentUpdateConfigVersion, - Kind: agentUpdateConfigKind, - Spec: AgentUpdateSpec{ - Enabled: false, - }, - }, - }, - { - name: "config does not exist", - }, - { - name: "invalid metadata", - cfg: &AgentUpdateConfig{ - Spec: AgentUpdateSpec{ - Enabled: true, - }, - }, - errMatch: "invalid", - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") - err := os.MkdirAll(filepath.Dir(cfgPath), 0777) - require.NoError(t, err) - - // Create config file only if provided in test case - if tt.cfg != nil { - b, err := yaml.Marshal(tt.cfg) - require.NoError(t, err) - err = os.WriteFile(cfgPath, b, 0600) - require.NoError(t, err) - } - - updater := AgentUpdater{ - Log: slog.Default(), - } - err = updater.Disable(context.Background(), cfgPath) - if tt.errMatch != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errMatch) - return - } - require.NoError(t, err) - - data, err := os.ReadFile(cfgPath) - - // If no config is present, disable should not create it - if tt.cfg == nil { - require.ErrorIs(t, err, os.ErrNotExist) - return - } - require.NoError(t, err) - - if golden.ShouldSet() { - golden.Set(t, data) - } - require.Equal(t, string(golden.Get(t)), string(data)) - }) - } -} diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index a13bde9b45d0..11aee2aae390 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -20,20 +20,16 @@ package main import ( "context" - "crypto/tls" - "crypto/x509" "log/slog" - "net/http" "os" "os/signal" "path/filepath" "syscall" - "time" "github.com/gravitational/trace" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/autoupdate" + autoupdate "github.com/gravitational/teleport/lib/autoupdate/agent" libdefaults "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" libutils "github.com/gravitational/teleport/lib/utils" @@ -56,13 +52,13 @@ const ( proxyServerEnvVar = "TELEPORT_PROXY" // updateGroupEnvVar allows the update group to be specified via env var. updateGroupEnvVar = "TELEPORT_UPDATE_GROUP" + // updateVersionEnvVar forces the version to specified value. + updateVersionEnvVar = "TELEPORT_UPDATE_VERSION" ) const ( // versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions. versionsDirName = "versions" - // configFileName specifies the name of the file inside versionsDirName containing configuration for the teleport update - configFileName = "update.yaml" // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution. lockFileName = ".lock" ) @@ -76,23 +72,15 @@ func main() { } type cliConfig struct { + autoupdate.OverrideConfig + // Debug logs enabled Debug bool - // DataDir for Teleport (usually /var/lib/teleport) - DataDir string // LogFormat controls the format of logging. Can be either `json` or `text`. // By default, this is `text`. LogFormat string - - // ProxyServer address, scheme and port optional. - // Overrides existing value if specified. - ProxyServer string - // Group identifier for updates (e.g., staging) - // Overrides existing value if specified. - Group string - // Template for the Teleport tgz download URL - // Overrides existing value if specified. - Template string + // DataDir for Teleport (usually /var/lib/teleport) + DataDir string } func (c *cliConfig) CheckAndSetDefaults() error { @@ -122,17 +110,21 @@ func Run(args []string) error { versionCmd := app.Command("version", "Print the version of your teleport-updater binary.") - enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial updates.") - enableCmd.Flag("proxy", "Address of the Teleport Proxy.").Short('p'). - Envar(proxyServerEnvVar).StringVar(&ccfg.ProxyServer) - enableCmd.Flag("group", "Update group, for staged updates.").Short('g'). - Envar(updateGroupEnvVar).StringVar(&ccfg.Group) - enableCmd.Flag("template", "Go template to override Teleport tgz download URL."). - Short('t').Envar(templateEnvVar).StringVar(&ccfg.Template) + enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.") + enableCmd.Flag("proxy", "Address of the Teleport Proxy."). + Short('p').Envar(proxyServerEnvVar).StringVar(&ccfg.Proxy) + enableCmd.Flag("group", "Update group for this agent installation."). + Short('g').Envar(updateGroupEnvVar).StringVar(&ccfg.Group) + enableCmd.Flag("template", "Go template used to override Teleport download URL."). + Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate) + enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster."). + Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion) disableCmd := app.Command("disable", "Disable agent auto-updates.") updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.") + updateCmd.Flag("force-version", "Use the provided version instead of querying it from the Teleport cluster."). + Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion) libutils.UpdateAppUsageTemplate(app, args) command, err := app.Parse(args) @@ -186,21 +178,28 @@ func setupLogger(debug bool, format string) error { // cmdDisable disables updates. func cmdDisable(ctx context.Context, ccfg *cliConfig) error { - var ( - versionsDir = filepath.Join(ccfg.DataDir, versionsDirName) - updateYAML = filepath.Join(versionsDir, configFileName) - ) + versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) + if err := os.MkdirAll(versionsDir, 0755); err != nil { + return trace.Errorf("failed to create versions directory: %w", err) + } + unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) if err != nil { - return trace.Wrap(err) + return trace.Errorf("failed to grab concurrent execution lock: %w", err) } defer func() { if err := unlock(); err != nil { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - updater := autoupdate.AgentUpdater{Log: plog} - if err := updater.Disable(ctx, updateYAML); err != nil { + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + VersionsDir: versionsDir, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) + } + if err := updater.Disable(ctx); err != nil { return trace.Wrap(err) } return nil @@ -208,37 +207,36 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { // cmdEnable enables updates and triggers an initial update. func cmdEnable(ctx context.Context, ccfg *cliConfig) error { - return trace.NotImplemented("TODO") -} - -// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address. -func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { - return trace.NotImplemented("TODO") -} + versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) + if err := os.MkdirAll(versionsDir, 0755); err != nil { + return trace.Errorf("failed to create versions directory: %w", err) + } -//nolint:unused // scaffolding used in upcoming PR -type downloadConfig struct { - // Insecure turns off TLS certificate verification when enabled. - Insecure bool - // Pool defines the set of root CAs to use when verifying server - // certificates. - Pool *x509.CertPool - // Timeout is a timeout for requests. - Timeout time.Duration -} + // Ensure enable can't run concurrently. + unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + if err != nil { + return trace.Errorf("failed to grab concurrent execution lock: %w", err) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() -//nolint:unused // scaffolding used in upcoming PR -func newClient(cfg *downloadConfig) (*http.Client, error) { - tr, err := libdefaults.Transport() + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + VersionsDir: versionsDir, + Log: plog, + }) if err != nil { - return nil, trace.Wrap(err) + return trace.Errorf("failed to setup updater: %w", err) } - tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: cfg.Insecure, - RootCAs: cfg.Pool, + if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil { + return trace.Wrap(err) } - return &http.Client{ - Transport: tr, - Timeout: cfg.Timeout, - }, nil + return nil +} + +// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address. +func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { + return trace.NotImplemented("TODO") }