diff --git a/pkg/component/controller/etcd.go b/pkg/component/controller/etcd.go index 60b4eb2ab2b6..ab2357ca9c39 100644 --- a/pkg/component/controller/etcd.go +++ b/pkg/component/controller/etcd.go @@ -94,7 +94,7 @@ func (e *Etcd) Init(_ context.Context) error { return assets.Stage(e.K0sVars.BinDir, "etcd", constant.BinDirMode) } -func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) { +func (e *Etcd) syncEtcdConfig(ctx context.Context, etcdRequest v1beta1.EtcdRequest, etcdCaCert, etcdCaCertKey string) ([]string, error) { logrus.Info("Synchronizing etcd config with existing cluster via ", e.JoinClient.Address()) var etcdResponse v1beta1.EtcdResponse @@ -103,7 +103,7 @@ func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCe func() error { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - etcdResponse, err = e.JoinClient.JoinEtcd(ctx, peerURL) + etcdResponse, err = e.JoinClient.JoinEtcd(ctx, etcdRequest) return err }, retry.Context(ctx), @@ -190,7 +190,11 @@ func (e *Etcd) Start(ctx context.Context) error { if file.Exists(filepath.Join(e.K0sVars.EtcdDataDir, "member", "snap", "db")) { logrus.Warnf("etcd db file(s) already exist, not gonna run join process") } else if e.JoinClient != nil { - initialCluster, err := e.syncEtcdConfig(ctx, peerURL, etcdCaCert, etcdCaCertKey) + etcdRequest := v1beta1.EtcdRequest{ + Node: name, + PeerAddress: peerURL, + } + initialCluster, err := e.syncEtcdConfig(ctx, etcdRequest, etcdCaCert, etcdCaCertKey) if err != nil { return fmt.Errorf("failed to sync etcd config: %w", err) } diff --git a/pkg/token/joinclient.go b/pkg/token/joinclient.go index b8e4f3dc03cf..37b5a8db924f 100644 --- a/pkg/token/joinclient.go +++ b/pkg/token/joinclient.go @@ -19,24 +19,23 @@ package token import ( "bytes" "context" - "crypto/tls" - "crypto/x509" "encoding/json" "fmt" - "io" - "net/http" - "os" "github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1" + "github.com/k0sproject/k0s/pkg/kubernetes" + + "k8s.io/client-go/dynamic" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" + "k8s.io/client-go/tools/clientcmd/api" ) // JoinClient is the client we can use to call k0s join APIs type JoinClient struct { - joinAddress string - httpClient http.Client - bearerToken string joinTokenType string + joinAddress string + restClient *rest.RESTClient } // JoinClientFromToken creates a new join api client from a token @@ -46,35 +45,27 @@ func JoinClientFromToken(encodedToken string) (*JoinClient, error) { return nil, fmt.Errorf("failed to decode token: %w", err) } - clientConfig, err := clientcmd.NewClientConfigFromBytes(tokenBytes) + kubeconfig, err := clientcmd.Load(tokenBytes) if err != nil { return nil, err } - config, err := clientConfig.ClientConfig() + + restConfig, err := kubernetes.ClientConfig(func() (*api.Config, error) { return kubeconfig, nil }) if err != nil { return nil, err } - raw, err := clientConfig.RawConfig() + restConfig = dynamic.ConfigFor(restConfig) + restClient, err := rest.UnversionedRESTClientFor(restConfig) if err != nil { return nil, err } - ca := x509.NewCertPool() - ca.AppendCertsFromPEM(config.CAData) - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, - RootCAs: ca, - } - tr := &http.Transport{TLSClientConfig: tlsConfig} - c := &JoinClient{ - httpClient: http.Client{Transport: tr}, - bearerToken: config.BearerToken, - } - c.joinAddress = config.Host - c.joinTokenType = GetTokenType(&raw) - - return c, nil + return &JoinClient{ + joinAddress: restConfig.Host, + joinTokenType: GetTokenType(kubeconfig), + restClient: restClient, + }, nil } func (j *JoinClient) Address() string { @@ -88,71 +79,28 @@ func (j *JoinClient) JoinTokenType() string { // GetCA calls the CA sync API func (j *JoinClient) GetCA(ctx context.Context) (v1beta1.CaResponse, error) { var caData v1beta1.CaResponse - req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.joinAddress+"/v1beta1/ca", nil) - if err != nil { - return caData, fmt.Errorf("failed to create join request: %w", err) - } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken)) - resp, err := j.httpClient.Do(req) - if err != nil { - return caData, err + b, err := j.restClient.Get().AbsPath("v1beta1", "ca").Do(ctx).Raw() + if err == nil { + err = json.Unmarshal(b, &caData) } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return caData, fmt.Errorf("unexpected response status: %s", resp.Status) - } - b, err := io.ReadAll(resp.Body) - if err != nil { - return caData, err - } - err = json.Unmarshal(b, &caData) - if err != nil { - return caData, err - } - return caData, nil + return caData, err } // JoinEtcd calls the etcd join API -func (j *JoinClient) JoinEtcd(ctx context.Context, peerAddress string) (v1beta1.EtcdResponse, error) { +func (j *JoinClient) JoinEtcd(ctx context.Context, etcdRequest v1beta1.EtcdRequest) (v1beta1.EtcdResponse, error) { var etcdResponse v1beta1.EtcdResponse - etcdRequest := v1beta1.EtcdRequest{ - PeerAddress: peerAddress, - } - name, err := os.Hostname() - if err != nil { - return etcdResponse, err - } - etcdRequest.Node = name buf := new(bytes.Buffer) if err := json.NewEncoder(buf).Encode(etcdRequest); err != nil { return etcdResponse, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf) - if err != nil { - return etcdResponse, fmt.Errorf("failed to create join request: %w", err) - } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken)) - resp, err := j.httpClient.Do(req) - if err != nil { - return etcdResponse, err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return etcdResponse, fmt.Errorf("unexpected response status when trying to join etcd cluster: %s", resp.Status) - } - - b, err := io.ReadAll(resp.Body) - if err != nil { - return etcdResponse, err - } - err = json.Unmarshal(b, &etcdResponse) - if err != nil { - return etcdResponse, err + b, err := j.restClient.Post().AbsPath("v1beta1", "etcd", "members").Body(buf).Do(ctx).Raw() + if err == nil { + err = json.Unmarshal(b, &etcdResponse) } - return etcdResponse, nil + return etcdResponse, err } diff --git a/pkg/token/joinclient_test.go b/pkg/token/joinclient_test.go index 19e542c71e6b..a977a3a5eac7 100644 --- a/pkg/token/joinclient_test.go +++ b/pkg/token/joinclient_test.go @@ -19,17 +19,86 @@ package token_test import ( "bytes" "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" "net" "net/http" "net/url" "testing" + k0sv1beta1 "github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1" "github.com/k0sproject/k0s/pkg/token" + "github.com/cloudflare/cfssl/csr" + "github.com/cloudflare/cfssl/initca" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestJoinClient_GetCA(t *testing.T) { + t.Parallel() + + joinURL, certData := startFakeJoinServer(t, func(res http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/some/sub/path/v1beta1/ca", req.RequestURI) + assert.Equal(t, []string{"Bearer the-token"}, req.Header["Authorization"]) + _, err := res.Write([]byte("{}")) + assert.NoError(t, err) + }) + + joinURL.Path = "/some/sub/path" + kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), certData, t.Name(), "the-token") + require.NoError(t, err) + tok, err := token.JoinEncode(bytes.NewReader(kubeconfig)) + require.NoError(t, err) + + underTest, err := token.JoinClientFromToken(tok) + require.NoError(t, err) + + response, err := underTest.GetCA(context.TODO()) + assert.NoError(t, err) + assert.Zero(t, response) +} + +func TestJoinClient_JoinEtcd(t *testing.T) { + t.Parallel() + + joinURL, certData := startFakeJoinServer(t, func(res http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/some/sub/path/v1beta1/etcd/members", req.RequestURI) + assert.Equal(t, []string{"Bearer the-token"}, req.Header["Authorization"]) + + if body, err := io.ReadAll(req.Body); assert.NoError(t, err) { + var data map[string]string + if assert.NoError(t, json.Unmarshal(body, &data)) { + assert.Equal(t, map[string]string{ + "node": "the-node", + "peerAddress": "the-peer-address", + }, data) + } + } + + _, err := res.Write([]byte("{}")) + assert.NoError(t, err) + }) + + joinURL.Path = "/some/sub/path" + kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), certData, t.Name(), "the-token") + require.NoError(t, err) + tok, err := token.JoinEncode(bytes.NewReader(kubeconfig)) + require.NoError(t, err) + + underTest, err := token.JoinClientFromToken(tok) + require.NoError(t, err) + + response, err := underTest.JoinEtcd(context.TODO(), k0sv1beta1.EtcdRequest{ + Node: "the-node", + PeerAddress: "the-peer-address", + }) + assert.NoError(t, err) + assert.Zero(t, response) +} + func TestJoinClient_Cancellation(t *testing.T) { t.Parallel() @@ -42,7 +111,7 @@ func TestJoinClient_Cancellation(t *testing.T) { return err }}, {"JoinEtcd", func(ctx context.Context, c *token.JoinClient) error { - _, err := c.JoinEtcd(ctx, "") + _, err := c.JoinEtcd(ctx, k0sv1beta1.EtcdRequest{}) return err }}, } { @@ -51,12 +120,12 @@ func TestJoinClient_Cancellation(t *testing.T) { t.Parallel() clientContext, cancelClientContext := context.WithCancelCause(context.Background()) - joinURL := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) { + joinURL, certData := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) { cancelClientContext(assert.AnError) // cancel the client's context <-req.Context().Done() // block forever }) - kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), nil, "", "") + kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), certData, "", "") require.NoError(t, err) tok, err := token.JoinEncode(bytes.NewReader(kubeconfig)) require.NoError(t, err) @@ -71,23 +140,44 @@ func TestJoinClient_Cancellation(t *testing.T) { } } -func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *url.URL { +func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*url.URL, []byte) { requestCtx, cancelRequests := context.WithCancel(context.Background()) + var ok bool listener, err := net.Listen("tcp", "localhost:0") if err != nil { require.NoError(t, err) } + defer func() { + if !ok { + assert.NoError(t, listener.Close()) + } + }() + + addr := listener.Addr().(*net.TCPAddr) + certData, _, keyData, err := initca.New(&csr.CertificateRequest{ + KeyRequest: csr.NewKeyRequest(), + CN: fmt.Sprintf("localhost:%d", addr.Port), + Hosts: []string{addr.IP.String()}, + }) + if !assert.NoError(t, err) { + assert.NoError(t, listener.Close()) + t.FailNow() + } + cert, err := tls.X509KeyPair(certData, keyData) + require.NoError(t, err) server := &http.Server{ - Addr: listener.Addr().String(), + Addr: addr.String(), + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, Handler: http.HandlerFunc(handler), BaseContext: func(net.Listener) context.Context { return requestCtx }, } - serverError := make(chan error) - go func() { defer close(serverError); serverError <- server.Serve(listener) }() - + ok = true + go func() { defer close(serverError); serverError <- server.ServeTLS(listener, "", "") }() t.Cleanup(func() { cancelRequests() if !assert.NoError(t, server.Shutdown(context.Background()), "Couldn't shutdown HTTP server") { @@ -96,5 +186,5 @@ func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.R assert.ErrorIs(t, <-serverError, http.ErrServerClosed, "HTTP server terminated unexpectedly") }) - return &url.URL{Scheme: "http", Host: server.Addr} + return &url.URL{Scheme: "https", Host: server.Addr}, certData }