Skip to content

Commit

Permalink
Merge pull request #4653 from twz123/rest-joinclient
Browse files Browse the repository at this point in the history
Use the Kubernetes REST client for joining
  • Loading branch information
twz123 authored Jun 21, 2024
2 parents 99b600f + 2676a27 commit bc88876
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 90 deletions.
10 changes: 7 additions & 3 deletions pkg/component/controller/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (e *Etcd) Init(_ context.Context) error {
return assets.Stage(e.K0sVars.BinDir, "etcd", constant.BinDirMode)
}

func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) {
func (e *Etcd) syncEtcdConfig(ctx context.Context, etcdRequest v1beta1.EtcdRequest, etcdCaCert, etcdCaCertKey string) ([]string, error) {
logrus.Info("Synchronizing etcd config with existing cluster via ", e.JoinClient.Address())

var etcdResponse v1beta1.EtcdResponse
Expand All @@ -103,7 +103,7 @@ func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCe
func() error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
etcdResponse, err = e.JoinClient.JoinEtcd(ctx, peerURL)
etcdResponse, err = e.JoinClient.JoinEtcd(ctx, etcdRequest)
return err
},
retry.Context(ctx),
Expand Down Expand Up @@ -190,7 +190,11 @@ func (e *Etcd) Start(ctx context.Context) error {
if file.Exists(filepath.Join(e.K0sVars.EtcdDataDir, "member", "snap", "db")) {
logrus.Warnf("etcd db file(s) already exist, not gonna run join process")
} else if e.JoinClient != nil {
initialCluster, err := e.syncEtcdConfig(ctx, peerURL, etcdCaCert, etcdCaCertKey)
etcdRequest := v1beta1.EtcdRequest{
Node: name,
PeerAddress: peerURL,
}
initialCluster, err := e.syncEtcdConfig(ctx, etcdRequest, etcdCaCert, etcdCaCertKey)
if err != nil {
return fmt.Errorf("failed to sync etcd config: %w", err)
}
Expand Down
104 changes: 26 additions & 78 deletions pkg/token/joinclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,23 @@ package token
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"os"

"github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1"
"github.com/k0sproject/k0s/pkg/kubernetes"

"k8s.io/client-go/dynamic"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/client-go/tools/clientcmd/api"
)

// JoinClient is the client we can use to call k0s join APIs
type JoinClient struct {
joinAddress string
httpClient http.Client
bearerToken string
joinTokenType string
joinAddress string
restClient *rest.RESTClient
}

// JoinClientFromToken creates a new join api client from a token
Expand All @@ -46,35 +45,27 @@ func JoinClientFromToken(encodedToken string) (*JoinClient, error) {
return nil, fmt.Errorf("failed to decode token: %w", err)
}

clientConfig, err := clientcmd.NewClientConfigFromBytes(tokenBytes)
kubeconfig, err := clientcmd.Load(tokenBytes)
if err != nil {
return nil, err
}
config, err := clientConfig.ClientConfig()

restConfig, err := kubernetes.ClientConfig(func() (*api.Config, error) { return kubeconfig, nil })
if err != nil {
return nil, err
}

raw, err := clientConfig.RawConfig()
restConfig = dynamic.ConfigFor(restConfig)
restClient, err := rest.UnversionedRESTClientFor(restConfig)
if err != nil {
return nil, err
}

ca := x509.NewCertPool()
ca.AppendCertsFromPEM(config.CAData)
tlsConfig := &tls.Config{
InsecureSkipVerify: false,
RootCAs: ca,
}
tr := &http.Transport{TLSClientConfig: tlsConfig}
c := &JoinClient{
httpClient: http.Client{Transport: tr},
bearerToken: config.BearerToken,
}
c.joinAddress = config.Host
c.joinTokenType = GetTokenType(&raw)

return c, nil
return &JoinClient{
joinAddress: restConfig.Host,
joinTokenType: GetTokenType(kubeconfig),
restClient: restClient,
}, nil
}

func (j *JoinClient) Address() string {
Expand All @@ -88,71 +79,28 @@ func (j *JoinClient) JoinTokenType() string {
// GetCA calls the CA sync API
func (j *JoinClient) GetCA(ctx context.Context) (v1beta1.CaResponse, error) {
var caData v1beta1.CaResponse
req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.joinAddress+"/v1beta1/ca", nil)
if err != nil {
return caData, fmt.Errorf("failed to create join request: %w", err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken))

resp, err := j.httpClient.Do(req)
if err != nil {
return caData, err
b, err := j.restClient.Get().AbsPath("v1beta1", "ca").Do(ctx).Raw()
if err == nil {
err = json.Unmarshal(b, &caData)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return caData, fmt.Errorf("unexpected response status: %s", resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return caData, err
}
err = json.Unmarshal(b, &caData)
if err != nil {
return caData, err
}
return caData, nil
return caData, err
}

// JoinEtcd calls the etcd join API
func (j *JoinClient) JoinEtcd(ctx context.Context, peerAddress string) (v1beta1.EtcdResponse, error) {
func (j *JoinClient) JoinEtcd(ctx context.Context, etcdRequest v1beta1.EtcdRequest) (v1beta1.EtcdResponse, error) {
var etcdResponse v1beta1.EtcdResponse
etcdRequest := v1beta1.EtcdRequest{
PeerAddress: peerAddress,
}
name, err := os.Hostname()
if err != nil {
return etcdResponse, err
}
etcdRequest.Node = name

buf := new(bytes.Buffer)
if err := json.NewEncoder(buf).Encode(etcdRequest); err != nil {
return etcdResponse, err
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf)
if err != nil {
return etcdResponse, fmt.Errorf("failed to create join request: %w", err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken))
resp, err := j.httpClient.Do(req)
if err != nil {
return etcdResponse, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return etcdResponse, fmt.Errorf("unexpected response status when trying to join etcd cluster: %s", resp.Status)
}

b, err := io.ReadAll(resp.Body)
if err != nil {
return etcdResponse, err
}
err = json.Unmarshal(b, &etcdResponse)
if err != nil {
return etcdResponse, err
b, err := j.restClient.Post().AbsPath("v1beta1", "etcd", "members").Body(buf).Do(ctx).Raw()
if err == nil {
err = json.Unmarshal(b, &etcdResponse)
}

return etcdResponse, nil
return etcdResponse, err
}
108 changes: 99 additions & 9 deletions pkg/token/joinclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,86 @@ package token_test
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"testing"

k0sv1beta1 "github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1"
"github.com/k0sproject/k0s/pkg/token"

"github.com/cloudflare/cfssl/csr"
"github.com/cloudflare/cfssl/initca"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestJoinClient_GetCA(t *testing.T) {
t.Parallel()

joinURL, certData := startFakeJoinServer(t, func(res http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/some/sub/path/v1beta1/ca", req.RequestURI)
assert.Equal(t, []string{"Bearer the-token"}, req.Header["Authorization"])
_, err := res.Write([]byte("{}"))
assert.NoError(t, err)
})

joinURL.Path = "/some/sub/path"
kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), certData, t.Name(), "the-token")
require.NoError(t, err)
tok, err := token.JoinEncode(bytes.NewReader(kubeconfig))
require.NoError(t, err)

underTest, err := token.JoinClientFromToken(tok)
require.NoError(t, err)

response, err := underTest.GetCA(context.TODO())
assert.NoError(t, err)
assert.Zero(t, response)
}

func TestJoinClient_JoinEtcd(t *testing.T) {
t.Parallel()

joinURL, certData := startFakeJoinServer(t, func(res http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/some/sub/path/v1beta1/etcd/members", req.RequestURI)
assert.Equal(t, []string{"Bearer the-token"}, req.Header["Authorization"])

if body, err := io.ReadAll(req.Body); assert.NoError(t, err) {
var data map[string]string
if assert.NoError(t, json.Unmarshal(body, &data)) {
assert.Equal(t, map[string]string{
"node": "the-node",
"peerAddress": "the-peer-address",
}, data)
}
}

_, err := res.Write([]byte("{}"))
assert.NoError(t, err)
})

joinURL.Path = "/some/sub/path"
kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), certData, t.Name(), "the-token")
require.NoError(t, err)
tok, err := token.JoinEncode(bytes.NewReader(kubeconfig))
require.NoError(t, err)

underTest, err := token.JoinClientFromToken(tok)
require.NoError(t, err)

response, err := underTest.JoinEtcd(context.TODO(), k0sv1beta1.EtcdRequest{
Node: "the-node",
PeerAddress: "the-peer-address",
})
assert.NoError(t, err)
assert.Zero(t, response)
}

func TestJoinClient_Cancellation(t *testing.T) {
t.Parallel()

Expand All @@ -42,7 +111,7 @@ func TestJoinClient_Cancellation(t *testing.T) {
return err
}},
{"JoinEtcd", func(ctx context.Context, c *token.JoinClient) error {
_, err := c.JoinEtcd(ctx, "")
_, err := c.JoinEtcd(ctx, k0sv1beta1.EtcdRequest{})
return err
}},
} {
Expand All @@ -51,12 +120,12 @@ func TestJoinClient_Cancellation(t *testing.T) {
t.Parallel()

clientContext, cancelClientContext := context.WithCancelCause(context.Background())
joinURL := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) {
joinURL, certData := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) {
cancelClientContext(assert.AnError) // cancel the client's context
<-req.Context().Done() // block forever
})

kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), nil, "", "")
kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), certData, "", "")
require.NoError(t, err)
tok, err := token.JoinEncode(bytes.NewReader(kubeconfig))
require.NoError(t, err)
Expand All @@ -71,23 +140,44 @@ func TestJoinClient_Cancellation(t *testing.T) {
}
}

