From 6a3f71d1915fcffb26c1e47d38e78e0859da071f Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Fri, 18 Oct 2024 13:05:08 -0400 Subject: [PATCH] Automatically enable tunnel for "tsh proxy db" --- tool/tsh/common/db.go | 23 +++++++++++++++++ tool/tsh/common/db_test.go | 53 ++++++++++++++++++++++++++++++++++++++ tool/tsh/common/proxy.go | 20 +++++++++----- 3 files changed, 90 insertions(+), 6 deletions(-) diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index cf107f20dc8c..862b9b8fbbe2 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -1783,6 +1783,18 @@ func formatAmbiguousDB(cf *CLIConf, selectors resourceSelectors, matchedDBs type return formatAmbiguityErrTemplate(cf, selectors, listCommand, sb.String(), fullNameExample) } +// formatDbProxyAutoTunnel makes the message printed when "tsh proxy db" +// automatically enables the "--tunnel" flag. +func formatDbProxyAutoTunnel(reasons ...string) string { + templateData := map[string]any{ + "reasons": reasons, + } + + buf := bytes.NewBuffer(nil) + _ = dbProxyAutoTunnelTemplate.Execute(buf, templateData) + return buf.String() +} + // resourceSelectors is a helper struct for gathering up the selectors for a // resource, as an aggregate of name, labels, and predicate query. type resourceSelectors struct { @@ -1870,6 +1882,17 @@ Please use one of the following commands to connect to the database: {{.}}{{end -}} {{- end}}`)) + // dbProxyAutoTunnelTemplate is the message printed when "tsh proxy db" + // automatically enables the "--tunnel" flag. + dbProxyAutoTunnelTemplate = template.Must(template.New("").Parse(`Note: "--tunnel" flag has been automatically enabled{{if .reasons}} when: +{{- range $reason := .reasons }} + - {{ $reason }}. +{{- end}} +{{- else}}. +{{- end}} +To avoid this note, please add the "--tunnel" flag to this "tsh proxy db" command. + +`)) // dbConnectTemplate is the message printed after a successful "tsh db login" on how to connect. dbConnectTemplate = template.Must(template.New("").Parse(`Connection information for database "{{ .name }}" has been saved. diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index 7ae5ba046f8b..6273030fd710 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -1972,3 +1972,56 @@ func Test_shouldRetryGetDatabaseUsingSearchAsRoles(t *testing.T) { }) } } + +func Test_maybeEnableDbProxyTunnel(t *testing.T) { + tests := []struct { + name string + dbRequiresTunnel bool + tunnelFlagProvided bool + wantTunnelFlagEnabled bool + wantMessage string + }{ + { + name: "tunnel already enabled", + dbRequiresTunnel: true, + tunnelFlagProvided: true, + wantTunnelFlagEnabled: true, + }, + { + name: "tunnel not required", + dbRequiresTunnel: false, + tunnelFlagProvided: false, + wantTunnelFlagEnabled: false, + }, + { + name: "tunnel enabled", + dbRequiresTunnel: true, + tunnelFlagProvided: false, + wantTunnelFlagEnabled: true, + wantMessage: `Note: "--tunnel" flag has been automatically enabled when: + - tunnel required by unit test. +To avoid this note, please add the "--tunnel" flag to this "tsh proxy db" command. + +`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var buf bytes.Buffer + cf := &CLIConf{ + LocalProxyTunnel: test.tunnelFlagProvided, + OverrideStdout: &buf, + } + requires := &dbLocalProxyRequirement{ + tunnel: test.dbRequiresTunnel, + tunnelReasons: []string{"tunnel required by unit test"}, + } + + maybeEnableDbProxyTunnel(cf, requires) + + require.Equal(t, test.wantTunnelFlagEnabled, cf.LocalProxyTunnel) + require.Equal(t, test.wantMessage, buf.String()) + }) + } +} diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 4f0f1fee9213..4c845f6084aa 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -172,12 +172,7 @@ func onProxyCommandDB(cf *CLIConf) error { // These steps are not needed with `--tunnel`, because the local proxy tunnel // will manage database certificates itself and reissue them as needed. requires := getDBLocalProxyRequirement(tc, dbInfo.RouteToDatabase) - if requires.tunnel && !cf.LocalProxyTunnel { - // Some scenarios require a local proxy tunnel, e.g.: - // - Snowflake, DynamoDB protocol - // - Hardware-backed private key policy - return trace.BadParameter(formatDbCmdUnsupported(cf, dbInfo.RouteToDatabase, requires.tunnelReasons...)) - } + maybeEnableDbProxyTunnel(cf, requires) if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil { return trace.Wrap(err) } @@ -291,6 +286,19 @@ func onProxyCommandDB(cf *CLIConf) error { return nil } +// maybeEnableDbProxyTunnel forces cf.LocalProxyTunnel to true for scenarios require +// a local proxy tunnel, e.g.: +// - Snowflake, DynamoDB protocol +// - Hardware-backed private key policy +func maybeEnableDbProxyTunnel(cf *CLIConf, requires *dbLocalProxyRequirement) { + if requires.tunnel && !cf.LocalProxyTunnel { + cf.LocalProxyTunnel = true + + msg := formatDbProxyAutoTunnel(requires.tunnelReasons...) + fmt.Fprintf(cf.Stdout(), msg) + } +} + func maybeAddDBUserPassword(cf *CLIConf, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { if dbInfo.Protocol == defaults.ProtocolCassandra { db, err := dbInfo.GetDatabase(cf.Context, tc)