diff --git a/api/types/app.go b/api/types/app.go index 7806d54c2b89..153a8b1cfa60 100644 --- a/api/types/app.go +++ b/api/types/app.go @@ -88,6 +88,55 @@ type Application interface { GetCORS() *CORSPolicy } +// ReadOnlyApplication is a read only variant of Application. +type ReadOnlyApplication interface { + // ReadOnlyResourceWithLabels provides common resource methods. + ReadOnlyResourceWithLabels + // GetNamespace returns the app namespace. + GetNamespace() string + // GetStaticLabels returns the app static labels. + GetStaticLabels() map[string]string + // GetDynamicLabels returns the app dynamic labels. + GetDynamicLabels() map[string]CommandLabel + // String returns string representation of the app. + String() string + // GetDescription returns the app description. + GetDescription() string + // GetURI returns the app connection endpoint. + GetURI() string + // GetPublicAddr returns the app public address. + GetPublicAddr() string + // GetInsecureSkipVerify returns the app insecure setting. + GetInsecureSkipVerify() bool + // GetRewrite returns the app rewrite configuration. + GetRewrite() *Rewrite + // IsAWSConsole returns true if this app is AWS management console. + IsAWSConsole() bool + // IsAzureCloud returns true if this app represents Azure Cloud instance. + IsAzureCloud() bool + // IsGCP returns true if this app represents GCP instance. + IsGCP() bool + // IsTCP returns true if this app represents a TCP endpoint. + IsTCP() bool + // GetProtocol returns the application protocol. + GetProtocol() string + // GetAWSAccountID returns value of label containing AWS account ID on this app. + GetAWSAccountID() string + // GetAWSExternalID returns the AWS External ID configured for this app. + GetAWSExternalID() string + // GetUserGroups will get the list of user group IDs associated with the application. + GetUserGroups() []string + // Copy returns a copy of this app resource. + Copy() *AppV3 + // GetIntegration will return the Integration. + // If present, the Application must use the Integration's credentials instead of ambient credentials to access Cloud APIs. + GetIntegration() string + // GetRequiredAppNames will return a list of required apps names that should be authenticated during this apps authentication process. + GetRequiredAppNames() []string + // GetCORS returns the CORS configuration for the app. + GetCORS() *CORSPolicy +} + // NewAppV3 creates a new app resource. func NewAppV3(meta Metadata, spec AppSpecV3) (*AppV3, error) { app := &AppV3{ diff --git a/api/types/database.go b/api/types/database.go index 70a1844cce87..7224c103dcaa 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -147,6 +147,94 @@ type Database interface { IsUsernameCaseInsensitive() bool } +// ReadOnlyDatabase is a read only variant of Database. +type ReadOnlyDatabase interface { + // ReadOnlyResourceWithLabels provides common resource methods. + ReadOnlyResourceWithLabels + // GetNamespace returns the database namespace. + GetNamespace() string + // GetStaticLabels returns the database static labels. + GetStaticLabels() map[string]string + // GetDynamicLabels returns the database dynamic labels. + GetDynamicLabels() map[string]CommandLabel + // String returns string representation of the database. + String() string + // GetDescription returns the database description. + GetDescription() string + // GetProtocol returns the database protocol. + GetProtocol() string + // GetURI returns the database connection endpoint. + GetURI() string + // GetCA returns the database CA certificate. + GetCA() string + // GetTLS returns the database TLS configuration. + GetTLS() DatabaseTLS + // GetStatusCA gets the database CA certificate in the status field. + GetStatusCA() string + // GetMySQL returns the database options from spec. + GetMySQL() MySQLOptions + // GetOracle returns the database options from spec. + GetOracle() OracleOptions + // GetMySQLServerVersion returns the MySQL server version either from configuration or + // reported by the database. + GetMySQLServerVersion() string + // GetAWS returns the database AWS metadata. + GetAWS() AWS + // GetGCP returns GCP information for Cloud SQL databases. + GetGCP() GCPCloudSQL + // GetAzure returns Azure database server metadata. + GetAzure() Azure + // GetAD returns Active Directory database configuration. + GetAD() AD + // GetType returns the database authentication type: self-hosted, RDS, Redshift or Cloud SQL. + GetType() string + // GetSecretStore returns secret store configurations. + GetSecretStore() SecretStore + // GetManagedUsers returns a list of database users that are managed by Teleport. + GetManagedUsers() []string + // GetMongoAtlas returns Mongo Atlas database metadata. + GetMongoAtlas() MongoAtlas + // IsRDS returns true if this is an RDS/Aurora database. + IsRDS() bool + // IsRDSProxy returns true if this is an RDS Proxy database. + IsRDSProxy() bool + // IsRedshift returns true if this is a Redshift database. + IsRedshift() bool + // IsCloudSQL returns true if this is a Cloud SQL database. + IsCloudSQL() bool + // IsAzure returns true if this is an Azure database. + IsAzure() bool + // IsElastiCache returns true if this is an AWS ElastiCache database. + IsElastiCache() bool + // IsMemoryDB returns true if this is an AWS MemoryDB database. + IsMemoryDB() bool + // IsAWSHosted returns true if database is hosted by AWS. + IsAWSHosted() bool + // IsCloudHosted returns true if database is hosted in the cloud (AWS, Azure or Cloud SQL). + IsCloudHosted() bool + // RequireAWSIAMRolesAsUsers returns true for database types that require + // AWS IAM roles as database users. + RequireAWSIAMRolesAsUsers() bool + // SupportAWSIAMRoleARNAsUsers returns true for database types that support + // AWS IAM roles as database users. + SupportAWSIAMRoleARNAsUsers() bool + // Copy returns a copy of this database resource. + Copy() *DatabaseV3 + // GetAdminUser returns database privileged user information. + GetAdminUser() DatabaseAdminUser + // SupportsAutoUsers returns true if this database supports automatic + // user provisioning. + SupportsAutoUsers() bool + // GetEndpointType returns the endpoint type of the database, if available. + GetEndpointType() string + // GetCloud gets the cloud this database is running on, or an empty string if it + // isn't running on a cloud provider. + GetCloud() string + // IsUsernameCaseInsensitive returns true if the database username is case + // insensitive. + IsUsernameCaseInsensitive() bool +} + // NewDatabaseV3 creates a new database resource. func NewDatabaseV3(meta Metadata, spec DatabaseSpecV3) (*DatabaseV3, error) { database := &DatabaseV3{ diff --git a/api/types/kubernetes.go b/api/types/kubernetes.go index 793ac9763b57..70ce8387b465 100644 --- a/api/types/kubernetes.go +++ b/api/types/kubernetes.go @@ -80,6 +80,43 @@ type KubeCluster interface { GetCloud() string } +// ReadOnlyKubeCluster is a read only variant of KubeCluster. +type ReadOnlyKubeCluster interface { + // ReadOnlyResourceWithLabels provides common resource methods. + ReadOnlyResourceWithLabels + // GetNamespace returns the kube cluster namespace. + GetNamespace() string + // GetStaticLabels returns the kube cluster static labels. + GetStaticLabels() map[string]string + // GetDynamicLabels returns the kube cluster dynamic labels. + GetDynamicLabels() map[string]CommandLabel + // GetKubeconfig returns the kubeconfig payload. + GetKubeconfig() []byte + // String returns string representation of the kube cluster. + String() string + // GetDescription returns the kube cluster description. + GetDescription() string + // GetAzureConfig gets the Azure config. + GetAzureConfig() KubeAzure + // GetAWSConfig gets the AWS config. + GetAWSConfig() KubeAWS + // GetGCPConfig gets the GCP config. + GetGCPConfig() KubeGCP + // IsAzure indentifies if the KubeCluster contains Azure details. + IsAzure() bool + // IsAWS indentifies if the KubeCluster contains AWS details. + IsAWS() bool + // IsGCP indentifies if the KubeCluster contains GCP details. + IsGCP() bool + // IsKubeconfig identifies if the KubeCluster contains kubeconfig data. + IsKubeconfig() bool + // Copy returns a copy of this kube cluster resource. + Copy() *KubernetesClusterV3 + // GetCloud gets the cloud this kube cluster is running on, or an empty string if it + // isn't running on a cloud provider. + GetCloud() string +} + // DiscoveredEKSCluster represents a server discovered by EKS discovery fetchers. type DiscoveredEKSCluster interface { // KubeCluster is base discovered cluster. diff --git a/api/types/kubernetes_server.go b/api/types/kubernetes_server.go index f4cee7ca3649..a4fd175b0b0f 100644 --- a/api/types/kubernetes_server.go +++ b/api/types/kubernetes_server.go @@ -60,6 +60,32 @@ type KubeServer interface { ProxiedService } +// ReadOnlyKubeServer is a read only variant of KubeServer. +type ReadOnlyKubeServer interface { + // ReadOnlyResourceWithLabels provides common resource methods. + ReadOnlyResourceWithLabels + // GetNamespace returns server namespace. + GetNamespace() string + // GetTeleportVersion returns the teleport version the server is running on. + GetTeleportVersion() string + // GetHostname returns the server hostname. + GetHostname() string + // GetHostID returns ID of the host the server is running on. + GetHostID() string + // GetRotation gets the state of certificate authority rotation. + GetRotation() Rotation + // String returns string representation of the server. + String() string + // Copy returns a copy of this kube server object. + Copy() KubeServer + // CloneResource returns a copy of the KubeServer as a ResourceWithLabels + CloneResource() ResourceWithLabels + // GetCluster returns the Kubernetes Cluster this kube server proxies. + GetCluster() KubeCluster + // GetProxyIDs returns a list of proxy ids this service is connected to. + GetProxyIDs() []string +} + // NewKubernetesServerV3 creates a new kube server instance. func NewKubernetesServerV3(meta Metadata, spec KubernetesServerSpecV3) (*KubernetesServerV3, error) { s := &KubernetesServerV3{ diff --git a/api/types/resource.go b/api/types/resource.go index ec87a72c97a8..fadc99716fc7 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -66,6 +66,24 @@ type Resource interface { SetRevision(string) } +// ReadOnlyResource is a read only variant of Resource. +type ReadOnlyResource interface { + // GetKind returns resource kind + GetKind() string + // GetSubKind returns resource subkind + GetSubKind() string + // GetVersion returns resource version + GetVersion() string + // GetName returns the name of the resource + GetName() string + // Expiry returns object expiry setting + Expiry() time.Time + // GetMetadata returns object metadata + GetMetadata() Metadata + // GetRevision returns the revision + GetRevision() string +} + // IsSystemResource checks to see if the given resource is considered // part of the teleport system, as opposed to some user created resource // or preset. @@ -109,6 +127,13 @@ type ResourceWithOrigin interface { SetOrigin(string) } +// ReadOnlyResourceWithOrigin is a read only variant of ResourceWithOrigin. +type ReadOnlyResourceWithOrigin interface { + ReadOnlyResource + // Origin returns the origin value of the resource. + Origin() string +} + // ResourceWithLabels is a common interface for resources that have labels. type ResourceWithLabels interface { // ResourceWithOrigin is the base resource interface. @@ -126,6 +151,20 @@ type ResourceWithLabels interface { MatchSearch(searchValues []string) bool } +// ReadOnlyResourceWithLabels is a read only variant of ResourceWithLabels. +type ReadOnlyResourceWithLabels interface { + ReadOnlyResourceWithOrigin + // GetLabel retrieves the label with the provided key. + GetLabel(key string) (value string, ok bool) + // GetAllLabels returns all resource's labels. + GetAllLabels() map[string]string + // GetStaticLabels returns the resource's static labels. + GetStaticLabels() map[string]string + // MatchSearch goes through select field values of a resource + // and tries to match against the list of search values. + MatchSearch(searchValues []string) bool +} + // EnrichedResource is a [ResourceWithLabels] wrapped with // additional user-specific information. type EnrichedResource struct { diff --git a/api/types/server.go b/api/types/server.go index 47ec7c92ee1d..723c948947a5 100644 --- a/api/types/server.go +++ b/api/types/server.go @@ -103,6 +103,62 @@ type Server interface { GetAWSAccountID() string } +// ReadOnlyServer is a read only variant of Server. +type ReadOnlyServer interface { + // ReadOnlyResourceWithLabels provides common resource headers + ReadOnlyResourceWithLabels + // GetTeleportVersion returns the teleport version the server is running on + GetTeleportVersion() string + // GetAddr return server address + GetAddr() string + // GetHostname returns server hostname + GetHostname() string + // GetNamespace returns server namespace + GetNamespace() string + // GetLabels returns server's static label key pairs + GetLabels() map[string]string + // GetCmdLabels gets command labels + GetCmdLabels() map[string]CommandLabel + // GetPublicAddr returns a public address where this server can be reached. + GetPublicAddr() string + // GetPublicAddrs returns a list of public addresses where this server can be reached. + GetPublicAddrs() []string + // GetRotation gets the state of certificate authority rotation. + GetRotation() Rotation + // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. + GetUseTunnel() bool + // String returns string representation of the server + String() string + // GetPeerAddr returns the peer address of the server. + GetPeerAddr() string + // GetProxyIDs returns a list of proxy ids this service is connected to. + GetProxyIDs() []string + // DeepCopy creates a clone of this server value + DeepCopy() Server + + // CloneResource is used to return a clone of the Server and match the CloneAny interface + // This is helpful when interfacing with multiple types at the same time in unified resources + CloneResource() ResourceWithLabels + + // GetCloudMetadata gets the cloud metadata for the server. + GetCloudMetadata() *CloudMetadata + // GetAWSInfo returns the AWSInfo for the server. + GetAWSInfo() *AWSInfo + + // IsOpenSSHNode returns whether the connection to this Server must use OpenSSH. + // This returns true for SubKindOpenSSHNode and SubKindOpenSSHEICENode. + IsOpenSSHNode() bool + + // IsEICE returns whether the Node is an EICE instance. + // Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels). + IsEICE() bool + + // GetAWSInstanceID returns the AWS Instance ID if this node comes from an EC2 instance. + GetAWSInstanceID() string + // GetAWSAccountID returns the AWS Account ID if this node comes from an EC2 instance. + GetAWSAccountID() string +} + // NewServer creates an instance of Server. func NewServer(name, kind string, spec ServerSpecV2) (Server, error) { return NewServerWithLabels(name, kind, spec, map[string]string{}) diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index 580abc957795..fce25155b66a 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -98,7 +98,7 @@ type TLSServerConfig struct { // kubernetes cluster name. Proxy uses this map to route requests to the correct // kubernetes_service. The servers are kept in memory to avoid making unnecessary // unmarshal calls followed by filtering and to improve memory usage. - KubernetesServersWatcher *services.KubeServerWatcher + KubernetesServersWatcher *services.GenericWatcher[types.KubeServer, types.ReadOnlyKubeServer] // PROXYProtocolMode controls behavior related to unsigned PROXY protocol headers. PROXYProtocolMode multiplexer.PROXYProtocolMode // InventoryHandle is used to send kube server heartbeats via the inventory control stream. @@ -170,7 +170,7 @@ type TLSServer struct { closeContext context.Context closeFunc context.CancelFunc // kubeClusterWatcher monitors changes to kube cluster resources. - kubeClusterWatcher *services.KubeClusterWatcher + kubeClusterWatcher *services.GenericWatcher[types.KubeCluster, types.ReadOnlyKubeCluster] // reconciler reconciles proxied kube clusters with kube_clusters resources. reconciler *services.Reconciler[types.KubeCluster] // monitoredKubeClusters contains all kube clusters the proxied kube_clusters are @@ -620,7 +620,9 @@ func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNa }, nil case ProxyService: return func(ctx context.Context, name string) ([]types.KubeServer, error) { - servers, err := t.KubernetesServersWatcher.GetKubeServersByClusterName(ctx, name) + servers, err := t.KubernetesServersWatcher.CurrentResourcesWithFilter(ctx, func(ks types.ReadOnlyKubeServer) bool { + return ks.GetCluster().GetName() == name + }) return servers, trace.Wrap(err) }, nil case LegacyProxyService: @@ -630,7 +632,9 @@ func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNa // and forward the request to the next proxy. kube, err := t.getKubeClusterWithServiceLabels(name) if err != nil { - servers, err := t.KubernetesServersWatcher.GetKubeServersByClusterName(ctx, name) + servers, err := t.KubernetesServersWatcher.CurrentResourcesWithFilter(ctx, func(ks types.ReadOnlyKubeServer) bool { + return ks.GetCluster().GetName() == name + }) return servers, trace.Wrap(err) } srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID) diff --git a/lib/kube/proxy/utils_testing.go b/lib/kube/proxy/utils_testing.go index 4621b7d51bec..462638df203c 100644 --- a/lib/kube/proxy/utils_testing.go +++ b/lib/kube/proxy/utils_testing.go @@ -294,6 +294,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo Component: teleport.ComponentKube, Client: client, }, + KubernetesServerGetter: client, }, ) require.NoError(t, err) @@ -387,7 +388,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // Ensure watcher has the correct list of clusters. require.Eventually(t, func() bool { - kubeServers, err := kubeServersWatcher.GetKubernetesServers(ctx) + kubeServers, err := kubeServersWatcher.CurrentResources(ctx) return err == nil && len(kubeServers) == len(cfg.Clusters) }, 3*time.Second, time.Millisecond*100) diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go index 188d16822a38..365fc18dc82d 100644 --- a/lib/kube/proxy/watcher.go +++ b/lib/kube/proxy/watcher.go @@ -87,7 +87,7 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) { // startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and // registers/unregisters the proxied Kube Cluster accordingly. -func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) { +func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster, types.ReadOnlyKubeCluster], error) { if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService { s.log.Debug("Not initializing Kube Cluster resource watcher.") return nil, nil @@ -100,6 +100,7 @@ func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*servi // Logger: s.log, Client: s.AccessPoint, }, + KubernetesClusterGetter: s.AccessPoint, }) if err != nil { return nil, trace.Wrap(err) @@ -108,7 +109,7 @@ func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*servi defer watcher.Close() for { select { - case clusters := <-watcher.KubeClustersC: + case clusters := <-watcher.ResourcesC: s.monitoredKubeClusters.setResources(clusters) select { case s.reconcileCh <- struct{}{}: diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index c00b67006875..fe3659a92bf4 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -51,6 +51,7 @@ import ( // AccessPoint is the subset of the auth cache consumed by the [Client]. type AccessPoint interface { types.Events + services.ProxyGetter } // ClientConfig configures a Client instance. @@ -416,6 +417,7 @@ func (c *Client) sync() { Client: c.config.AccessPoint, Logger: c.config.Log, }, + ProxyGetter: c.config.AccessPoint, ProxyDiffer: func(old, new types.Server) bool { return old.GetPeerAddr() != new.GetPeerAddr() }, @@ -434,7 +436,7 @@ func (c *Client) sync() { case <-proxyWatcher.Done(): c.config.Log.DebugContext(c.ctx, "stopping peer proxy sync: proxy watcher done") return - case proxies := <-proxyWatcher.ProxiesC: + case proxies := <-proxyWatcher.ResourcesC: if err := c.updateConnections(proxies); err != nil { c.config.Log.ErrorContext(c.ctx, "error syncing peer proxies", "error", err) } diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 18e22adc798e..99b216c35d6c 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -383,7 +383,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check // site is the minimum interface needed to match servers // for a reversetunnelclient.RemoteSite. It makes testing easier. type site interface { - GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) + GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) } @@ -394,13 +394,13 @@ type remoteSite struct { } // GetNodes uses the wrapped sites NodeWatcher to filter nodes -func (r remoteSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) { +func (r remoteSite) GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error) { watcher, err := r.site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } - return watcher.GetNodes(ctx, fn), nil + return watcher.CurrentResourcesWithFilter(ctx, fn) } // GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig @@ -450,7 +450,7 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re var maxScore int scores := make(map[string]int) - matches, err := site.GetNodes(ctx, func(server services.Node) bool { + matches, err := site.GetNodes(ctx, func(server types.ReadOnlyServer) bool { score := routeMatcher.RouteToServerScore(server) if score < 1 { return false diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index 48268cf35596..338f4575aebf 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -37,7 +37,6 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" ) @@ -51,7 +50,7 @@ func (t testSite) GetClusterNetworkingConfig(ctx context.Context) (types.Cluster return t.cfg, nil } -func (t testSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) { +func (t testSite) GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error) { var out []types.Server for _, s := range t.nodes { if fn(s) { diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 54e9d5db3a68..57a37e57490f 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -180,7 +180,7 @@ func (s *localSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, err } // NodeWatcher returns a services.NodeWatcher for this cluster. -func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) { +func (s *localSite) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { return s.srv.NodeWatcher, nil } @@ -739,7 +739,11 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch return case <-proxyResyncTicker.Chan(): var req discoveryRequest - req.SetProxies(s.srv.proxyWatcher.GetCurrent()) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + req.SetProxies(proxies) if err := rconn.sendDiscoveryRequest(req); err != nil { logger.WithError(err).Debug("Marking connection invalid on error") @@ -764,9 +768,12 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - current := s.srv.proxyWatcher.GetCurrent() - if len(current) > 0 { - rconn.updateProxies(current) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + if len(proxies) > 0 { + rconn.updateProxies(proxies) } reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc() firstHeartbeat = false @@ -935,7 +942,7 @@ func (s *localSite) periodicFunctions() { // sshTunnelStats reports SSH tunnel statistics for the cluster. func (s *localSite) sshTunnelStats() error { - missing := s.srv.NodeWatcher.GetNodes(s.srv.ctx, func(server services.Node) bool { + missing, err := s.srv.NodeWatcher.CurrentResourcesWithFilter(s.srv.ctx, func(server types.ReadOnlyServer) bool { // Skip over any servers that have a TTL larger than announce TTL (10 // minutes) and are non-IoT SSH servers (they won't have tunnels). // @@ -967,6 +974,9 @@ func (s *localSite) sshTunnelStats() error { return err != nil }) + if err != nil { + return trace.Wrap(err) + } // Update Prometheus metrics and also log if any tunnels are missing. missingSSHTunnels.Set(float64(len(missing))) diff --git a/lib/reversetunnel/localsite_test.go b/lib/reversetunnel/localsite_test.go index 3397aed76368..195a1e76510c 100644 --- a/lib/reversetunnel/localsite_test.go +++ b/lib/reversetunnel/localsite_test.go @@ -58,14 +58,16 @@ func TestRemoteConnCleanup(t *testing.T) { clock := clockwork.NewFakeClock() + clt := &mockLocalSiteClient{} watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", Logger: utils.NewSlogLoggerForTests(), Clock: clock, - Client: &mockLocalSiteClient{}, + Client: clt, }, - ProxiesC: make(chan []types.Server, 2), + ProxyGetter: clt, + ProxiesC: make(chan []types.Server, 2), }) require.NoError(t, err) require.NoError(t, watcher.WaitInitialization()) @@ -249,17 +251,19 @@ func TestProxyResync(t *testing.T) { proxy2, err := types.NewServer(uuid.NewString(), types.KindProxy, types.ServerSpecV2{}) require.NoError(t, err) + clt := &mockLocalSiteClient{ + proxies: []types.Server{proxy1, proxy2}, + } // set up the watcher and wait for it to be initialized watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", Logger: utils.NewSlogLoggerForTests(), Clock: clock, - Client: &mockLocalSiteClient{ - proxies: []types.Server{proxy1, proxy2}, - }, + Client: clt, }, - ProxiesC: make(chan []types.Server, 2), + ProxyGetter: clt, + ProxiesC: make(chan []types.Server, 2), }) require.NoError(t, err) require.NoError(t, watcher.WaitInitialization()) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index fc16cbe11cef..2460538954ac 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -90,7 +90,7 @@ func (p *clusterPeers) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, return peer.CachingAccessPoint() } -func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) { +func (p *clusterPeers) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { peer, err := p.pickPeer() if err != nil { return nil, trace.Wrap(err) @@ -202,7 +202,7 @@ func (s *clusterPeer) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, e return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) } -func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) { +func (s *clusterPeer) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { return nil, trace.ConnectionProblem(nil, "unable to fetch node watcher, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 8e8b7e4c3fe7..d2c6e669b944 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -85,7 +85,7 @@ type remoteSite struct { remoteAccessPoint authclient.RemoteProxyAccessPoint // nodeWatcher provides access the node set for the remote site - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation @@ -164,7 +164,7 @@ func (s *remoteSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, er } // NodeWatcher returns the services.NodeWatcher for the remote cluster. -func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) { +func (s *remoteSite) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { return s.nodeWatcher, nil } @@ -429,7 +429,11 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch return case <-proxyResyncTicker.Chan(): var req discoveryRequest - req.SetProxies(s.srv.proxyWatcher.GetCurrent()) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + req.SetProxies(proxies) if err := conn.sendDiscoveryRequest(req); err != nil { logger.WithError(err).Debug("Marking connection invalid on error") @@ -458,9 +462,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - current := s.srv.proxyWatcher.GetCurrent() - if len(current) > 0 { - conn.updateProxies(current) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + if len(proxies) > 0 { + conn.updateProxies(proxies) } firstHeartbeat = false } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 19dfd9e2d43c..be125e075f5e 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -114,7 +114,7 @@ type server struct { // proxyWatcher monitors changes to the proxies // and broadcasts updates - proxyWatcher *services.ProxyWatcher + proxyWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // offlineThreshold is how long to wait for a keep alive message before // marking a reverse tunnel connection as invalid. @@ -201,7 +201,7 @@ type Config struct { LockWatcher *services.LockWatcher // NodeWatcher is a node watcher. - NodeWatcher *services.NodeWatcher + NodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // CertAuthorityWatcher is a cert authority watcher. CertAuthorityWatcher *services.CertAuthorityWatcher @@ -307,9 +307,6 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) { }, ProxiesC: make(chan []types.Server, 10), ProxyGetter: cfg.LocalAccessPoint, - ProxyDiffer: func(_, _ types.Server) bool { - return true // we always want to store the most recently heartbeated proxy - }, }) if err != nil { cancel() @@ -401,7 +398,7 @@ func (s *server) periodicFunctions() { s.log.Debugf("Closing.") return // Proxies have been updated, notify connected agents about the update. - case proxies := <-s.proxyWatcher.ProxiesC: + case proxies := <-s.proxyWatcher.ResourcesC: s.fanOutProxies(proxies) case <-ticker.C: if err := s.fetchClusterPeers(); err != nil { diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index f7e8dfb47ef6..4334c05dcafa 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -123,7 +123,7 @@ type RemoteSite interface { // but is resilient to auth server crashes CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) // NodeWatcher returns the node watcher that maintains the node set for the site - NodeWatcher() (*services.NodeWatcher, error) + NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/service/service.go b/lib/service/service.go index 0870695cef8f..d552063d6a83 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -5025,6 +5025,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Logger: process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), Client: accessPoint, }, + KubernetesServerGetter: accessPoint, }) if err != nil { return trace.Wrap(err) diff --git a/lib/services/watcher.go b/lib/services/watcher.go index f8151d491583..3b97338dfe6d 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -87,21 +87,21 @@ func watchKindsString(kinds []types.WatchKind) string { // ResourceWatcherConfig configures resource watcher. type ResourceWatcherConfig struct { - // Component is a component used in logs. - Component string + // Clock is used to control time. + Clock clockwork.Clock + // Client is used to create new watchers + Client types.Events // Logger emits log messages. Logger *slog.Logger + // ResetC is a channel to notify of internal watcher reset (used in tests). + ResetC chan time.Duration + // Component is a component used in logs. + Component string // MaxRetryPeriod is the maximum retry period on failed watchers. MaxRetryPeriod time.Duration - // Clock is used to control time. - Clock clockwork.Clock - // Client is used to create new watchers. - Client types.Events // MaxStaleness is a maximum acceptable staleness for the locally maintained // resources, zero implies no staleness detection. MaxStaleness time.Duration - // ResetC is a channel to notify of internal watcher reset (used in tests). - ResetC chan time.Duration // QueueSize is an optional queue size QueueSize int } @@ -165,28 +165,23 @@ func newResourceWatcher(ctx context.Context, collector resourceCollector, cfg Re // resourceWatcher monitors additions, updates and deletions // to a set of resources. type resourceWatcher struct { - ResourceWatcherConfig - collector resourceCollector - - // ctx is a context controlling the lifetime of this resourceWatcher - // instance. - ctx context.Context - cancel context.CancelFunc - - // retry is used to manage backoff logic for watchers. - retry retryutils.Retry - // failureStartedAt records when the current sync failures were first // detected, zero if there are no failures present. failureStartedAt time.Time - + collector resourceCollector + // ctx is a context controlling the lifetime of this resourceWatcher + // instance. + ctx context.Context + // retry is used to manage backoff logic for watchers. + retry retryutils.Retry + cancel context.CancelFunc // LoopC is a channel to check whether the watch loop is running // (used in tests). LoopC chan struct{} - // StaleC is a channel that can trigger the condition of resource staleness // (used in tests). StaleC chan struct{} + ResourceWatcherConfig } // Done returns a channel that signals resource watcher closure. @@ -380,107 +375,392 @@ func (p *resourceWatcher) watch() error { // ProxyWatcherConfig is a ProxyWatcher configuration. type ProxyWatcherConfig struct { - ResourceWatcherConfig // ProxyGetter is used to directly fetch the list of active proxies. ProxyGetter // ProxyDiffer is used to decide whether a put operation on an existing proxy should // trigger a event. ProxyDiffer func(old, new types.Server) bool // ProxiesC is a channel used to report the current proxy set. It receives - // a fresh list at startup and subsequently a list of all known proxies + // a fresh list at startup and subsequently a list of all known proxy // whenever an addition or deletion is detected. ProxiesC chan []types.Server + ResourceWatcherConfig +} + +// NewProxyWatcher returns a new instance of GenericWatcher that is configured +// to watch for changes. +func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*GenericWatcher[types.Server, types.ReadOnlyServer], error) { + if cfg.ProxyGetter == nil { + return nil, trace.BadParameter("ProxyGetter must be provided") + } + + if cfg.ProxyDiffer == nil { + cfg.ProxyDiffer = func(old, new types.Server) bool { return true } + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, types.ReadOnlyServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindProxy, + ResourceKey: types.Server.GetName, + ResourceGetter: func(ctx context.Context) ([]types.Server, error) { + return cfg.ProxyGetter.GetProxies() + }, + ResourcesC: cfg.ProxiesC, + ResourceDiffer: cfg.ProxyDiffer, + RequireResourcesForInitialBroadcast: true, + CloneFunc: types.Server.DeepCopy, + }) + return w, trace.Wrap(err) +} + +// DatabaseWatcherConfig is a DatabaseWatcher configuration. +type DatabaseWatcherConfig struct { + // DatabaseGetter is responsible for fetching database resources. + DatabaseGetter + // DatabasesC receives up-to-date list of all database resources. + DatabasesC chan []types.Database + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewDatabaseWatcher returns a new instance of DatabaseWatcher. +func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*GenericWatcher[types.Database, types.ReadOnlyDatabase], error) { + if cfg.DatabaseGetter == nil { + return nil, trace.BadParameter("DatabaseGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Database, types.ReadOnlyDatabase]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindDatabase, + ResourceKey: types.Database.GetName, + ResourceGetter: func(ctx context.Context) ([]types.Database, error) { + return cfg.DatabaseGetter.GetDatabases(ctx) + }, + ResourcesC: cfg.DatabasesC, + CloneFunc: func(resource types.Database) types.Database { + return resource.Copy() + }, + }) + return w, trace.Wrap(err) +} + +// AppWatcherConfig is an AppWatcher configuration. +type AppWatcherConfig struct { + // AppGetter is responsible for fetching application resources. + AppGetter + // AppsC receives up-to-date list of all application resources. + AppsC chan []types.Application + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewAppWatcher returns a new instance of AppWatcher. +func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*GenericWatcher[types.Application, types.ReadOnlyApplication], error) { + if cfg.AppGetter == nil { + return nil, trace.BadParameter("AppGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Application, types.ReadOnlyApplication]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindApp, + ResourceKey: types.Application.GetName, + ResourceGetter: func(ctx context.Context) ([]types.Application, error) { + return cfg.AppGetter.GetApps(ctx) + }, + ResourcesC: cfg.AppsC, + CloneFunc: func(resource types.Application) types.Application { + return resource.Copy() + }, + }) + + return w, trace.Wrap(err) +} + +// KubeServerWatcherConfig is an KubeServerWatcher configuration. +type KubeServerWatcherConfig struct { + // KubernetesServerGetter is responsible for fetching kube_server resources. + KubernetesServerGetter + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewKubeServerWatcher returns a new instance of KubeServerWatcher. +func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*GenericWatcher[types.KubeServer, types.ReadOnlyKubeServer], error) { + if cfg.KubernetesServerGetter == nil { + return nil, trace.BadParameter("KubernetesServerGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeServer, types.ReadOnlyKubeServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindKubeServer, + ResourceGetter: func(ctx context.Context) ([]types.KubeServer, error) { + return cfg.KubernetesServerGetter.GetKubernetesServers(ctx) + }, + ResourceKey: func(resource types.KubeServer) string { + return resource.GetHostID() + resource.GetName() + }, + DisableUpdateBroadcast: true, + CloneFunc: types.KubeServer.Copy, + }) + return w, trace.Wrap(err) +} + +// KubeClusterWatcherConfig is an KubeClusterWatcher configuration. +type KubeClusterWatcherConfig struct { + // KubernetesGetter is responsible for fetching kube_cluster resources. + KubernetesClusterGetter + // KubeClustersC receives up-to-date list of all kube_cluster resources. + KubeClustersC chan []types.KubeCluster + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewKubeClusterWatcher returns a new instance of KubeClusterWatcher. +func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (*GenericWatcher[types.KubeCluster, types.ReadOnlyKubeCluster], error) { + if cfg.KubernetesClusterGetter == nil { + return nil, trace.BadParameter("KubernetesClusterGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeCluster, types.ReadOnlyKubeCluster]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindKubernetesCluster, + ResourceGetter: func(ctx context.Context) ([]types.KubeCluster, error) { + return cfg.KubernetesClusterGetter.GetKubernetesClusters(ctx) + }, + ResourceKey: types.KubeCluster.GetName, + ResourcesC: cfg.KubeClustersC, + CloneFunc: func(resource types.KubeCluster) types.KubeCluster { + return resource.Copy() + }, + }) + return w, trace.Wrap(err) +} + +// GenericWatcherConfig is a generic resource watcher configuration. +type GenericWatcherConfig[T any, R any] struct { + // ResourceGetter is used to directly fetch the current set of resources. + ResourceGetter func(context.Context) ([]T, error) + // ResourceDiffer is used to decide whether a put operation on an existing ResourceGetter should + // trigger an event. + ResourceDiffer func(old, new T) bool + // ResourceKey defines how the resources should be keyed. + ResourceKey func(resource T) string + // ResourcesC is a channel used to report the current resourxe set. It receives + // a fresh list at startup and subsequently a list of all known resourxes + // whenever an addition or deletion is detected. + ResourcesC chan []T + // CloneFunc defines how a resource is cloned. All resources provided via + // the broadcast mechanism, or retrieved via [GenericWatcer.CurrentResources] + // or [GenericWatcher.CurrentResourcesWithFilter] will be cloned by this + // mechanism before being provided to callers. + CloneFunc func(resource T) T + ResourceWatcherConfig + // ResourceKind specifies the kind of resource the watcher is monitoring. + ResourceKind string + // RequireResourcesForInitialBroadcast indicates whether an update should be + // performed if the initial set of resources is empty. + RequireResourcesForInitialBroadcast bool + // DisableUpdateBroadcast turns off emitting updates on changes. When this + // mode is opted into, users must invoke [GenericWatcher.CurrentResources] or + // [GenericWatcher.CurrentResourcesWithFilter] manually to retrieve the active + // resource set. + DisableUpdateBroadcast bool } // CheckAndSetDefaults checks parameters and sets default values. -func (cfg *ProxyWatcherConfig) CheckAndSetDefaults() error { +func (cfg *GenericWatcherConfig[T, R]) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } - if cfg.ProxyGetter == nil { - getter, ok := cfg.Client.(ProxyGetter) - if !ok { - return trace.BadParameter("missing parameter ProxyGetter and Client not usable as ProxyGetter") - } - cfg.ProxyGetter = getter + + if cfg.ResourceGetter == nil { + return trace.BadParameter("ResourceGetter not provided to generic resource watcher") + } + + if cfg.ResourceKind == "" { + return trace.BadParameter("ResourceKind not provided to generic resource watcher") } - if cfg.ProxiesC == nil { - cfg.ProxiesC = make(chan []types.Server) + + if cfg.ResourceKey == nil { + return trace.BadParameter("ResourceKey not provided to generic resource watcher") + } + + if cfg.ResourceDiffer == nil { + cfg.ResourceDiffer = func(T, T) bool { return true } + } + + if cfg.ResourcesC == nil { + cfg.ResourcesC = make(chan []T) } return nil } -// NewProxyWatcher returns a new instance of ProxyWatcher. -func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*ProxyWatcher, error) { +// NewGenericResourceWatcher returns a new instance of resource watcher. +func NewGenericResourceWatcher[T any, R any](ctx context.Context, cfg GenericWatcherConfig[T, R]) (*GenericWatcher[T, R], error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - collector := &proxyCollector{ - ProxyWatcherConfig: cfg, - initializationC: make(chan struct{}), + + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + Context: ctx, + TTL: 3 * time.Second, + Clock: cfg.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + collector := &genericCollector[T, R]{ + GenericWatcherConfig: cfg, + initializationC: make(chan struct{}), + cache: cache, } + collector.stale.Store(true) watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) if err != nil { return nil, trace.Wrap(err) } - return &ProxyWatcher{watcher, collector}, nil + return &GenericWatcher[T, R]{watcher, collector}, nil } -// ProxyWatcher is built on top of resourceWatcher to monitor additions -// and deletions to the set of proxies. -type ProxyWatcher struct { +// GenericWatcher is built on top of resourceWatcher to monitor additions +// and deletions to the set of resources. +type GenericWatcher[T any, R any] struct { *resourceWatcher - *proxyCollector + *genericCollector[T, R] +} + +// ResourceCount returns the current number of resources known to the watcher. +func (g *GenericWatcher[T, R]) ResourceCount() int { + g.rw.RLock() + defer g.rw.RUnlock() + return len(g.current) +} + +// CurrentResources returns a copy of the resources known to the watcher. +func (g *GenericWatcher[T, R]) CurrentResources(ctx context.Context) ([]T, error) { + if err := g.refreshStaleResources(ctx); err != nil { + return nil, trace.Wrap(err) + } + + g.rw.RLock() + defer g.rw.RUnlock() + + return resourcesToSlice(g.current, g.CloneFunc), nil } -// proxyCollector accompanies resourceWatcher when monitoring proxies. -type proxyCollector struct { - ProxyWatcherConfig - // current holds a map of the currently known proxies (keyed by server name, +// CurrentResourcesWithFilter returns a copy of the resources known to the watcher +// that match the provided filter. +func (g *GenericWatcher[T, R]) CurrentResourcesWithFilter(ctx context.Context, filter func(R) bool) ([]T, error) { + if err := g.refreshStaleResources(ctx); err != nil { + return nil, trace.Wrap(err) + } + + g.rw.RLock() + defer g.rw.RUnlock() + + r := func(a any) R { + return a.(R) + } + + var out []T + for _, resource := range g.current { + if filter(r(resource)) { + out = append(out, g.CloneFunc(resource)) + } + } + + return out, nil +} + +// genericCollector accompanies resourceWatcher when monitoring proxies. +type genericCollector[T any, R any] struct { + GenericWatcherConfig[T, R] + // current holds a map of the currently known resources (keyed by server name, // RWMutex protected). - current map[string]types.Server - rw sync.RWMutex + current map[string]T initializationC chan struct{} - once sync.Once + // cache is a helper for temporarily storing the results of CurrentResources. + // It's used to limit the number of calls to the backend. + cache *utils.FnCache + rw sync.RWMutex + once sync.Once + // stale is used to indicate that the watcher is stale and needs to be + // refreshed. + stale atomic.Bool +} + +// resourceKinds specifies the resource kind to watch. +func (p *genericCollector[T, R]) resourceKinds() []types.WatchKind { + return []types.WatchKind{{Kind: p.ResourceKind}} } -// GetCurrent returns the currently stored proxies. -func (p *proxyCollector) GetCurrent() []types.Server { - p.rw.RLock() - defer p.rw.RUnlock() - return serverMapValues(p.current) +// getResources gets the list of current resources. +func (g *genericCollector[T, R]) getResources(ctx context.Context) (map[string]T, error) { + resources, err := g.GenericWatcherConfig.ResourceGetter(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + current := make(map[string]T, len(resources)) + for _, resource := range resources { + current[g.GenericWatcherConfig.ResourceKey(resource)] = resource + } + return current, nil } -// resourceKinds specifies the resource kind to watch. -func (p *proxyCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindProxy}} +func (g *genericCollector[T, R]) refreshStaleResources(ctx context.Context) error { + if !g.stale.Load() { + return nil + } + + _, err := utils.FnCacheGet(ctx, g.cache, g.GenericWatcherConfig.ResourceKind, func(ctx context.Context) (any, error) { + current, err := g.getResources(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + // There is a chance that the watcher reinitialized while + // getting resources happened above. Check if we are still stale + if g.stale.CompareAndSwap(true, false) { + g.rw.Lock() + g.current = current + g.rw.Unlock() + } + + return nil, nil + }) + + return trace.Wrap(err) } // getResourcesAndUpdateCurrent is called when the resources should be // (re-)fetched directly. -func (p *proxyCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - proxies, err := p.ProxyGetter.GetProxies() +func (p *genericCollector[T, R]) getResourcesAndUpdateCurrent(ctx context.Context) error { + resources, err := p.ResourceGetter(ctx) if err != nil { return trace.Wrap(err) } - newCurrent := make(map[string]types.Server, len(proxies)) - for _, proxy := range proxies { - newCurrent[proxy.GetName()] = proxy + newCurrent := make(map[string]T, len(resources)) + for _, resource := range resources { + newCurrent[p.ResourceKey(resource)] = resource } p.rw.Lock() defer p.rw.Unlock() p.current = newCurrent - // only emit an empty proxy list if the collector has already been initialized - // to prevent an empty slice being sent out on creation of the watcher - if len(proxies) > 0 || (len(proxies) == 0 && p.isInitialized()) { + p.stale.Store(false) + // Only emit an empty set of resources if the watcher is already initialized, + // or if explicitly opted into by for the watcher. + if len(resources) > 0 || p.isInitialized() || + (!p.RequireResourcesForInitialBroadcast && len(resources) == 0) { p.broadcastUpdate(ctx) } p.defineCollectorAsInitialized() return nil } -func (p *proxyCollector) defineCollectorAsInitialized() { +func (p *genericCollector[T, R]) defineCollectorAsInitialized() { p.once.Do(func() { // mark watcher as initialized. close(p.initializationC) @@ -488,34 +768,35 @@ func (p *proxyCollector) defineCollectorAsInitialized() { } // processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *proxyCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { +func (p *genericCollector[T, R]) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { p.rw.Lock() defer p.rw.Unlock() var updated bool for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindProxy { + if event.Resource == nil || event.Resource.GetKind() != p.ResourceKind { p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) continue } switch event.Type { case types.OpDelete: - delete(p.current, event.Resource.GetName()) - // Always broadcast when a proxy is deleted. + // On delete events, the server description is populated with the host ID. + delete(p.current, event.Resource.GetMetadata().Description+event.Resource.GetName()) + // Always broadcast when a resource is deleted. updated = true case types.OpPut: - server, ok := event.Resource.(types.Server) + resource, ok := event.Resource.(T) if !ok { p.Logger.WarnContext(ctx, "Received unexpected type", "resource", event.Resource.GetKind()) continue } - current, exists := p.current[server.GetName()] - p.current[server.GetName()] = server - if !exists || (p.ProxyDiffer != nil && p.ProxyDiffer(current, server)) { - updated = true - } + + key := p.ResourceKey(resource) + current := p.current[key] + p.current[key] = resource + updated = p.ResourceDiffer(current, resource) default: p.Logger.WarnContext(ctx, "Skipping unsupported event type", "event_type", event.Type) } @@ -526,27 +807,31 @@ func (p *proxyCollector) processEventsAndUpdateCurrent(ctx context.Context, even } } -// broadcastUpdate broadcasts information about updating the proxy set. -func (p *proxyCollector) broadcastUpdate(ctx context.Context) { +// broadcastUpdate broadcasts information about updating the resource set. +func (p *genericCollector[T, R]) broadcastUpdate(ctx context.Context) { + if p.DisableUpdateBroadcast { + return + } + names := make([]string, 0, len(p.current)) for k := range p.current { names = append(names, k) } - p.Logger.DebugContext(ctx, "List of known proxies updated", "proxies", names) + p.Logger.DebugContext(ctx, "List of known resources updated", "resources", names) select { - case p.ProxiesC <- serverMapValues(p.current): + case p.ResourcesC <- resourcesToSlice(p.current, p.CloneFunc): case <-ctx.Done(): } } // isInitialized is used to check that the cache has done its initial // sync -func (p *proxyCollector) initializationChan() <-chan struct{} { +func (p *genericCollector[T, R]) initializationChan() <-chan struct{} { return p.initializationC } -func (p *proxyCollector) isInitialized() bool { +func (p *genericCollector[T, R]) isInitialized() bool { select { case <-p.initializationC: return true @@ -555,20 +840,14 @@ func (p *proxyCollector) isInitialized() bool { } } -func (p *proxyCollector) notifyStale() {} - -func serverMapValues(serverMap map[string]types.Server) []types.Server { - servers := make([]types.Server, 0, len(serverMap)) - for _, server := range serverMap { - servers = append(servers, server) - } - return servers +func (p *genericCollector[T, R]) notifyStale() { + p.stale.Store(true) } // LockWatcherConfig is a LockWatcher configuration. type LockWatcherConfig struct { - ResourceWatcherConfig LockGetter + ResourceWatcherConfig } // CheckAndSetDefaults checks parameters and sets default values. @@ -622,15 +901,15 @@ type lockCollector struct { LockWatcherConfig // current holds a map of the currently known locks (keyed by lock name). current map[string]types.Lock - // isStale indicates whether the local lock view (current) is stale. - isStale bool - // currentRW is a mutex protecting both current and isStale. - currentRW sync.RWMutex // fanout provides support for multiple subscribers to the lock updates. fanout *FanoutV2 // initializationC is used to check whether the initial sync has completed initializationC chan struct{} - once sync.Once + // currentRW is a mutex protecting both current and isStale. + currentRW sync.RWMutex + once sync.Once + // isStale indicates whether the local lock view (current) is stale. + isStale bool } // IsStale is used to check whether the lock watcher is stale. @@ -817,756 +1096,89 @@ func lockMapValues(lockMap map[string]types.Lock) []types.Lock { return locks } -// DatabaseWatcherConfig is a DatabaseWatcher configuration. -type DatabaseWatcherConfig struct { +func resourcesToSlice[T any](resources map[string]T, cloneFunc func(T) T) (slice []T) { + for _, resource := range resources { + slice = append(slice, cloneFunc(resource)) + } + return slice +} + +// CertAuthorityWatcherConfig is a CertAuthorityWatcher configuration. +type CertAuthorityWatcherConfig struct { // ResourceWatcherConfig is the resource watcher configuration. ResourceWatcherConfig - // DatabaseGetter is responsible for fetching database resources. - DatabaseGetter - // DatabasesC receives up-to-date list of all database resources. - DatabasesC chan types.Databases + // AuthorityGetter is responsible for fetching cert authority resources. + AuthorityGetter + // Types restricts which cert authority types are retrieved via the AuthorityGetter. + Types []types.CertAuthType } // CheckAndSetDefaults checks parameters and sets default values. -func (cfg *DatabaseWatcherConfig) CheckAndSetDefaults() error { +func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } - if cfg.DatabaseGetter == nil { - getter, ok := cfg.Client.(DatabaseGetter) + if cfg.AuthorityGetter == nil { + getter, ok := cfg.Client.(AuthorityGetter) if !ok { - return trace.BadParameter("missing parameter DatabaseGetter and Client not usable as DatabaseGetter") + return trace.BadParameter("missing parameter AuthorityGetter and Client not usable as AuthorityGetter") } - cfg.DatabaseGetter = getter + cfg.AuthorityGetter = getter } - if cfg.DatabasesC == nil { - cfg.DatabasesC = make(chan types.Databases) + if len(cfg.Types) == 0 { + return trace.BadParameter("missing parameter Types") } return nil } -// NewDatabaseWatcher returns a new instance of DatabaseWatcher. -func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*DatabaseWatcher, error) { +// NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher. +func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - collector := &databaseCollector{ - DatabaseWatcherConfig: cfg, - initializationC: make(chan struct{}), + + collector := &caCollector{ + CertAuthorityWatcherConfig: cfg, + fanout: NewFanoutV2(FanoutV2Config{ + Capacity: smallFanoutCapacity, + }), + cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)), + filter: make(types.CertAuthorityFilter, len(cfg.Types)), + initializationC: make(chan struct{}), + } + + for _, t := range cfg.Types { + collector.cas[t] = make(map[string]types.CertAuthority) + collector.filter[t] = types.Wildcard } + // Resource watcher require the fanout to be initialized before passing in. + // Otherwise, Emit() may fail due to a race condition mentioned in https://github.com/gravitational/teleport/issues/19289 + collector.fanout.SetInit(collector.resourceKinds()) watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) if err != nil { return nil, trace.Wrap(err) } - return &DatabaseWatcher{watcher, collector}, nil -} -// DatabaseWatcher is built on top of resourceWatcher to monitor database resources. -type DatabaseWatcher struct { - *resourceWatcher - *databaseCollector + return &CertAuthorityWatcher{watcher, collector}, nil } -// databaseCollector accompanies resourceWatcher when monitoring database resources. -type databaseCollector struct { - // DatabaseWatcherConfig is the watcher configuration. - DatabaseWatcherConfig - // current holds a map of the currently known database resources. - current map[string]types.Database - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check that the - initializationC chan struct{} - once sync.Once -} - -// resourceKinds specifies the resource kind to watch. -func (p *databaseCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindDatabase}} -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (p *databaseCollector) initializationChan() <-chan struct{} { - return p.initializationC -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (p *databaseCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - databases, err := p.DatabaseGetter.GetDatabases(ctx) - if err != nil { - return trace.Wrap(err) - } - newCurrent := make(map[string]types.Database, len(databases)) - for _, database := range databases { - newCurrent[database.GetName()] = database - } - p.lock.Lock() - defer p.lock.Unlock() - p.current = newCurrent - p.defineCollectorAsInitialized() - - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case p.DatabasesC <- databases: - } - - return nil -} - -func (p *databaseCollector) defineCollectorAsInitialized() { - p.once.Do(func() { - // mark watcher as initialized. - close(p.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *databaseCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - p.lock.Lock() - defer p.lock.Unlock() - - var updated bool - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindDatabase { - p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(p.current, event.Resource.GetName()) - updated = true - case types.OpPut: - database, ok := event.Resource.(types.Database) - if !ok { - p.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - p.current[database.GetName()] = database - updated = true - default: - p.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } - - if updated { - select { - case <-ctx.Done(): - case p.DatabasesC <- resourcesToSlice(p.current): - } - } -} - -func (*databaseCollector) notifyStale() {} - -// AppWatcherConfig is an AppWatcher configuration. -type AppWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // AppGetter is responsible for fetching application resources. - AppGetter - // AppsC receives up-to-date list of all application resources. - AppsC chan types.Apps -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *AppWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.AppGetter == nil { - getter, ok := cfg.Client.(AppGetter) - if !ok { - return trace.BadParameter("missing parameter AppGetter and Client not usable as AppGetter") - } - cfg.AppGetter = getter - } - if cfg.AppsC == nil { - cfg.AppsC = make(chan types.Apps) - } - return nil -} - -// NewAppWatcher returns a new instance of AppWatcher. -func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*AppWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - collector := &appCollector{ - AppWatcherConfig: cfg, - initializationC: make(chan struct{}), - } - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &AppWatcher{watcher, collector}, nil -} - -// AppWatcher is built on top of resourceWatcher to monitor application resources. -type AppWatcher struct { - *resourceWatcher - *appCollector -} - -// appCollector accompanies resourceWatcher when monitoring application resources. -type appCollector struct { - // AppWatcherConfig is the watcher configuration. - AppWatcherConfig - // current holds a map of the currently known application resources. - current map[string]types.Application - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once -} - -// resourceKinds specifies the resource kind to watch. -func (p *appCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindApp}} -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (p *appCollector) initializationChan() <-chan struct{} { - return p.initializationC -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (p *appCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - apps, err := p.AppGetter.GetApps(ctx) - if err != nil { - return trace.Wrap(err) - } - newCurrent := make(map[string]types.Application, len(apps)) - for _, app := range apps { - newCurrent[app.GetName()] = app - } - p.lock.Lock() - defer p.lock.Unlock() - p.current = newCurrent - p.defineCollectorAsInitialized() - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case p.AppsC <- apps: - } - return nil -} - -func (p *appCollector) defineCollectorAsInitialized() { - p.once.Do(func() { - // mark watcher as initialized. - close(p.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *appCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - p.lock.Lock() - defer p.lock.Unlock() - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindApp { - p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(p.current, event.Resource.GetName()) - p.AppsC <- resourcesToSlice(p.current) - - select { - case <-ctx.Done(): - case p.AppsC <- resourcesToSlice(p.current): - } - - case types.OpPut: - app, ok := event.Resource.(types.Application) - if !ok { - p.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - p.current[app.GetName()] = app - - select { - case <-ctx.Done(): - case p.AppsC <- resourcesToSlice(p.current): - } - default: - p.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } -} - -func (*appCollector) notifyStale() {} - -func resourcesToSlice[T any](resources map[string]T) (slice []T) { - for _, resource := range resources { - slice = append(slice, resource) - } - return slice -} - -// KubeClusterWatcherConfig is an KubeClusterWatcher configuration. -type KubeClusterWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // KubernetesGetter is responsible for fetching kube_cluster resources. - KubernetesClusterGetter - // KubeClustersC receives up-to-date list of all kube_cluster resources. - KubeClustersC chan types.KubeClusters -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *KubeClusterWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.KubernetesClusterGetter == nil { - getter, ok := cfg.Client.(KubernetesClusterGetter) - if !ok { - return trace.BadParameter("missing parameter KubernetesGetter and Client not usable as KubernetesGetter") - } - cfg.KubernetesClusterGetter = getter - } - if cfg.KubeClustersC == nil { - cfg.KubeClustersC = make(chan types.KubeClusters) - } - return nil -} - -// NewKubeClusterWatcher returns a new instance of KubeClusterWatcher. -func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (*KubeClusterWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - collector := &kubeCollector{ - KubeClusterWatcherConfig: cfg, - initializationC: make(chan struct{}), - } - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &KubeClusterWatcher{watcher, collector}, nil -} - -// KubeClusterWatcher is built on top of resourceWatcher to monitor kube_cluster resources. -type KubeClusterWatcher struct { - *resourceWatcher - *kubeCollector -} - -// kubeCollector accompanies resourceWatcher when monitoring kube_cluster resources. -type kubeCollector struct { - // KubeClusterWatcherConfig is the watcher configuration. - KubeClusterWatcherConfig - // current holds a map of the currently known kube_cluster resources. - current map[string]types.KubeCluster - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (k *kubeCollector) initializationChan() <-chan struct{} { - return k.initializationC -} - -// resourceKinds specifies the resource kind to watch. -func (k *kubeCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindKubernetesCluster}} -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (k *kubeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - clusters, err := k.KubernetesClusterGetter.GetKubernetesClusters(ctx) - if err != nil { - return trace.Wrap(err) - } - newCurrent := make(map[string]types.KubeCluster, len(clusters)) - for _, cluster := range clusters { - newCurrent[cluster.GetName()] = cluster - } - k.lock.Lock() - defer k.lock.Unlock() - k.current = newCurrent - - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case k.KubeClustersC <- clusters: - } - - k.defineCollectorAsInitialized() - - return nil -} - -func (k *kubeCollector) defineCollectorAsInitialized() { - k.once.Do(func() { - // mark watcher as initialized. - close(k.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (k *kubeCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - k.lock.Lock() - defer k.lock.Unlock() - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindKubernetesCluster { - k.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(k.current, event.Resource.GetName()) - k.KubeClustersC <- resourcesToSlice(k.current) - - select { - case <-ctx.Done(): - case k.KubeClustersC <- resourcesToSlice(k.current): - } - - case types.OpPut: - cluster, ok := event.Resource.(types.KubeCluster) - if !ok { - k.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - k.current[cluster.GetName()] = cluster - - select { - case <-ctx.Done(): - case k.KubeClustersC <- resourcesToSlice(k.current): - } - default: - k.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } -} - -func (*kubeCollector) notifyStale() {} - -// KubeServerWatcherConfig is an KubeServerWatcher configuration. -type KubeServerWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // KubernetesServerGetter is responsible for fetching kube_server resources. - KubernetesServerGetter -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *KubeServerWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.KubernetesServerGetter == nil { - getter, ok := cfg.Client.(KubernetesServerGetter) - if !ok { - return trace.BadParameter("missing parameter KubernetesServerGetter and Client not usable as KubernetesServerGetter") - } - cfg.KubernetesServerGetter = getter - } - return nil -} - -// NewKubeServerWatcher returns a new instance of KubeServerWatcher. -func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*KubeServerWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - cache, err := utils.NewFnCache(utils.FnCacheConfig{ - Context: ctx, - TTL: 3 * time.Second, - Clock: cfg.Clock, - }) - if err != nil { - return nil, trace.Wrap(err) - } - collector := &kubeServerCollector{ - KubeServerWatcherConfig: cfg, - initializationC: make(chan struct{}), - cache: cache, - } - // start the collector as staled. - collector.stale.Store(true) - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &KubeServerWatcher{watcher, collector}, nil -} - -// KubeServerWatcher is built on top of resourceWatcher to monitor kube_server resources. -type KubeServerWatcher struct { - *resourceWatcher - *kubeServerCollector -} - -// GetKubeServersByClusterName returns a list of kubernetes servers for the specified cluster. -func (k *KubeServerWatcher) GetKubeServersByClusterName(ctx context.Context, clusterName string) ([]types.KubeServer, error) { - k.refreshStaleKubeServers(ctx) - - k.lock.RLock() - defer k.lock.RUnlock() - var servers []types.KubeServer - for _, server := range k.current { - if server.GetCluster().GetName() == clusterName { - servers = append(servers, server.Copy()) - } - } - if len(servers) == 0 { - return nil, trace.NotFound("no kubernetes servers found for cluster %q", clusterName) - } - - return servers, nil -} - -// GetKubernetesServers returns a list of kubernetes servers for all clusters. -func (k *KubeServerWatcher) GetKubernetesServers(ctx context.Context) ([]types.KubeServer, error) { - k.refreshStaleKubeServers(ctx) - - k.lock.RLock() - defer k.lock.RUnlock() - servers := make([]types.KubeServer, 0, len(k.current)) - for _, server := range k.current { - servers = append(servers, server.Copy()) - } - return servers, nil -} - -// kubeServerCollector accompanies resourceWatcher when monitoring kube_server resources. -type kubeServerCollector struct { - // KubeServerWatcherConfig is the watcher configuration. - KubeServerWatcherConfig - // current holds a map of the currently known kube_server resources. - current map[kubeServersKey]types.KubeServer - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once - // stale is used to indicate that the watcher is stale and needs to be - // refreshed. - stale atomic.Bool - // cache is a helper for temporarily storing the results of GetKubernetesServers. - // It's used to limit the amount of calls to the backend. - cache *utils.FnCache -} - -// kubeServersKey is used to uniquely identify a kube_server resource. -type kubeServersKey struct { - hostID string - resourceName string -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (k *kubeServerCollector) initializationChan() <-chan struct{} { - return k.initializationC -} - -// resourceKinds specifies the resource kind to watch. -func (k *kubeServerCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindKubeServer}} -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (k *kubeServerCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - newCurrent, err := k.getResources(ctx) - if err != nil { - return trace.Wrap(err) - } - - k.lock.Lock() - k.current = newCurrent - k.lock.Unlock() - - k.stale.Store(false) - - k.defineCollectorAsInitialized() - return nil -} - -// getResourcesAndUpdateCurrent gets the list of current resources. -func (k *kubeServerCollector) getResources(ctx context.Context) (map[kubeServersKey]types.KubeServer, error) { - servers, err := k.KubernetesServerGetter.GetKubernetesServers(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - current := make(map[kubeServersKey]types.KubeServer, len(servers)) - for _, server := range servers { - key := kubeServersKey{ - hostID: server.GetHostID(), - resourceName: server.GetName(), - } - current[key] = server - } - return current, nil -} - -func (k *kubeServerCollector) defineCollectorAsInitialized() { - k.once.Do(func() { - // mark watcher as initialized. - close(k.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (k *kubeServerCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - k.lock.Lock() - defer k.lock.Unlock() - - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindKubeServer { - k.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - - switch event.Type { - case types.OpDelete: - key := kubeServersKey{ - // On delete events, the server description is populated with the host ID. - hostID: event.Resource.GetMetadata().Description, - resourceName: event.Resource.GetName(), - } - delete(k.current, key) - case types.OpPut: - server, ok := event.Resource.(types.KubeServer) - if !ok { - k.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - - key := kubeServersKey{ - hostID: server.GetHostID(), - resourceName: server.GetName(), - } - k.current[key] = server - default: - k.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } -} - -func (k *kubeServerCollector) notifyStale() { - k.stale.Store(true) -} - -// refreshStaleKubeServers attempts to reload kube servers from the cache if -// the collector is stale. This ensures that no matter the health of -// the collector callers will be returned the most up to date node -// set as possible. -func (k *kubeServerCollector) refreshStaleKubeServers(ctx context.Context) error { - if !k.stale.Load() { - return nil - } - - _, err := utils.FnCacheGet(ctx, k.cache, "kube_servers", func(ctx context.Context) (any, error) { - current, err := k.getResources(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - // There is a chance that the watcher reinitialized while - // getting kube servers happened above. Check if we are still stale - if k.stale.CompareAndSwap(true, false) { - k.lock.Lock() - k.current = current - k.lock.Unlock() - } - - return nil, nil - }) - - return trace.Wrap(err) -} - -// CertAuthorityWatcherConfig is a CertAuthorityWatcher configuration. -type CertAuthorityWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // AuthorityGetter is responsible for fetching cert authority resources. - AuthorityGetter - // Types restricts which cert authority types are retrieved via the AuthorityGetter. - Types []types.CertAuthType -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.AuthorityGetter == nil { - getter, ok := cfg.Client.(AuthorityGetter) - if !ok { - return trace.BadParameter("missing parameter AuthorityGetter and Client not usable as AuthorityGetter") - } - cfg.AuthorityGetter = getter - } - if len(cfg.Types) == 0 { - return trace.BadParameter("missing parameter Types") - } - return nil -} - -// NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher. -func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - collector := &caCollector{ - CertAuthorityWatcherConfig: cfg, - fanout: NewFanoutV2(FanoutV2Config{ - Capacity: smallFanoutCapacity, - }), - cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)), - filter: make(types.CertAuthorityFilter, len(cfg.Types)), - initializationC: make(chan struct{}), - } - - for _, t := range cfg.Types { - collector.cas[t] = make(map[string]types.CertAuthority) - collector.filter[t] = types.Wildcard - } - // Resource watcher require the fanout to be initialized before passing in. - // Otherwise, Emit() may fail due to a race condition mentioned in https://github.com/gravitational/teleport/issues/19289 - collector.fanout.SetInit(collector.resourceKinds()) - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - return &CertAuthorityWatcher{watcher, collector}, nil -} - -// CertAuthorityWatcher is built on top of resourceWatcher to monitor cert authority resources. -type CertAuthorityWatcher struct { - *resourceWatcher - *caCollector +// CertAuthorityWatcher is built on top of resourceWatcher to monitor cert authority resources. +type CertAuthorityWatcher struct { + *resourceWatcher + *caCollector } // caCollector accompanies resourceWatcher when monitoring cert authority resources. type caCollector struct { - CertAuthorityWatcherConfig fanout *FanoutV2 - - // lock protects concurrent access to cas - lock sync.RWMutex - // cas maps ca type -> cluster -> ca - cas map[types.CertAuthType]map[string]types.CertAuthority + cas map[types.CertAuthType]map[string]types.CertAuthority // initializationC is used to check whether the initial sync has completed initializationC chan struct{} - once sync.Once filter types.CertAuthorityFilter + CertAuthorityWatcherConfig + // lock protects concurrent access to cas + lock sync.RWMutex + once sync.Once } // Subscribe is used to subscribe to the lock updates. @@ -1703,285 +1315,40 @@ func (c *caCollector) notifyStale() {} // NodeWatcherConfig is a NodeWatcher configuration. type NodeWatcherConfig struct { - ResourceWatcherConfig // NodesGetter is used to directly fetch the list of active nodes. NodesGetter -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *NodeWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.NodesGetter == nil { - getter, ok := cfg.Client.(NodesGetter) - if !ok { - return trace.BadParameter("missing parameter NodesGetter and Client not usable as NodesGetter") - } - cfg.NodesGetter = getter - } - return nil + ResourceWatcherConfig } // NewNodeWatcher returns a new instance of NodeWatcher. -func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*NodeWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - cache, err := utils.NewFnCache(utils.FnCacheConfig{ - Context: ctx, - TTL: 3 * time.Second, - Clock: cfg.Clock, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - collector := &nodeCollector{ - NodeWatcherConfig: cfg, - current: map[string]types.Server{}, - initializationC: make(chan struct{}), - cache: cache, - stale: true, - } - - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - return &NodeWatcher{resourceWatcher: watcher, nodeCollector: collector}, nil -} - -// NodeWatcher is built on top of resourceWatcher to monitor additions -// and deletions to the set of nodes. -type NodeWatcher struct { - *resourceWatcher - *nodeCollector -} - -// nodeCollector accompanies resourceWatcher when monitoring nodes. -type nodeCollector struct { - NodeWatcherConfig - - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once - - cache *utils.FnCache - - rw sync.RWMutex - // current holds a map of the currently known nodes keyed by server name - current map[string]types.Server - stale bool -} - -// Node is a readonly subset of the types.Server interface which -// users may filter by in GetNodes. -type Node interface { - // ResourceWithLabels provides common resource headers - types.ResourceWithLabels - // GetTeleportVersion returns the teleport version the server is running on - GetTeleportVersion() string - // GetAddr return server address - GetAddr() string - // GetPublicAddrs returns all public addresses where this server can be reached. - GetPublicAddrs() []string - // GetHostname returns server hostname - GetHostname() string - // GetNamespace returns server namespace - GetNamespace() string - // GetCmdLabels gets command labels - GetCmdLabels() map[string]types.CommandLabel - // GetRotation gets the state of certificate authority rotation. - GetRotation() types.Rotation - // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. - GetUseTunnel() bool - // GetProxyIDs returns a list of proxy ids this server is connected to. - GetProxyIDs() []string - // IsEICE returns whether the Node is an EICE instance. - // Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels). - IsEICE() bool -} - -// GetNodes allows callers to retrieve a subset of nodes that match the filter provided. The -// returned servers are a copy and can be safely modified. It is intentionally hard to retrieve -// the full set of nodes to reduce the number of copies needed since the number of nodes can get -// quite large and doing so can be expensive. -func (n *nodeCollector) GetNodes(ctx context.Context, fn func(n Node) bool) []types.Server { - // Attempt to freshen our data first. - n.refreshStaleNodes(ctx) - - n.rw.RLock() - defer n.rw.RUnlock() - - var matched []types.Server - for _, server := range n.current { - if fn(server) { - matched = append(matched, server.DeepCopy()) - } - } - - return matched -} - -// GetNode allows callers to retrieve a node based on its name. The -// returned server are a copy and can be safely modified. -func (n *nodeCollector) GetNode(ctx context.Context, name string) (types.Server, error) { - // Attempt to freshen our data first. - n.refreshStaleNodes(ctx) - - n.rw.RLock() - defer n.rw.RUnlock() - - server, found := n.current[name] - if !found { - return nil, trace.NotFound("server does not exist") - } - return server.DeepCopy(), nil -} - -// refreshStaleNodes attempts to reload nodes from the NodeGetter if -// the collecter is stale. This ensures that no matter the health of -// the collecter callers will be returned the most up to date node -// set as possible. -func (n *nodeCollector) refreshStaleNodes(ctx context.Context) error { - n.rw.RLock() - if !n.stale { - n.rw.RUnlock() - return nil - } - n.rw.RUnlock() - - _, err := utils.FnCacheGet(ctx, n.cache, "nodes", func(ctx context.Context) (any, error) { - current, err := n.getNodes(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - n.rw.Lock() - defer n.rw.Unlock() - - // There is a chance that the watcher reinitialized while - // getting nodes happened above. Check if we are still stale - // now that the lock is held to ensure that the refresh is - // still necessary. - if !n.stale { - return nil, nil - } - - n.current = current - return nil, trace.Wrap(err) - }) - - return trace.Wrap(err) -} - -func (n *nodeCollector) NodeCount() int { - n.rw.RLock() - defer n.rw.RUnlock() - return len(n.current) -} - -// resourceKinds specifies the resource kind to watch. -func (n *nodeCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindNode}} -} - -// getResourcesAndUpdateCurrent is called when the resources should be -// (re-)fetched directly. -func (n *nodeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - newCurrent, err := n.getNodes(ctx) - if err != nil { - return trace.Wrap(err) - } - defer n.defineCollectorAsInitialized() - - if len(newCurrent) == 0 { - return nil - } - - n.rw.Lock() - defer n.rw.Unlock() - n.current = newCurrent - n.stale = false - return nil -} - -func (n *nodeCollector) getNodes(ctx context.Context) (map[string]types.Server, error) { - nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) - if err != nil { - return nil, trace.Wrap(err) - } - - if len(nodes) == 0 { - return map[string]types.Server{}, nil - } - - current := make(map[string]types.Server, len(nodes)) - for _, node := range nodes { - current[node.GetName()] = node +func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*GenericWatcher[types.Server, types.ReadOnlyServer], error) { + if cfg.NodesGetter == nil { + return nil, trace.BadParameter("NodesGetter must be provided") } - return current, nil -} - -func (n *nodeCollector) defineCollectorAsInitialized() { - n.once.Do(func() { - // mark watcher as initialized. - close(n.initializationC) + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, types.ReadOnlyServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindNode, + ResourceGetter: func(ctx context.Context) ([]types.Server, error) { + return cfg.NodesGetter.GetNodes(ctx, apidefaults.Namespace) + }, + ResourceKey: types.Server.GetName, + DisableUpdateBroadcast: true, + CloneFunc: types.Server.DeepCopy, }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (n *nodeCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - n.rw.Lock() - defer n.rw.Unlock() - - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindNode { - n.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - - switch event.Type { - case types.OpDelete: - delete(n.current, event.Resource.GetName()) - case types.OpPut: - server, ok := event.Resource.(types.Server) - if !ok { - n.Logger.WarnContext(ctx, "Received unexpected type", "resource", event.Resource.GetKind()) - continue - } - - n.current[server.GetName()] = server - default: - n.Logger.WarnContext(ctx, "Skipping unsupported event type", "event_type", event.Type) - } - } -} - -func (n *nodeCollector) initializationChan() <-chan struct{} { - return n.initializationC -} - -func (n *nodeCollector) notifyStale() { - n.rw.Lock() - defer n.rw.Unlock() - n.stale = true + return w, trace.Wrap(err) } // AccessRequestWatcherConfig is a AccessRequestWatcher configuration. type AccessRequestWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // AccessRequestGetter is responsible for fetching access request resources. AccessRequestGetter - // Filter is the filter to use to monitor access requests. - Filter types.AccessRequestFilter // AccessRequestsC receives up-to-date list of all access request resources. AccessRequestsC chan types.AccessRequests + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig + // Filter is the filter to use to monitor access requests. + Filter types.AccessRequestFilter } // CheckAndSetDefaults checks parameters and sets default values. @@ -2030,11 +1397,11 @@ type accessRequestCollector struct { AccessRequestWatcherConfig // current holds a map of the currently known access request resources. current map[string]types.AccessRequest - // lock protects the "current" map. - lock sync.RWMutex // initializationC is used to check that the watcher has been initialized properly. initializationC chan struct{} - once sync.Once + // lock protects the "current" map. + lock sync.RWMutex + once sync.Once } // resourceKinds specifies the resource kind to watch. @@ -2094,7 +1461,7 @@ func (p *accessRequestCollector) processEventsAndUpdateCurrent(ctx context.Conte delete(p.current, event.Resource.GetName()) select { case <-ctx.Done(): - case p.AccessRequestsC <- resourcesToSlice(p.current): + case p.AccessRequestsC <- resourcesToSlice(p.current, types.AccessRequest.Copy): } case types.OpPut: accessRequest, ok := event.Resource.(types.AccessRequest) @@ -2105,7 +1472,7 @@ func (p *accessRequestCollector) processEventsAndUpdateCurrent(ctx context.Conte p.current[accessRequest.GetName()] = accessRequest select { case <-ctx.Done(): - case p.AccessRequestsC <- resourcesToSlice(p.current): + case p.AccessRequestsC <- resourcesToSlice(p.current, types.AccessRequest.Copy): } default: @@ -2118,14 +1485,14 @@ func (*accessRequestCollector) notifyStale() {} // OktaAssignmentWatcherConfig is a OktaAssignmentWatcher configuration. type OktaAssignmentWatcherConfig struct { - // RWCfg is the resource watcher configuration. - RWCfg ResourceWatcherConfig // OktaAssignments is responsible for fetching Okta assignments. OktaAssignments OktaAssignmentsGetter - // PageSize is the number of Okta assignments to list at a time. - PageSize int // OktaAssignmentsC receives up-to-date list of all Okta assignment resources. OktaAssignmentsC chan types.OktaAssignments + // RWCfg is the resource watcher configuration. + RWCfg ResourceWatcherConfig + // PageSize is the number of Okta assignments to list at a time. + PageSize int } // CheckAndSetDefaults checks parameters and sets default values. @@ -2190,16 +1557,16 @@ func (o *OktaAssignmentWatcher) Done() <-chan struct{} { // oktaAssignmentCollector accompanies resourceWatcher when monitoring Okta assignment resources. type oktaAssignmentCollector struct { - logger *slog.Logger // OktaAssignmentWatcherConfig is the watcher configuration. - cfg OktaAssignmentWatcherConfig - // mu guards "current" - mu sync.RWMutex + cfg OktaAssignmentWatcherConfig + logger *slog.Logger // current holds a map of the currently known Okta assignment resources. current map[string]types.OktaAssignment // initializationC is used to check that the watcher has been initialized properly. initializationC chan struct{} - once sync.Once + // mu guards "current" + mu sync.RWMutex + once sync.Once } // resourceKinds specifies the resource kind to watch. @@ -2267,7 +1634,7 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont switch event.Type { case types.OpDelete: delete(c.current, event.Resource.GetName()) - resources := resourcesToSlice(c.current) + resources := resourcesToSlice(c.current, types.OktaAssignment.Copy) select { case <-ctx.Done(): case c.cfg.OktaAssignmentsC <- resources: @@ -2279,7 +1646,7 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont continue } c.current[oktaAssignment.GetName()] = oktaAssignment - resources := resourcesToSlice(c.current) + resources := resourcesToSlice(c.current, types.OktaAssignment.Copy) select { case <-ctx.Done(): diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 3ffe202bb708..64c3dc38f536 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -131,7 +131,8 @@ func TestProxyWatcher(t *testing.T) { Events: local.NewEventsService(bk), }, }, - ProxiesC: make(chan []types.Server, 10), + ProxyGetter: presence, + ProxiesC: make(chan []types.Server, 10), }) require.NoError(t, err) t.Cleanup(w.Close) @@ -143,7 +144,7 @@ func TestProxyWatcher(t *testing.T) { // The first event is always the current list of proxies. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], proxy)) case <-w.Done(): @@ -158,7 +159,7 @@ func TestProxyWatcher(t *testing.T) { // Watcher should detect the proxy list change. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 2) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -171,7 +172,7 @@ func TestProxyWatcher(t *testing.T) { // Watcher should detect the proxy list change. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], proxy2)) case <-w.Done(): @@ -185,7 +186,7 @@ func TestProxyWatcher(t *testing.T) { // Watcher should detect the proxy list change. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Empty(t, changeset) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -562,14 +563,15 @@ func TestDatabaseWatcher(t *testing.T) { Events: local.NewEventsService(bk), }, }, - DatabasesC: make(chan types.Databases, 10), + DatabaseGetter: databasesService, + DatabasesC: make(chan []types.Database, 10), }) require.NoError(t, err) t.Cleanup(w.Close) // Initially there are no databases so watcher should send an empty list. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Empty(t, changeset) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -583,7 +585,7 @@ func TestDatabaseWatcher(t *testing.T) { // The first event is always the current list of databases. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], database1)) case <-w.Done(): @@ -598,7 +600,7 @@ func TestDatabaseWatcher(t *testing.T) { // Watcher should detect the database list change. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 2) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -611,7 +613,7 @@ func TestDatabaseWatcher(t *testing.T) { // Watcher should detect the database list change. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], database2)) case <-w.Done(): @@ -661,14 +663,15 @@ func TestAppWatcher(t *testing.T) { Events: local.NewEventsService(bk), }, }, - AppsC: make(chan types.Apps, 10), + AppGetter: appService, + AppsC: make(chan []types.Application, 10), }) require.NoError(t, err) t.Cleanup(w.Close) // Initially there are no apps so watcher should send an empty list. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Empty(t, changeset) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -682,7 +685,7 @@ func TestAppWatcher(t *testing.T) { // The first event is always the current list of apps. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], app1)) case <-w.Done(): @@ -697,7 +700,7 @@ func TestAppWatcher(t *testing.T) { // Watcher should detect the app list change. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 2) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -710,7 +713,7 @@ func TestAppWatcher(t *testing.T) { // Watcher should detect the database list change. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], app2)) case <-w.Done(): @@ -909,6 +912,7 @@ func TestNodeWatcherFallback(t *testing.T) { }, MaxStaleness: time.Minute, }, + NodesGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -922,15 +926,14 @@ func TestNodeWatcherFallback(t *testing.T) { nodes = append(nodes, node) } - require.Empty(t, w.NodeCount()) + require.Empty(t, w.ResourceCount()) require.False(t, w.IsInitialized()) - got := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) + got, err := w.CurrentResources(ctx) + require.NoError(t, err) require.Len(t, nodes, len(got)) - require.Len(t, nodes, w.NodeCount()) + require.Len(t, nodes, w.ResourceCount()) require.False(t, w.IsInitialized()) } @@ -961,6 +964,7 @@ func TestNodeWatcher(t *testing.T) { }, MaxStaleness: time.Minute, }, + NodesGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -974,25 +978,27 @@ func TestNodeWatcher(t *testing.T) { nodes = append(nodes, node) } - require.Eventually(t, func() bool { - filtered := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) - return len(filtered) == len(nodes) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(nodes)) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") - require.Len(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetUseTunnel() }), 3) + filtered, err := w.CurrentResourcesWithFilter(ctx, func(n types.ReadOnlyServer) bool { return n.GetUseTunnel() }) + require.NoError(t, err) + require.Len(t, filtered, 3) require.NoError(t, presence.DeleteNode(ctx, apidefaults.Namespace, nodes[0].GetName())) - require.Eventually(t, func() bool { - filtered := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) - return len(filtered) == len(nodes)-1 + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(nodes)-1) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") - require.Empty(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetName() == nodes[0].GetName() })) + filtered, err = w.CurrentResourcesWithFilter(ctx, func(n types.ReadOnlyServer) bool { return n.GetName() == nodes[0].GetName() }) + require.NoError(t, err) + require.Empty(t, filtered) } func newNodeServer(t *testing.T, name, hostname, addr string, tunnel bool) types.Server { @@ -1032,6 +1038,7 @@ func TestKubeServerWatcher(t *testing.T) { }, MaxStaleness: time.Minute, }, + KubernetesServerGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -1057,55 +1064,66 @@ func TestKubeServerWatcher(t *testing.T) { kubeServers = append(kubeServers, kubeServer) } - require.Eventually(t, func() bool { - filtered, err := w.GetKubernetesServers(context.Background()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(context.Background()) assert.NoError(t, err) - return len(filtered) == len(kubeServers) + assert.Len(t, filtered, len(kubeServers)) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive kube servers.") // Test filtering by cluster name. - filtered, err := w.GetKubeServersByClusterName(context.Background(), kubeServers[0].GetName()) + filtered, err := w.CurrentResourcesWithFilter(context.Background(), func(ks types.ReadOnlyKubeServer) bool { + return ks.GetName() == kubeServers[0].GetName() + }) require.NoError(t, err) require.Len(t, filtered, 1) // Test Deleting a kube server. require.NoError(t, presence.DeleteKubernetesServer(ctx, kubeServers[0].GetHostID(), kubeServers[0].GetName())) - require.Eventually(t, func() bool { - kube, err := w.GetKubernetesServers(context.Background()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + kube, err := w.CurrentResources(context.Background()) assert.NoError(t, err) - return len(kube) == len(kubeServers)-1 + assert.Len(t, kube, len(kubeServers)-1) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive the delete event.") - filtered, err = w.GetKubeServersByClusterName(context.Background(), kubeServers[0].GetName()) - require.Error(t, err) + filtered, err = w.CurrentResourcesWithFilter(context.Background(), func(ks types.ReadOnlyKubeServer) bool { + return ks.GetName() == kubeServers[0].GetName() + }) + require.NoError(t, err) require.Empty(t, filtered) // Test adding a kube server with the same name as an existing one. kubeServer := newKubeServer(t, kubeServers[1].GetName(), "addr", uuid.NewString()) _, err = presence.UpsertKubernetesServer(ctx, kubeServer) require.NoError(t, err) - require.Eventually(t, func() bool { - filtered, err := w.GetKubeServersByClusterName(context.Background(), kubeServers[1].GetName()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResourcesWithFilter(context.Background(), func(ks types.ReadOnlyKubeServer) bool { + return ks.GetName() == kubeServers[1].GetName() + }) assert.NoError(t, err) - return len(filtered) == 2 - }, time.Second, time.Millisecond, "Timeout waiting for watcher to the new registered kube server.") + assert.Len(t, filtered, 2) + }, 1000*time.Second, time.Millisecond, "Timeout waiting for watcher to the new registered kube server.") // Test deleting all kube servers with the same name. - filtered, err = w.GetKubeServersByClusterName(context.Background(), kubeServers[1].GetName()) + filtered, err = w.CurrentResourcesWithFilter(context.Background(), func(ks types.ReadOnlyKubeServer) bool { + return ks.GetName() == kubeServers[1].GetName() + }) assert.NoError(t, err) for _, server := range filtered { require.NoError(t, presence.DeleteKubernetesServer(ctx, server.GetHostID(), server.GetName())) } - require.Eventually(t, func() bool { - filtered, err := w.GetKubeServersByClusterName(context.Background(), kubeServers[1].GetName()) - return len(filtered) == 0 && err != nil + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResourcesWithFilter(context.Background(), func(ks types.ReadOnlyKubeServer) bool { + return ks.GetName() == kubeServers[1].GetName() + }) + assert.NoError(t, err) + assert.Empty(t, filtered) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive the two delete events.") require.NoError(t, presence.DeleteAllKubernetesServers(ctx)) - require.Eventually(t, func() bool { - filtered, err := w.GetKubernetesServers(context.Background()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(context.Background()) assert.NoError(t, err) - return len(filtered) == 0 + assert.Empty(t, filtered) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive all delete events.") } diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index 0a248bc94101..c1bcbe3a52e5 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -151,7 +151,7 @@ type Server struct { reconcileCh chan struct{} // watcher monitors changes to application resources. - watcher *services.AppWatcher + watcher *services.GenericWatcher[types.Application, types.ReadOnlyApplication] } // monitoredApps is a collection of applications from different sources diff --git a/lib/srv/app/watcher.go b/lib/srv/app/watcher.go index fb0acc2bfad7..3d91a68b8a7a 100644 --- a/lib/srv/app/watcher.go +++ b/lib/srv/app/watcher.go @@ -65,7 +65,7 @@ func (s *Server) startReconciler(ctx context.Context) error { // startResourceWatcher starts watching changes to application resources and // registers/unregisters the proxied applications accordingly. -func (s *Server) startResourceWatcher(ctx context.Context) (*services.AppWatcher, error) { +func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Application, types.ReadOnlyApplication], error) { if len(s.c.ResourceMatchers) == 0 { s.log.Debug("Not initializing application resource watcher.") return nil, nil @@ -78,6 +78,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.AppWatcher // Log: s.log, Client: s.c.AccessPoint, }, + AppGetter: s.c.AccessPoint, }) if err != nil { return nil, trace.Wrap(err) @@ -86,7 +87,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.AppWatcher defer watcher.Close() for { select { - case apps := <-watcher.AppsC: + case apps := <-watcher.ResourcesC: appsWithAddr := make(types.Apps, 0, len(apps)) for _, app := range apps { appsWithAddr = append(appsWithAddr, s.guessPublicAddr(app)) diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 90b93a688ad5..62c5eb98c067 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -314,7 +314,7 @@ type Server struct { // heartbeats holds heartbeats for database servers. heartbeats map[string]srv.HeartbeatI // watcher monitors changes to database resources. - watcher *services.DatabaseWatcher + watcher *services.GenericWatcher[types.Database, types.ReadOnlyDatabase] // proxiedDatabases contains databases this server currently is proxying. // Proxied databases are reconciled against monitoredDatabases below. proxiedDatabases map[string]types.Database diff --git a/lib/srv/db/watcher.go b/lib/srv/db/watcher.go index a3313a90792e..7267472a8926 100644 --- a/lib/srv/db/watcher.go +++ b/lib/srv/db/watcher.go @@ -69,7 +69,7 @@ func (s *Server) startReconciler(ctx context.Context) error { // startResourceWatcher starts watching changes to database resources and // registers/unregisters the proxied databases accordingly. -func (s *Server) startResourceWatcher(ctx context.Context) (*services.DatabaseWatcher, error) { +func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Database, types.ReadOnlyDatabase], error) { if len(s.cfg.ResourceMatchers) == 0 { s.log.DebugContext(ctx, "Not starting database resource watcher.") return nil, nil @@ -81,6 +81,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.DatabaseWa Logger: s.log, Client: s.cfg.AccessPoint, }, + DatabaseGetter: s.cfg.AccessPoint, }) if err != nil { return nil, trace.Wrap(err) @@ -90,7 +91,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.DatabaseWa defer watcher.Close() for { select { - case databases := <-watcher.DatabasesC: + case databases := <-watcher.ResourcesC: s.monitoredDatabases.setResources(databases) select { case s.reconcileCh <- struct{}{}: diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index d1b490bbe083..86aadd4592e1 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -267,7 +267,7 @@ type Server struct { // cancelfn is used with ctx when stopping the discovery server cancelfn context.CancelFunc // nodeWatcher is a node watcher. - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // ec2Watcher periodically retrieves EC2 instances. ec2Watcher *server.Watcher @@ -777,13 +777,16 @@ func (s *Server) initGCPWatchers(ctx context.Context, matchers []types.GCPMatche return nil } -func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n types.ReadOnlyServer) bool { labels := n.GetAllLabels() _, accountOK := labels[types.AWSAccountIDLabel] _, instanceOK := labels[types.AWSInstanceIDLabel] return accountOK && instanceOK }) + if err != nil { + return trace.Wrap(err) + } var filtered []server.EC2Instance outer: @@ -800,6 +803,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func genEC2InstancesLogStr(instances []server.EC2Instance) string { @@ -850,7 +854,9 @@ func (s *Server) handleEC2Instances(instances *server.EC2Instances) error { // EICE Nodes must never be filtered, so that we can extend their expiration and sync labels. totalInstancesFound := len(instances.Instances) if !instances.Rotation && instances.EnrollMode != types.InstallParamEnrollMode_INSTALL_PARAM_ENROLL_MODE_EICE { - s.filterExistingEC2Nodes(instances) + if err := s.filterExistingEC2Nodes(instances); err != nil { + return trace.Wrap(err) + } } instancesAlreadyEnrolled := totalInstancesFound - len(instances.Instances) @@ -904,12 +910,24 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) { continue } - existingNode, err := s.nodeWatcher.GetNode(s.ctx, eiceNode.GetName()) + existingNodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(s types.ReadOnlyServer) bool { + return s.GetName() == eiceNode.GetName() + }) if err != nil && !trace.IsNotFound(err) { s.Log.Warnf("Error finding the existing node with name %q: %v", eiceNode.GetName(), err) continue } + var existingNode types.Server + switch len(existingNodes) { + case 0: + case 1: + existingNode = existingNodes[0] + default: + s.Log.Warnf("Found multiple matching nodes with name %q", eiceNode.GetName()) + continue + } + // EICE Node's Name are deterministic (based on the Account and Instance ID). // // To reduce load, nodes are skipped if @@ -1064,7 +1082,7 @@ func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, err if err != nil { return nil, trace.Wrap(err) } - found := s.nodeWatcher.GetNodes(ctx, func(n services.Node) bool { + found, err := s.nodeWatcher.CurrentResourcesWithFilter(ctx, func(n types.ReadOnlyServer) bool { if n.GetSubKind() != types.SubKindOpenSSHNode { return false } @@ -1077,6 +1095,9 @@ func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, err return mostRecentCertRotation.After(n.GetRotation().LastRotated) }) + if err != nil { + return nil, trace.Wrap(err) + } if len(found) == 0 { return nil, trace.NotFound("no unrotated nodes found") @@ -1118,13 +1139,18 @@ func (s *Server) handleEC2Discovery() { } } -func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n types.ReadOnlyServer) bool { labels := n.GetAllLabels() _, subscriptionOK := labels[types.SubscriptionIDLabel] _, vmOK := labels[types.VMIDLabel] return subscriptionOK && vmOK }) + + if err != nil { + return trace.Wrap(err) + } + var filtered []*armcompute.VirtualMachine outer: for _, inst := range instances.Instances { @@ -1144,6 +1170,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func (s *Server) handleAzureInstances(instances *server.AzureInstances) error { @@ -1151,7 +1178,9 @@ func (s *Server) handleAzureInstances(instances *server.AzureInstances) error { if err != nil { return trace.Wrap(err) } - s.filterExistingAzureNodes(instances) + if err := s.filterExistingAzureNodes(instances); err != nil { + return trace.Wrap(err) + } if len(instances.Instances) == 0 { return trace.Wrap(errNoInstances) } @@ -1206,14 +1235,19 @@ func (s *Server) handleAzureDiscovery() { } } -func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n types.ReadOnlyServer) bool { labels := n.GetAllLabels() _, projectIDOK := labels[types.ProjectIDLabelDiscovery] _, zoneOK := labels[types.ZoneLabelDiscovery] _, nameOK := labels[types.NameLabelDiscovery] return projectIDOK && zoneOK && nameOK }) + + if err != nil { + return trace.Wrap(err) + } + var filtered []*gcpimds.Instance outer: for _, inst := range instances.Instances { @@ -1230,6 +1264,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func (s *Server) handleGCPInstances(instances *server.GCPInstances) error { @@ -1237,7 +1272,9 @@ func (s *Server) handleGCPInstances(instances *server.GCPInstances) error { if err != nil { return trace.Wrap(err) } - s.filterExistingGCPNodes(instances) + if err := s.filterExistingGCPNodes(instances); err != nil { + return trace.Wrap(err) + } if len(instances.Instances) == 0 { return trace.Wrap(errNoInstances) } @@ -1730,6 +1767,7 @@ func (s *Server) initTeleportNodeWatcher() (err error) { Client: s.AccessPoint, MaxStaleness: time.Minute, }, + NodesGetter: s.AccessPoint, }) return trace.Wrap(err) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 974dd1bfcdad..ab384e65d74f 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -852,11 +852,10 @@ func TestDiscoveryServerConcurrency(t *testing.T) { // We must get only one EC2 EICE Node. // Even when two servers are discovering the same EC2 Instance, they will use the same name when converting to EICE Node. - require.Eventually(t, func() bool { + require.EventuallyWithT(t, func(t *assert.CollectT) { allNodes, err := tlsServer.Auth().GetNodes(ctx, "default") - require.NoError(t, err) - - return len(allNodes) == 1 + assert.NoError(t, err) + assert.Len(t, allNodes, 1) }, 1*time.Second, 50*time.Millisecond) // We should never get a duplicate instance. diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 42df4c9d4017..4d9be00f6244 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -2865,12 +2865,13 @@ func newLockWatcher(ctx context.Context, t testing.TB, client types.Events) *ser return lockWatcher } -func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.NodeWatcher { +func newNodeWatcher(ctx context.Context, t *testing.T, client *authclient.Client) *services.GenericWatcher[types.Server, types.ReadOnlyServer] { nodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", Client: client, }, + NodesGetter: client, }) require.NoError(t, err) t.Cleanup(nodeWatcher.Close) diff --git a/lib/utils/fncache.go b/lib/utils/fncache.go index e45a8b3a2d82..84f5be17478b 100644 --- a/lib/utils/fncache.go +++ b/lib/utils/fncache.go @@ -245,6 +245,8 @@ func FnCacheGetWithTTL[T any](ctx context.Context, cache *FnCache, key any, ttl switch { case err != nil: return ret, err + case t == nil: + return ret, nil case !ok: return ret, trace.BadParameter("value retrieved was %T, expected %T", t, ret) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 8e1e4ba43f69..155e911a41cf 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -162,7 +162,7 @@ type Handler struct { // nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // tracer is used to create spans. tracer oteltrace.Tracer @@ -298,7 +298,7 @@ type Config struct { // NodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. - NodeWatcher *services.NodeWatcher + NodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // PresenceChecker periodically runs the mfa ceremony for moderated // sessions. @@ -3530,9 +3530,12 @@ func (h *Handler) siteNodeConnect( WebsocketConn: ws, SSHDialTimeout: dialTimeout, HostNameResolver: func(serverID string) (string, error) { - matches := nw.GetNodes(r.Context(), func(n services.Node) bool { + matches, err := nw.CurrentResourcesWithFilter(r.Context(), func(n types.ReadOnlyServer) bool { return n.GetName() == serverID }) + if err != nil { + return "", trace.Wrap(err) + } if len(matches) != 1 { return "", trace.NotFound("unable to resolve hostname for server %s", serverID) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index ad16a9c658dc..7199ba708174 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -382,6 +382,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { Component: teleport.ComponentProxy, Client: s.proxyClient, }, + NodesGetter: s.proxyClient, }) require.NoError(t, err) @@ -8186,6 +8187,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula Component: teleport.ComponentProxy, Client: client, }, + NodesGetter: client, }) require.NoError(t, err) t.Cleanup(proxyNodeWatcher.Close) @@ -9076,6 +9078,7 @@ func startKubeWithoutCleanup(ctx context.Context, t *testing.T, cfg startKubeOpt Client: client, Clock: clock, }, + KubernetesServerGetter: client, }) require.NoError(t, err) diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 609c8a827033..bbd8529a0336 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -2862,34 +2862,48 @@ func TestSSHHeadless(t *testing.T) { bob.SetRoles([]string{"requester"}) sshHostname := "test-ssh-host" - rootAuth, rootProxy := makeTestServers(t, withBootstrap(nodeAccess, alice, requester, bob), withConfig(func(cfg *servicecfg.Config) { - cfg.Hostname = sshHostname - cfg.SSH.Enabled = true - cfg.SSH.Addr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} - })) + server := testserver.MakeTestServer(t, + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.Hostname = sshHostname + cfg.Auth.Enabled = true + cfg.Proxy.Enabled = true + cfg.SSH.Enabled = true + cfg.SSH.DisableCreateHostUser = true - proxyAddr, err := rootProxy.ProxyWebAddr() - require.NoError(t, err) + cfg.Auth.BootstrapResources = []types.Resource{nodeAccess, alice, requester, bob} + cfg.Auth.Preference = &types.AuthPreferenceV2{ + Metadata: types.Metadata{ + Labels: map[string]string{types.OriginLabel: types.OriginConfigFile}, + }, + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + Webauthn: &types.Webauthn{ + RPID: "127.0.0.1", + }, + AllowHeadless: types.NewBoolOption(true), + }, + } + }), + ) - _, err = rootAuth.GetAuthServer().UpsertAuthPreference(ctx, &types.AuthPreferenceV2{ - Spec: types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOptional, - Webauthn: &types.Webauthn{ - RPID: "127.0.0.1", - }, - }, - }) - require.NoError(t, err) + require.EventuallyWithT(t, func(t *assert.CollectT) { + found, err := server.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + assert.NoError(t, err) + assert.Len(t, found, 1) + }, 10*time.Second, 100*time.Millisecond) go func() { - if err := approveAllAccessRequests(ctx, rootAuth.GetAuthServer()); err != nil { + // Ensure the context is canceled, so that Run calls don't block + defer cancel() + if err := approveAllAccessRequests(ctx, server.GetAuthServer()); err != nil { assert.ErrorIs(t, err, context.Canceled, "unexpected error from approveAllAccessRequests") } - // Cancel the context, so Run calls don't block - cancel() }() + proxyAddr, err := server.ProxyWebAddr() + require.NoError(t, err) + for _, tc := range []struct { name string args []string @@ -2930,10 +2944,10 @@ func TestSSHHeadless(t *testing.T) { "echo", "test", ) - err := Run(ctx, args, CliOption(func(cf *CLIConf) error { - cf.MockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + err := Run(ctx, args, func(cf *CLIConf) error { + cf.MockHeadlessLogin = mockHeadlessLogin(t, server.GetAuthServer(), alice) return nil - })) + }) tc.assertErr(t, err) }) } @@ -2966,32 +2980,45 @@ func TestHeadlessDoesNotAddKeysToAgent(t *testing.T) { alice.SetRoles([]string{"node-access"}) sshHostname := "test-ssh-host" - rootAuth, rootProxy := makeTestServers(t, withBootstrap(nodeAccess, alice), withConfig(func(cfg *servicecfg.Config) { - cfg.Hostname = sshHostname - cfg.SSH.Enabled = true - cfg.SSH.Addr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} - })) - proxyAddr, err := rootProxy.ProxyWebAddr() - require.NoError(t, err) + server := testserver.MakeTestServer(t, + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.Hostname = sshHostname + cfg.Auth.Enabled = true + cfg.Proxy.Enabled = true + cfg.SSH.Enabled = true + cfg.SSH.DisableCreateHostUser = true + cfg.Auth.BootstrapResources = []types.Resource{nodeAccess, alice} + cfg.Auth.Preference = &types.AuthPreferenceV2{ + Metadata: types.Metadata{ + Labels: map[string]string{types.OriginLabel: types.OriginConfigFile}, + }, + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + Webauthn: &types.Webauthn{ + RPID: "127.0.0.1", + }, + AllowHeadless: types.NewBoolOption(true), + }, + } + })) - _, err = rootAuth.GetAuthServer().UpsertAuthPreference(ctx, &types.AuthPreferenceV2{ - Spec: types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOptional, - Webauthn: &types.Webauthn{ - RPID: "127.0.0.1", - }, - }, - }) + require.EventuallyWithT(t, func(t *assert.CollectT) { + found, err := server.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + assert.NoError(t, err) + assert.Len(t, found, 1) + }, 10*time.Second, 100*time.Millisecond) + + proxyAddr, err := server.ProxyWebAddr() require.NoError(t, err) go func() { - if err := approveAllAccessRequests(ctx, rootAuth.GetAuthServer()); err != nil { + // Ensure the context is canceled, so that Run calls don't block + defer cancel() + if err := approveAllAccessRequests(ctx, server.GetAuthServer()); err != nil { assert.ErrorIs(t, err, context.Canceled, "unexpected error from approveAllAccessRequests") } - // Cancel the context, so Run calls don't block - cancel() }() err = Run(ctx, []string{ @@ -3004,10 +3031,10 @@ func TestHeadlessDoesNotAddKeysToAgent(t *testing.T) { "--add-keys-to-agent=yes", fmt.Sprintf("%s@%s", user.Username, sshHostname), "echo", "test", - }, CliOption(func(cf *CLIConf) error { - cf.MockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + }, func(cf *CLIConf) error { + cf.MockHeadlessLogin = mockHeadlessLogin(t, server.GetAuthServer(), alice) return nil - })) + }) require.NoError(t, err) keys, err := agentKeyring.List()