func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *url.URL {
func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*url.URL, []byte) {
requestCtx, cancelRequests := context.WithCancel(context.Background())
var ok bool

listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
require.NoError(t, err)
}
defer func() {
if !ok {
assert.NoError(t, listener.Close())
}
}()

addr := listener.Addr().(*net.TCPAddr)
certData, _, keyData, err := initca.New(&csr.CertificateRequest{
KeyRequest: csr.NewKeyRequest(),
CN: fmt.Sprintf("localhost:%d", addr.Port),
Hosts: []string{addr.IP.String()},
})
if !assert.NoError(t, err) {
assert.NoError(t, listener.Close())
t.FailNow()
}
cert, err := tls.X509KeyPair(certData, keyData)
require.NoError(t, err)

server := &http.Server{
Addr: listener.Addr().String(),
Addr: addr.String(),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
Handler: http.HandlerFunc(handler),
BaseContext: func(net.Listener) context.Context { return requestCtx },
}

serverError := make(chan error)
go func() { defer close(serverError); serverError <- server.Serve(listener) }()

ok = true
go func() { defer close(serverError); serverError <- server.ServeTLS(listener, "", "") }()
t.Cleanup(func() {
cancelRequests()
if !assert.NoError(t, server.Shutdown(context.Background()), "Couldn't shutdown HTTP server") {
Expand All @@ -96,5 +186,5 @@ func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.R
assert.ErrorIs(t, <-serverError, http.ErrServerClosed, "HTTP server terminated unexpectedly")
})

return &url.URL{Scheme: "http", Host: server.Addr}
return &url.URL{Scheme: "https", Host: server.Addr}, certData
}

0 comments on commit bc88876

Please sign in to comment.