Skip to content

Commit

Permalink
Add PSK support
Browse files Browse the repository at this point in the history
  • Loading branch information
janvrska committed May 6, 2024
1 parent 677ce13 commit 6856b44
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 119 deletions.
27 changes: 23 additions & 4 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Use ClientConfig struct to set various client's options and features.
//
// Example:
//
// import (
// "fmt"
// "time"
Expand Down Expand Up @@ -78,10 +79,16 @@ import (
type ClientConfig struct {
// UseDTLS controls whether DTLS should be used to secure the connection
// to the MQTT-SN gateway.
UseDTLS bool
Certificate *tls.Certificate
PrivateKey crypto.PrivateKey
CACertificates []*x509.Certificate
UseDTLS bool
Certificate *tls.Certificate
PrivateKey crypto.PrivateKey
UsePSK bool
PSK map[string][]byte
PSKIdentity string
PSKApiBasicAuthUsername string
PSKApiBasicAuthPassword string
PSKApiEndpoint string
CACertificates []*x509.Certificate
// SelfSigned controls whether the client should use a self-signed
// certificate and key. If SelfSigned is false and UseDTLS is true, you
// must provide CertFile and KeyFile.
Expand Down Expand Up @@ -188,6 +195,18 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err
RootCAs: certPool,
}

if c.cfg.UsePSK && c.cfg.UseDTLS {
config.PSK = func(hint []byte) ([]byte, error) {
psk, ok := c.cfg.PSK[string(hint)]
if !ok {
return nil, errors.New("PSK not found")
}

return psk, nil
}
config.PSKIdentityHint = []byte(c.cfg.PSKIdentity)
}

// Connect to a DTLS server
ctx2, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
Expand Down
82 changes: 67 additions & 15 deletions cmd/bisquitt-pub/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"os/signal"
"path"
Expand Down Expand Up @@ -33,6 +36,11 @@ func handleAction() cli.ActionFunc {

useDTLS := c.Bool(DtlsFlag)
useSelfSigned := c.Bool(SelfSignedFlag)
usePSK := c.Bool(PskFlag)
pskIdentity := c.String(PskIdentityFlag)
pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag)
pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag)
pskApiEndpoint := c.String(PskApiEndpointFlag)
certFile := c.Path(CertFlag)
keyFile := c.Path(KeyFlag)
caFile := c.Path(CAFileFlag)
Expand Down Expand Up @@ -132,21 +140,26 @@ func handleAction() cli.ActionFunc {
password := []byte(c.String(PasswordFlag))

clientCfg := &snClient.ClientConfig{
ClientID: clientID,
UseDTLS: useDTLS,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
PredefinedTopics: predefinedTopics,
User: user,
Password: password,
ClientID: clientID,
UseDTLS: useDTLS,
UsePSK: usePSK,
PSKIdentity: pskIdentity,
PSKApiBasicAuthUsername: pskApiBasicAuthUsername,
PSKApiBasicAuthPassword: pskApiBasicAuthPassword,
PSKApiEndpoint: pskApiEndpoint,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
PredefinedTopics: predefinedTopics,
User: user,
Password: password,
}

var logger util.Logger
Expand All @@ -157,6 +170,10 @@ func handleAction() cli.ActionFunc {
}
defer logger.Sync()

if usePSK && useDTLS {
go loadPsk(clientCfg, logger)
}

client := snClient.NewClient(logger, clientCfg)

signalCh := make(chan os.Signal, 1)
Expand Down Expand Up @@ -208,3 +225,38 @@ func handleAction() cli.ActionFunc {
return nil
}
}

func loadPsk(snClientConfig *snClient.ClientConfig, logger util.Logger) {
req, err := http.NewRequest("GET", snClientConfig.PSKApiEndpoint, nil)
if err != nil {
logger.Error("Error in creating request: %s", err)
}

req.SetBasicAuth(snClientConfig.PSKApiBasicAuthUsername, snClientConfig.PSKApiBasicAuthPassword)
client := &http.Client{}

resp, err := client.Do(req)
if err != nil {
logger.Error("Error in sending request: %s", err)
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error("Error in reading response body: %s", err)
}

type Response struct {
Data map[string][]byte `json:"data"`
}

var response Response

err = json.Unmarshal(body, &response)
if err != nil {
logger.Error("Error in unmarshalling response body: %s", err)
}

snClientConfig.PSK = response.Data
}
78 changes: 59 additions & 19 deletions cmd/bisquitt-pub/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,30 @@ import (
)

