Skip to content

Commit

Permalink
Add PSK support
Browse files Browse the repository at this point in the history
  • Loading branch information
janvrska committed May 7, 2024
1 parent 677ce13 commit 5597414
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 148 deletions.
61 changes: 44 additions & 17 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Use ClientConfig struct to set various client's options and features.
//
// Example:
//
// import (
// "fmt"
// "time"
Expand Down Expand Up @@ -78,10 +79,16 @@ import (
type ClientConfig struct {
// UseDTLS controls whether DTLS should be used to secure the connection
// to the MQTT-SN gateway.
UseDTLS bool
Certificate *tls.Certificate
PrivateKey crypto.PrivateKey
CACertificates []*x509.Certificate
UseDTLS bool
Certificate *tls.Certificate
PrivateKey crypto.PrivateKey
UsePSK bool
PSK map[string][]byte
PSKIdentity string
PSKApiBasicAuthUsername string
PSKApiBasicAuthPassword string
PSKApiEndpoint string
CACertificates []*x509.Certificate
// SelfSigned controls whether the client should use a self-signed
// certificate and key. If SelfSigned is false and UseDTLS is true, you
// must provide CertFile and KeyFile.
Expand Down Expand Up @@ -147,21 +154,24 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err
var certificate *tls.Certificate
var err error

if c.cfg.SelfSigned {
var cert tls.Certificate
cert, err = selfsign.GenerateSelfSigned()
certificate = &cert
} else {
privateKey := c.cfg.PrivateKey
if privateKey == nil {
err = errors.New("private key is missing")
}
if certificate = c.cfg.Certificate; certificate != nil {
certificate.PrivateKey = privateKey
if !c.cfg.UsePSK && c.cfg.UseDTLS {
if c.cfg.SelfSigned {
var cert tls.Certificate
cert, err = selfsign.GenerateSelfSigned()
certificate = &cert
} else {
err = errors.New("TLS certificate is missing")
privateKey := c.cfg.PrivateKey
if privateKey == nil {
err = errors.New("private key is missing")
}
if certificate = c.cfg.Certificate; certificate != nil {
certificate.PrivateKey = privateKey
} else {
err = errors.New("TLS certificate is missing")
}
}
}

if err != nil {
return nil, err
}
Expand All @@ -182,12 +192,29 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
Certificates: []tls.Certificate{*certificate},
InsecureSkipVerify: c.cfg.Insecure,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
RootCAs: certPool,
}

if !c.cfg.UsePSK && c.cfg.UseDTLS && certificate != nil {
config.Certificates = []tls.Certificate{*certificate}
}

if c.cfg.UsePSK && c.cfg.UseDTLS {
c.log.Info("Using PSK with DTLS")
config.CipherSuites = []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256}
config.PSK = func(hint []byte) ([]byte, error) {
psk, ok := c.cfg.PSK[string(hint)]
if !ok {
return nil, errors.New("PSK not found")
}

return psk, nil
}
config.PSKIdentityHint = []byte(c.cfg.PSKIdentity)
}

// Connect to a DTLS server
ctx2, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
Expand Down
88 changes: 72 additions & 16 deletions cmd/bisquitt-pub/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"os/signal"
"path"
Expand Down Expand Up @@ -33,13 +36,18 @@ func handleAction() cli.ActionFunc {

useDTLS := c.Bool(DtlsFlag)
useSelfSigned := c.Bool(SelfSignedFlag)
usePSK := c.Bool(PskFlag)
pskIdentity := c.String(PskIdentityFlag)
pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag)
pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag)
pskApiEndpoint := c.String(PskApiEndpointFlag)
certFile := c.Path(CertFlag)
keyFile := c.Path(KeyFlag)
caFile := c.Path(CAFileFlag)
caPath := c.Path(CAPathFlag)
debug := c.Bool(DebugFlag)

if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned {
if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK {
return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`,
CertFlag, KeyFlag, SelfSignedFlag)
}
Expand Down Expand Up @@ -132,21 +140,26 @@ func handleAction() cli.ActionFunc {
password := []byte(c.String(PasswordFlag))

clientCfg := &snClient.ClientConfig{
ClientID: clientID,
UseDTLS: useDTLS,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
PredefinedTopics: predefinedTopics,
User: user,
Password: password,
ClientID: clientID,
UseDTLS: useDTLS,
UsePSK: usePSK,
PSKIdentity: pskIdentity,
PSKApiBasicAuthUsername: pskApiBasicAuthUsername,
PSKApiBasicAuthPassword: pskApiBasicAuthPassword,
PSKApiEndpoint: pskApiEndpoint,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
PredefinedTopics: predefinedTopics,
User: user,
Password: password,
}

var logger util.Logger
Expand All @@ -157,6 +170,10 @@ func handleAction() cli.ActionFunc {
}
defer logger.Sync()

if usePSK && useDTLS {
loadPsk(clientCfg, logger)
}

client := snClient.NewClient(logger, clientCfg)

signalCh := make(chan os.Signal, 1)
Expand Down Expand Up @@ -208,3 +225,42 @@ func handleAction() cli.ActionFunc {
return nil
}
}

func loadPsk(snClientConfig *snClient.ClientConfig, logger util.Logger) {
req, err := http.NewRequest("GET", snClientConfig.PSKApiEndpoint, nil)
if err != nil {
logger.Error("Error in creating request: %s", err)
return
}

req.SetBasicAuth(snClientConfig.PSKApiBasicAuthUsername, snClientConfig.PSKApiBasicAuthPassword)
client := &http.Client{}

resp, err := client.Do(req)
if err != nil {
logger.Error("Error in sending request: %s", err)
return
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error("Error in reading response body: %s", err)
return
}

type Response struct {
Data map[string][]byte `json:"data"`
}

var response Response

err = json.Unmarshal(body, &response)
if err != nil {
logger.Error("Error in unmarshalling response body: %s", err)
return
}

snClientConfig.PSK = response.Data
}
78 changes: 59 additions & 19 deletions cmd/bisquitt-pub/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,30 @@ import (
)

const (
HostFlag = "host"
PortFlag = "port"
DtlsFlag = "dtls"
SelfSignedFlag = "self-signed"
CertFlag = "cert"
KeyFlag = "key"
CAFileFlag = "cafile"
CAPathFlag = "capath"
InsecureFlag = "insecure"
DebugFlag = "debug"
TopicFlag = "topic"
MessageFlag = "message"
RetainFlag = "retain"
PredefinedTopicFlag = "predefined-topic"
PredefinedTopicsFileFlag = "predefined-topics-file"
QOSFlag = "qos"
ClientIDFlag = "client-id"
UserFlag = "user"
PasswordFlag = "password"
HostFlag = "host"
PortFlag = "port"
DtlsFlag = "dtls"
SelfSignedFlag = "self-signed"
PskFlag = "psk"
PskIdentityFlag = "psk-identity"
PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username"
PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password"
PskApiEndpointFlag = "psk-api-endpoint"
CertFlag = "cert"
KeyFlag = "key"
CAFileFlag = "cafile"
CAPathFlag = "capath"
InsecureFlag = "insecure"
DebugFlag = "debug"
TopicFlag = "topic"
MessageFlag = "message"
RetainFlag = "retain"
PredefinedTopicFlag = "predefined-topic"
PredefinedTopicsFileFlag = "predefined-topics-file"
QOSFlag = "qos"
ClientIDFlag = "client-id"
UserFlag = "user"
PasswordFlag = "password"
)

func init() {
Expand Down Expand Up @@ -76,6 +81,41 @@ var Application = cli.App{
"SELF_SIGNED",
},
},
&cli.BoolFlag{
Name: PskFlag,
Usage: "use PSK",
EnvVars: []string{
"PSK_ENABLED",
},
},
&cli.StringFlag{
Name: PskIdentityFlag,
Usage: "PSK identity",
EnvVars: []string{
"PSK_IDENTITY",
},
},
&cli.StringFlag{
Name: PskApiBasicAuthUsernameFlag,
Usage: "PSK API basic auth username",
EnvVars: []string{
"PSK_API_BASIC_AUTH_USERNAME",
},
},
&cli.StringFlag{
Name: PskApiBasicAuthPasswordFlag,
Usage: "PSK API basic auth password",
EnvVars: []string{
"PSK_API_BASIC_AUTH_PASSWORD",
},
},
&cli.StringFlag{
Name: PskApiEndpointFlag,
Usage: "PSK API endpoint",
EnvVars: []string{
"PSK_API_ENDPOINT",
},
},
&cli.PathFlag{
Name: CertFlag,
Usage: "DTLS certificate file",
Expand Down
Loading

0 comments on commit 5597414

Please sign in to comment.