Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure joining nodes are immediately trusted #89

Closed
wants to merge 9 commits into from
134 changes: 82 additions & 52 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,6 @@ func (d *Daemon) reloadIfBootstrapped() error {
return err
}

err = d.setDaemonConfig(nil)
if err != nil {
return fmt.Errorf("Failed to retrieve daemon configuration yaml: %w", err)
}

err = d.StartAPI(false, nil, nil)
if err != nil {
return err
Expand Down Expand Up @@ -315,11 +310,9 @@ func (d *Daemon) initServer(resources ...*resources.Resources) *http.Server {
// StartAPI starts up the admin and consumer APIs, and generates a cluster cert
// if we are bootstrapping the first node.
func (d *Daemon) StartAPI(bootstrap bool, initConfig map[string]string, newConfig *trust.Location, joinAddresses ...string) error {
if newConfig != nil {
err := d.setDaemonConfig(newConfig)
if err != nil {
return fmt.Errorf("Failed to apply and save new daemon configuration: %w", err)
}
_, err := d.setDaemonConfig(newConfig)
if err != nil {
return fmt.Errorf("Failed to apply and save new daemon configuration: %w", err)
}

if bootstrap {
Expand Down Expand Up @@ -420,24 +413,44 @@ func (d *Daemon) StartAPI(bootstrap bool, initConfig map[string]string, newConfi
}

// Get a client for every other cluster member in the newly refreshed local store.
cluster := make(client.Cluster, 0, d.trustStore.Remotes().Count()-1)
for _, addr := range d.trustStore.Remotes().Addresses() {
if d.address.URL.Host == addr.String() {
continue
}
publicKey, err := d.ClusterCert().PublicKeyX509()
if err != nil {
return err
}

publicKey, err := d.ClusterCert().PublicKeyX509()
if err != nil {
return err
cluster, err := d.trustStore.Remotes().Cluster(false, d.ServerCert(), publicKey)
if err != nil {
return err
}

localMemberInfo := internalTypes.ClusterMemberLocal{Name: localNode.Name, Address: localNode.Address, Certificate: localNode.Certificate}
var clusterConfirmation bool
var lastErr error
err = cluster.Query(d.ShutdownCtx, false, func(ctx context.Context, c *client.Client) error {
// No need to send a request to ourselves.
if d.address.URL.Host == c.URL().URL.Host {
return nil
}

url := api.NewURL().Scheme("https").Host(addr.String())
c, err := internalClient.New(*url, d.ServerCert(), publicKey, true)
if err != nil {
return err
// At this point the joiner is only trusted on the node that was leader at the time,
// so find it and have it instruct all dqlite members to trust this system now that it is functional.
if !clusterConfirmation {
err = c.RegisterClusterMember(ctx, internalTypes.ClusterMember{ClusterMemberLocal: localMemberInfo})
if err != nil {
lastErr = err
} else {
clusterConfirmation = true
}
}

cluster = append(cluster, client.Client{Client: *c})
return nil
})
if err != nil {
return err
}

if !clusterConfirmation {
return fmt.Errorf("Failed to confirm new member on any existing system: %w", lastErr)
}

if len(joinAddresses) > 0 {
Expand All @@ -447,38 +460,23 @@ func (d *Daemon) StartAPI(bootstrap bool, initConfig map[string]string, newConfi
}
}

// Send notification that this node is upgraded to all other cluster members.
// Tell the other nodes that this system is up.
err = cluster.Query(d.ShutdownCtx, true, func(ctx context.Context, c *client.Client) error {
path := c.URL()
parts := strings.Split(string(internalClient.InternalEndpoint), "/")
parts = append(parts, "database")
path = *path.Path(parts...)
upgradeRequest, err := http.NewRequest("PATCH", path.String(), nil)
if err != nil {
return err
}

upgradeRequest.Header.Set("X-Dqlite-Version", fmt.Sprintf("%d", 1))
upgradeRequest = upgradeRequest.WithContext(ctx)
c.SetClusterNotification()

resp, err := c.Client.Do(upgradeRequest)
if err != nil {
logger.Error("Failed to send database upgrade request", logger.Ctx{"error": err})
// No need to send a request to ourselves.
if d.address.URL.Host == c.URL().URL.Host {
return nil
}

defer resp.Body.Close()
_, err = io.Copy(io.Discard, resp.Body)
// Send notification about this node's dqlite version to all other cluster members.
err = d.sendUpgradeNotification(ctx, c)
if err != nil {
logger.Error("Failed to read upgrade notification response body", logger.Ctx{"error": err})
}

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Database upgrade notification failed: %s", resp.Status)
return err
}

// If this was a join request, instruct all peers to run their OnNewMember hook.
if len(joinAddresses) > 0 {
localMemberInfo := internalTypes.ClusterMemberLocal{Name: localNode.Name, Address: localNode.Address, Certificate: localNode.Certificate}
_, err = c.AddClusterMember(ctx, internalTypes.ClusterMember{ClusterMemberLocal: localMemberInfo})
if err != nil {
return err
Expand All @@ -498,6 +496,38 @@ func (d *Daemon) StartAPI(bootstrap bool, initConfig map[string]string, newConfi
return nil
}

func (d *Daemon) sendUpgradeNotification(ctx context.Context, c *client.Client) error {
path := c.URL()
parts := strings.Split(string(internalClient.InternalEndpoint), "/")
parts = append(parts, "database")
path = *path.Path(parts...)
upgradeRequest, err := http.NewRequest("PATCH", path.String(), nil)
if err != nil {
return err
}

upgradeRequest.Header.Set("X-Dqlite-Version", fmt.Sprintf("%d", 1))
upgradeRequest = upgradeRequest.WithContext(ctx)

resp, err := c.Client.Do(upgradeRequest)
if err != nil {
logger.Error("Failed to send database upgrade request", logger.Ctx{"error": err})
return nil
}

defer resp.Body.Close()
_, err = io.Copy(io.Discard, resp.Body)
if err != nil {
logger.Error("Failed to read upgrade notification response body", logger.Ctx{"error": err})
}

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Database upgrade notification failed: %s", resp.Status)
}

return nil
}

// ClusterCert ensures both the daemon and state have the same cluster cert.
func (d *Daemon) ClusterCert() *shared.CertInfo {
d.clusterMu.RLock()
Expand Down Expand Up @@ -587,32 +617,32 @@ func (d *Daemon) Stop() error {

// setDaemonConfig sets the daemon's address and name from the given location information. If none is supplied, the file
// at `state-dir/daemon.yaml` will be read for the information.
func (d *Daemon) setDaemonConfig(config *trust.Location) error {
func (d *Daemon) setDaemonConfig(config *trust.Location) (*trust.Location, error) {
if config != nil {
bytes, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("Failed to parse daemon config to yaml: %w", err)
return nil, fmt.Errorf("Failed to parse daemon config to yaml: %w", err)
}

err = os.WriteFile(filepath.Join(d.os.StateDir, "daemon.yaml"), bytes, 0644)
if err != nil {
return fmt.Errorf("Failed to write daemon configuration yaml: %w", err)
return nil, fmt.Errorf("Failed to write daemon configuration yaml: %w", err)
}
} else {
data, err := os.ReadFile(filepath.Join(d.os.StateDir, "daemon.yaml"))
if err != nil {
return fmt.Errorf("Failed to find daemon configuration: %w", err)
return nil, fmt.Errorf("Failed to find daemon configuration: %w", err)
}

config = &trust.Location{}
err = yaml.Unmarshal(data, config)
if err != nil {
return fmt.Errorf("Failed to parse daemon config from yaml: %w", err)
return nil, fmt.Errorf("Failed to parse daemon config from yaml: %w", err)
}
}

d.address = *api.NewURL().Scheme("https").Host(config.Address.String())
d.name = config.Name

return nil
return config, nil
}
4 changes: 4 additions & 0 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (db *DB) Open(bootstrap bool, project string) error {

// Transaction handles performing a transaction on the dqlite database.
func (db *DB) Transaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error {
if !db.IsOpen() {
return fmt.Errorf("Cannot complete transaction, the database is not running")
}

return db.retry(func() error {
err := query.Transaction(ctx, db.db, f)
if errors.Is(err, context.DeadlineExceeded) {
Expand Down
66 changes: 66 additions & 0 deletions internal/rest/access/authentication.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package access

import (
"context"
"crypto/x509"
"fmt"
"net/http"

"github.com/canonical/lxd/lxd/request"
"github.com/canonical/lxd/lxd/util"
"github.com/canonical/lxd/shared/logger"

"github.com/canonical/microcluster/internal/state"
)

// TrustedRequest holds data pertaining to what level of trust we have for the request.
type TrustedRequest struct {
Trusted bool
}

// SetRequestAuthentication sets the trusted status for the request. A trusted request will be treated as having come from a trusted system.
func SetRequestAuthentication(r *http.Request, trusted bool) *http.Request {
r = r.WithContext(context.WithValue(r.Context(), any(request.CtxAccess), TrustedRequest{Trusted: trusted}))

return r
}

// Authenticate ensures the request certificates are trusted before proceeding.
// - Requests over the unix socket are always allowed.
// - HTTP requests require our cluster cert, or remote certs.
func Authenticate(state *state.State, r *http.Request) (bool, error) {
if r.RemoteAddr == "@" {
return true, nil
}

if state.Address().URL.Host == "" {
logger.Info("Allowing unauthenticated request to un-initialized system")
return true, nil
}

var trustedCerts map[string]x509.Certificate
switch r.Host {
case state.Address().URL.Host:
trustedCerts = state.Remotes().CertificatesNative()
default:
return false, fmt.Errorf("Invalid request address %q", r.Host)
}

if r.TLS != nil {
for _, cert := range r.TLS.PeerCertificates {
trusted, fingerprint := util.CheckTrustState(*cert, trustedCerts, nil, false)
if trusted {
clusterRemote := state.Remotes().RemoteByCertificateFingerprint(fingerprint)
if clusterRemote == nil {
// The cert fingerprint can no longer be matched back against what is in the truststore (e.g. file
// was deleted), so we are no longer trusted.
return false, nil
}

return true, nil
}
}
}

return false, nil
}
22 changes: 0 additions & 22 deletions internal/rest/access/handlers.go

This file was deleted.

5 changes: 5 additions & 0 deletions internal/rest/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ func tlsHTTPClient(clientCert *shared.CertInfo, remoteCert *x509.Certificate, pr
return client, nil
}

// SetClusterNotification sets the client's proxy to apply the forwarding headers to a request.
func (c *Client) SetClusterNotification() {
c.Transport.(*http.Transport).Proxy = forwardingProxy
}

func forwardingProxy(r *http.Request) (*url.URL, error) {
r.Header.Set("User-Agent", clusterRequest.UserAgentNotifier)

Expand Down
8 changes: 8 additions & 0 deletions internal/rest/client/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ func (c *Client) AddClusterMember(ctx context.Context, args types.ClusterMember)
return &tokenResponse, nil
}

// RegisterClusterMember instructs the dqlite leader to inform all existing cluster members to update their local records to include a newly joined system.
func (c *Client) RegisterClusterMember(ctx context.Context, args types.ClusterMember) error {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

return c.QueryStruct(queryCtx, "PUT", PublicEndpoint, api.NewURL().Path("cluster"), args, nil)
}

// GetClusterMembers returns the database record of cluster members.
func (c *Client) GetClusterMembers(ctx context.Context) ([]types.ClusterMember, error) {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
Expand Down
3 changes: 1 addition & 2 deletions internal/rest/resources/api_1.0.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/canonical/lxd/lxd/response"

"github.com/canonical/microcluster/internal/rest/access"
internalTypes "github.com/canonical/microcluster/internal/rest/types"
"github.com/canonical/microcluster/internal/state"
"github.com/canonical/microcluster/rest"
Expand All @@ -15,7 +14,7 @@ import (
var api10Cmd = rest.Endpoint{
AllowedBeforeInit: true,

Get: rest.EndpointAction{Handler: api10Get, AccessHandler: access.AllowAuthenticated},
Get: rest.EndpointAction{Handler: api10Get, AllowUntrusted: true},
}

func api10Get(s *state.State, r *http.Request) response.Response {
Expand Down
2 changes: 1 addition & 1 deletion internal/rest/resources/certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
"github.com/canonical/lxd/lxd/response"

"github.com/canonical/microcluster/client"
"github.com/canonical/microcluster/internal/rest/access"
"github.com/canonical/microcluster/internal/state"
"github.com/canonical/microcluster/rest"
"github.com/canonical/microcluster/rest/access"
"github.com/canonical/microcluster/rest/types"
)

Expand Down
Loading
Loading