const (
HostFlag = "host"
PortFlag = "port"
DtlsFlag = "dtls"
SelfSignedFlag = "self-signed"
CertFlag = "cert"
KeyFlag = "key"
CAFileFlag = "cafile"
CAPathFlag = "capath"
InsecureFlag = "insecure"
DebugFlag = "debug"
TopicFlag = "topic"
MessageFlag = "message"
RetainFlag = "retain"
PredefinedTopicFlag = "predefined-topic"
PredefinedTopicsFileFlag = "predefined-topics-file"
QOSFlag = "qos"
ClientIDFlag = "client-id"
UserFlag = "user"
PasswordFlag = "password"
HostFlag = "host"
PortFlag = "port"
DtlsFlag = "dtls"
SelfSignedFlag = "self-signed"
PskFlag = "psk"
PskIdentityFlag = "psk-identity"
PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username"
PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password"
PskApiEndpointFlag = "psk-api-endpoint"
CertFlag = "cert"
KeyFlag = "key"
CAFileFlag = "cafile"
CAPathFlag = "capath"
InsecureFlag = "insecure"
DebugFlag = "debug"
TopicFlag = "topic"
MessageFlag = "message"
RetainFlag = "retain"
PredefinedTopicFlag = "predefined-topic"
PredefinedTopicsFileFlag = "predefined-topics-file"
QOSFlag = "qos"
ClientIDFlag = "client-id"
UserFlag = "user"
PasswordFlag = "password"
)

func init() {
Expand Down Expand Up @@ -76,6 +81,41 @@ var Application = cli.App{
"SELF_SIGNED",
},
},
&cli.BoolFlag{
Name: PskFlag,
Usage: "use PSK",
EnvVars: []string{
"PSK_ENABLED",
},
},
&cli.StringFlag{
Name: PskIdentityFlag,
Usage: "PSK identity",
EnvVars: []string{
"PSK_IDENTITY",
},
},
&cli.StringFlag{
Name: PskApiBasicAuthUsernameFlag,
Usage: "PSK API basic auth username",
EnvVars: []string{
"PSK_API_BASIC_AUTH_USERNAME",
},
},
&cli.StringFlag{
Name: PskApiBasicAuthPasswordFlag,
Usage: "PSK API basic auth password",
EnvVars: []string{
"PSK_API_BASIC_AUTH_PASSWORD",
},
},
&cli.StringFlag{
Name: PskApiEndpointFlag,
Usage: "PSK API endpoint",
EnvVars: []string{
"PSK_API_ENDPOINT",
},
},
&cli.PathFlag{
Name: CertFlag,
Usage: "DTLS certificate file",
Expand Down
82 changes: 67 additions & 15 deletions cmd/bisquitt-sub/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"os/signal"
"path"
Expand Down Expand Up @@ -34,6 +37,11 @@ func handleAction() cli.ActionFunc {

useDTLS := c.Bool(DtlsFlag)
useSelfSigned := c.Bool(SelfSignedFlag)
usePSK := c.Bool(PskFlag)
pskIdentity := c.String(PskIdentityFlag)
pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag)
pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag)
pskApiEndpoint := c.String(PskApiEndpointFlag)
certFile := c.Path(CertFlag)
keyFile := c.Path(KeyFlag)
caFile := c.Path(CAFileFlag)
Expand Down Expand Up @@ -133,21 +141,26 @@ func handleAction() cli.ActionFunc {
password := []byte(c.String(PasswordFlag))

clientCfg := &snClient.ClientConfig{
ClientID: clientID,
UseDTLS: useDTLS,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
PredefinedTopics: predefinedTopics,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
User: user,
Password: password,
ClientID: clientID,
UseDTLS: useDTLS,
SelfSigned: useSelfSigned,
UsePSK: usePSK,
PSKIdentity: pskIdentity,
PSKApiBasicAuthUsername: pskApiBasicAuthUsername,
PSKApiBasicAuthPassword: pskApiBasicAuthPassword,
PSKApiEndpoint: pskApiEndpoint,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
PredefinedTopics: predefinedTopics,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
User: user,
Password: password,
}

if c.IsSet(WillTopicFlag) {
Expand All @@ -174,6 +187,10 @@ func handleAction() cli.ActionFunc {
}
defer logger.Sync()

if usePSK && useDTLS {
go loadPsk(clientCfg, logger)
}

client := snClient.NewClient(logger, clientCfg)

signalCh := make(chan os.Signal, 1)
Expand Down Expand Up @@ -238,3 +255,38 @@ func handleAction() cli.ActionFunc {
return nil
}
}

func loadPsk(gwConfig *snClient.ClientConfig, logger util.Logger) {
req, err := http.NewRequest("GET", gwConfig.PSKApiEndpoint, nil)
if err != nil {
logger.Error("Error in creating request: %s", err)
}

req.SetBasicAuth(gwConfig.PSKApiBasicAuthUsername, gwConfig.PSKApiBasicAuthPassword)
client := &http.Client{}
resp, err := client.Do(req)

if err != nil {
logger.Error("Error in sending request: %s", err)
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error("Error in reading response body: %s", err)
}

type Response struct {
Data map[string][]byte `json:"data"`
}

var response Response

err = json.Unmarshal(body, &response)
if err != nil {
logger.Error("Error in unmarshalling response body: %s", err)
}

gwConfig.PSK = response.Data
}
Loading

0 comments on commit 6856b44

Please sign in to comment.