From 809dc7997a2b15127aae071a632539a7804cf300 Mon Sep 17 00:00:00 2001 From: janvrska <1644599+janvrska@users.noreply.github.com> Date: Mon, 6 May 2024 11:00:23 +0200 Subject: [PATCH] Add PSK support --- client/client.go | 121 ++++++++++++++++++++++++++----- cmd/bisquitt-pub/actions.go | 47 ++++++++---- cmd/bisquitt-pub/application.go | 88 +++++++++++++++++----- cmd/bisquitt-sub/actions.go | 46 ++++++++---- cmd/bisquitt-sub/application.go | 92 +++++++++++++++++------ cmd/bisquitt/actions.go | 42 +++++++---- cmd/bisquitt/application.go | 91 +++++++++++++++++------ gateway/gateway.go | 125 +++++++++++++++++++++++++++----- go.mod | 1 + go.sum | 2 + 10 files changed, 513 insertions(+), 142 deletions(-) diff --git a/client/client.go b/client/client.go index 73cba07..2a7eb90 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ // Use ClientConfig struct to set various client's options and features. // // Example: +// // import ( // "fmt" // "time" @@ -58,12 +59,16 @@ import ( "crypto" "crypto/tls" "crypto/x509" + "encoding/json" "errors" "fmt" + "io" "net" + "net/http" "sync" "time" + "github.com/patrickmn/go-cache" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "golang.org/x/sync/errgroup" @@ -78,10 +83,23 @@ import ( type ClientConfig struct { // UseDTLS controls whether DTLS should be used to secure the connection // to the MQTT-SN gateway. - UseDTLS bool - Certificate *tls.Certificate - PrivateKey crypto.PrivateKey - CACertificates []*x509.Certificate + UseDTLS bool + Certificate *tls.Certificate + PrivateKey crypto.PrivateKey + // UsePSK controls whether pre-shared key should be used to secure the + // connection to the MQTT-SN gateway. If UsePSK is true, you must provide + // PSKIdentity, PSKApiBasicAuthUsername, PSKApiBasicAuthPassword and + // PSKApiEndpoint. + // If UsePSK is true, the client will use PSK instead of the certificate + // and private key. + UsePSK bool + PSK *cache.Cache + PSKCacheExpiration time.Duration + PSKIdentity string + PSKApiBasicAuthUsername string + PSKApiBasicAuthPassword string + PSKApiEndpoint string + CACertificates []*x509.Certificate // SelfSigned controls whether the client should use a self-signed // certificate and key. If SelfSigned is false and UseDTLS is true, you // must provide CertFile and KeyFile. @@ -147,21 +165,24 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err var certificate *tls.Certificate var err error - if c.cfg.SelfSigned { - var cert tls.Certificate - cert, err = selfsign.GenerateSelfSigned() - certificate = &cert - } else { - privateKey := c.cfg.PrivateKey - if privateKey == nil { - err = errors.New("private key is missing") - } - if certificate = c.cfg.Certificate; certificate != nil { - certificate.PrivateKey = privateKey + if !c.cfg.UsePSK && c.cfg.UseDTLS { + if c.cfg.SelfSigned { + var cert tls.Certificate + cert, err = selfsign.GenerateSelfSigned() + certificate = &cert } else { - err = errors.New("TLS certificate is missing") + privateKey := c.cfg.PrivateKey + if privateKey == nil { + err = errors.New("private key is missing") + } + if certificate = c.cfg.Certificate; certificate != nil { + certificate.PrivateKey = privateKey + } else { + err = errors.New("TLS certificate is missing") + } } } + if err != nil { return nil, err } @@ -182,12 +203,34 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err // Prepare the configuration of the DTLS connection config := &dtls.Config{ - Certificates: []tls.Certificate{*certificate}, InsecureSkipVerify: c.cfg.Insecure, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, RootCAs: certPool, } + if !c.cfg.UsePSK && c.cfg.UseDTLS && certificate != nil { + config.Certificates = []tls.Certificate{*certificate} + } + + if c.cfg.UsePSK && c.cfg.UseDTLS { + config.CipherSuites = []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256} + config.PSK = func(hint []byte) ([]byte, error) { + psk, ok := c.cfg.PSK.Get(string(hint)) + if ok { + return psk.([]byte), nil + } + + psk, ok = c.getPSK(string(hint)) + if ok { + c.cfg.PSK.Set(string(hint), psk, c.cfg.PSKCacheExpiration) + return psk.([]byte), nil + } + + return nil, errors.New("PSK not found") + } + config.PSKIdentityHint = []byte(c.cfg.PSKIdentity) + } + // Connect to a DTLS server ctx2, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -556,3 +599,47 @@ func (c *Client) Disconnect() error { return c.group.Wait() } } + +func (c *Client) getPSK(hint string) ([]byte, bool) { + req, err := http.NewRequest("GET", fmt.Sprintf(c.cfg.PSKApiEndpoint+"/%s", hint), nil) + if err != nil { + c.log.Error("Error in creating request: %s", err) + return nil, false + } + + req.SetBasicAuth(c.cfg.PSKApiBasicAuthUsername, c.cfg.PSKApiBasicAuthPassword) + client := &http.Client{} + resp, err := client.Do(req) + + if err != nil { + c.log.Error("Error in sending request: %s", err) + return nil, false + } + + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + c.log.Debug("ID not found") + return nil, false + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + c.log.Error("Error in reading response body: %s", err) + return nil, false + } + + type Response struct { + Data map[string][]byte `json:"data"` + } + + var response Response + + err = json.Unmarshal(body, &response) + if err != nil { + c.log.Error("Error in unmarshalling response body: %s", err) + return nil, false + } + + return response.Data[hint], true +} diff --git a/cmd/bisquitt-pub/actions.go b/cmd/bisquitt-pub/actions.go index 93bcbf5..283b5fc 100644 --- a/cmd/bisquitt-pub/actions.go +++ b/cmd/bisquitt-pub/actions.go @@ -13,6 +13,7 @@ import ( "syscall" "time" + "github.com/patrickmn/go-cache" "github.com/urfave/cli/v2" snClient "github.com/energomonitor/bisquitt/client" @@ -33,13 +34,20 @@ func handleAction() cli.ActionFunc { useDTLS := c.Bool(DtlsFlag) useSelfSigned := c.Bool(SelfSignedFlag) + usePSK := c.Bool(PskFlag) + pskCacheExpiration := c.Duration(PskCacheExpirationFlag) + + pskIdentity := c.String(PskIdentityFlag) + pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag) + pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag) + pskApiEndpoint := c.String(PskApiEndpointFlag) certFile := c.Path(CertFlag) keyFile := c.Path(KeyFlag) caFile := c.Path(CAFileFlag) caPath := c.Path(CAPathFlag) debug := c.Bool(DebugFlag) - if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned { + if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK { return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`, CertFlag, KeyFlag, SelfSignedFlag) } @@ -132,21 +140,28 @@ func handleAction() cli.ActionFunc { password := []byte(c.String(PasswordFlag)) clientCfg := &snClient.ClientConfig{ - ClientID: clientID, - UseDTLS: useDTLS, - SelfSigned: useSelfSigned, - Insecure: insecure, - Certificate: certificate, - PrivateKey: privateKey, - CACertificates: caCertificates, - RetryDelay: 10 * time.Second, - RetryCount: 4, - ConnectTimeout: 20 * time.Second, - KeepAlive: 60 * time.Second, - CleanSession: true, - PredefinedTopics: predefinedTopics, - User: user, - Password: password, + ClientID: clientID, + UseDTLS: useDTLS, + UsePSK: usePSK, + PSK: cache.New(pskCacheExpiration, 5*time.Minute), + PSKCacheExpiration: pskCacheExpiration, + PSKIdentity: pskIdentity, + PSKApiBasicAuthUsername: pskApiBasicAuthUsername, + PSKApiBasicAuthPassword: pskApiBasicAuthPassword, + PSKApiEndpoint: pskApiEndpoint, + SelfSigned: useSelfSigned, + Insecure: insecure, + Certificate: certificate, + PrivateKey: privateKey, + CACertificates: caCertificates, + RetryDelay: 10 * time.Second, + RetryCount: 4, + ConnectTimeout: 20 * time.Second, + KeepAlive: 60 * time.Second, + CleanSession: true, + PredefinedTopics: predefinedTopics, + User: user, + Password: password, } var logger util.Logger diff --git a/cmd/bisquitt-pub/application.go b/cmd/bisquitt-pub/application.go index f32cb5b..b4630ca 100644 --- a/cmd/bisquitt-pub/application.go +++ b/cmd/bisquitt-pub/application.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "time" "github.com/urfave/cli/v2" @@ -9,25 +10,31 @@ import ( ) const ( - HostFlag = "host" - PortFlag = "port" - DtlsFlag = "dtls" - SelfSignedFlag = "self-signed" - CertFlag = "cert" - KeyFlag = "key" - CAFileFlag = "cafile" - CAPathFlag = "capath" - InsecureFlag = "insecure" - DebugFlag = "debug" - TopicFlag = "topic" - MessageFlag = "message" - RetainFlag = "retain" - PredefinedTopicFlag = "predefined-topic" - PredefinedTopicsFileFlag = "predefined-topics-file" - QOSFlag = "qos" - ClientIDFlag = "client-id" - UserFlag = "user" - PasswordFlag = "password" + HostFlag = "host" + PortFlag = "port" + DtlsFlag = "dtls" + SelfSignedFlag = "self-signed" + PskFlag = "psk" + PskCacheExpirationFlag = "psk-cache-expiration" + PskIdentityFlag = "psk-identity" + PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username" + PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password" + PskApiEndpointFlag = "psk-api-endpoint" + CertFlag = "cert" + KeyFlag = "key" + CAFileFlag = "cafile" + CAPathFlag = "capath" + InsecureFlag = "insecure" + DebugFlag = "debug" + TopicFlag = "topic" + MessageFlag = "message" + RetainFlag = "retain" + PredefinedTopicFlag = "predefined-topic" + PredefinedTopicsFileFlag = "predefined-topics-file" + QOSFlag = "qos" + ClientIDFlag = "client-id" + UserFlag = "user" + PasswordFlag = "password" ) func init() { @@ -76,6 +83,49 @@ var Application = cli.App{ "SELF_SIGNED", }, }, + &cli.BoolFlag{ + Name: PskFlag, + Usage: "use PSK", + EnvVars: []string{ + "PSK_ENABLED", + }, + }, + &cli.DurationFlag{ + Name: PskCacheExpirationFlag, + Usage: "PSK cache expiration", + Value: 5 * time.Minute, + EnvVars: []string{ + "PSK_CACHE_EXPIRATION", + }, + }, + &cli.StringFlag{ + Name: PskIdentityFlag, + Usage: "PSK identity", + EnvVars: []string{ + "PSK_IDENTITY", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthUsernameFlag, + Usage: "PSK API basic auth username", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_USERNAME", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthPasswordFlag, + Usage: "PSK API basic auth password", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_PASSWORD", + }, + }, + &cli.StringFlag{ + Name: PskApiEndpointFlag, + Usage: "PSK API endpoint", + EnvVars: []string{ + "PSK_API_ENDPOINT", + }, + }, &cli.PathFlag{ Name: CertFlag, Usage: "DTLS certificate file", diff --git a/cmd/bisquitt-sub/actions.go b/cmd/bisquitt-sub/actions.go index 7224c02..7915748 100644 --- a/cmd/bisquitt-sub/actions.go +++ b/cmd/bisquitt-sub/actions.go @@ -15,6 +15,7 @@ import ( "syscall" "time" + "github.com/patrickmn/go-cache" "github.com/urfave/cli/v2" snClient "github.com/energomonitor/bisquitt/client" @@ -34,13 +35,19 @@ func handleAction() cli.ActionFunc { useDTLS := c.Bool(DtlsFlag) useSelfSigned := c.Bool(SelfSignedFlag) + usePSK := c.Bool(PskFlag) + pskCacheExpiration := c.Duration(PskCacheExpirationFlag) + pskIdentity := c.String(PskIdentityFlag) + pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag) + pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag) + pskApiEndpoint := c.String(PskApiEndpointFlag) certFile := c.Path(CertFlag) keyFile := c.Path(KeyFlag) caFile := c.Path(CAFileFlag) caPath := c.Path(CAPathFlag) debug := c.Bool(DebugFlag) - if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned { + if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK { return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`, CertFlag, KeyFlag, SelfSignedFlag) } @@ -133,21 +140,28 @@ func handleAction() cli.ActionFunc { password := []byte(c.String(PasswordFlag)) clientCfg := &snClient.ClientConfig{ - ClientID: clientID, - UseDTLS: useDTLS, - SelfSigned: useSelfSigned, - Insecure: insecure, - Certificate: certificate, - PrivateKey: privateKey, - CACertificates: caCertificates, - PredefinedTopics: predefinedTopics, - RetryDelay: 10 * time.Second, - RetryCount: 4, - ConnectTimeout: 20 * time.Second, - KeepAlive: 60 * time.Second, - CleanSession: true, - User: user, - Password: password, + ClientID: clientID, + UseDTLS: useDTLS, + SelfSigned: useSelfSigned, + UsePSK: usePSK, + PSK: cache.New(pskCacheExpiration, 5*time.Minute), + PSKCacheExpiration: pskCacheExpiration, + PSKIdentity: pskIdentity, + PSKApiBasicAuthUsername: pskApiBasicAuthUsername, + PSKApiBasicAuthPassword: pskApiBasicAuthPassword, + PSKApiEndpoint: pskApiEndpoint, + Insecure: insecure, + Certificate: certificate, + PrivateKey: privateKey, + CACertificates: caCertificates, + PredefinedTopics: predefinedTopics, + RetryDelay: 10 * time.Second, + RetryCount: 4, + ConnectTimeout: 20 * time.Second, + KeepAlive: 60 * time.Second, + CleanSession: true, + User: user, + Password: password, } if c.IsSet(WillTopicFlag) { diff --git a/cmd/bisquitt-sub/application.go b/cmd/bisquitt-sub/application.go index 1d8a889..714cdc5 100644 --- a/cmd/bisquitt-sub/application.go +++ b/cmd/bisquitt-sub/application.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "time" "github.com/urfave/cli/v2" @@ -9,27 +10,33 @@ import ( ) const ( - HostFlag = "host" - PortFlag = "port" - DtlsFlag = "dtls" - SelfSignedFlag = "self-signed" - CertFlag = "cert" - KeyFlag = "key" - CAFileFlag = "cafile" - CAPathFlag = "capath" - InsecureFlag = "insecure" - DebugFlag = "debug" - TopicFlag = "topic" - PredefinedTopicFlag = "predefined-topic" - PredefinedTopicsFileFlag = "predefined-topics-file" - QOSFlag = "qos" - ClientIDFlag = "client-id" - WillTopicFlag = "will-topic" - WillMessageFlag = "will-message" - WillQOSFlag = "will-qos" - WillRetainFlag = "will-retain" - UserFlag = "user" - PasswordFlag = "password" + HostFlag = "host" + PortFlag = "port" + DtlsFlag = "dtls" + SelfSignedFlag = "self-signed" + PskFlag = "psk" + PskCacheExpirationFlag = "psk-cache-expiration" + PskIdentityFlag = "psk-identity" + PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username" + PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password" + PskApiEndpointFlag = "psk-api-endpoint" + CertFlag = "cert" + KeyFlag = "key" + CAFileFlag = "cafile" + CAPathFlag = "capath" + InsecureFlag = "insecure" + DebugFlag = "debug" + TopicFlag = "topic" + PredefinedTopicFlag = "predefined-topic" + PredefinedTopicsFileFlag = "predefined-topics-file" + QOSFlag = "qos" + ClientIDFlag = "client-id" + WillTopicFlag = "will-topic" + WillMessageFlag = "will-message" + WillQOSFlag = "will-qos" + WillRetainFlag = "will-retain" + UserFlag = "user" + PasswordFlag = "password" ) func init() { @@ -71,6 +78,49 @@ var Application = cli.App{ "DTLS_ENABLED", }, }, + &cli.BoolFlag{ + Name: PskFlag, + Usage: "use PSK", + EnvVars: []string{ + "PSK_ENABLED", + }, + }, + &cli.DurationFlag{ + Name: PskCacheExpirationFlag, + Value: 5 * time.Minute, + Usage: "PSK cache expiration", + EnvVars: []string{ + "PSK_CACHE_EXPIRATION", + }, + }, + &cli.StringFlag{ + Name: PskIdentityFlag, + Usage: "PSK identity", + EnvVars: []string{ + "PSK_IDENTITY", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthUsernameFlag, + Usage: "PSK API basic auth username", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_USERNAME", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthPasswordFlag, + Usage: "PSK API basic auth password", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_PASSWORD", + }, + }, + &cli.StringFlag{ + Name: PskApiEndpointFlag, + Usage: "PSK API endpoint", + EnvVars: []string{ + "PSK_API_ENDPOINT", + }, + }, &cli.BoolFlag{ Name: SelfSignedFlag, Usage: "generate self-signed certificate", diff --git a/cmd/bisquitt/actions.go b/cmd/bisquitt/actions.go index 38b8e50..4696664 100644 --- a/cmd/bisquitt/actions.go +++ b/cmd/bisquitt/actions.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/patrickmn/go-cache" "github.com/urfave/cli/v2" "github.com/energomonitor/bisquitt/gateway" @@ -25,12 +26,18 @@ func handleAction() cli.ActionFunc { return func(c *cli.Context) error { useDTLS := c.Bool(DtlsFlag) useSelfSigned := c.Bool(SelfSignedFlag) + usePSK := c.Bool(PskFlag) + pskCacheExpiration := c.Duration(PskCacheExpirationFlag) + pskIdentity := c.String(PskIdentityFlag) + pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag) + pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag) + pskApiEndpoint := c.String(PskApiEndpointFlag) certFile := c.Path(CertFlag) keyFile := c.Path(KeyFlag) debug := c.Bool(DebugFlag) syslog := c.Bool(SyslogFlag) - if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned { + if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK { return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`, CertFlag, KeyFlag, SelfSignedFlag) } @@ -116,19 +123,26 @@ func handleAction() cli.ActionFunc { performanceLogTime := c.Duration(PerformanceLogTimeFlag) gwConfig := &gateway.GatewayConfig{ - MqttBrokerAddress: mqttBrokerAddress, - MqttConnectionTimeout: mqttConnectionTimeout, - MqttUser: mqttUser, - MqttPassword: mqttPassword, - UseDTLS: useDTLS, - SelfSigned: useSelfSigned, - Certificate: certificate, - PrivateKey: privateKey, - PerformanceLogTime: performanceLogTime, - PredefinedTopics: predefinedTopics, - AuthEnabled: authEnabled, - RetryDelay: 10 * time.Second, - RetryCount: 4, + MqttBrokerAddress: mqttBrokerAddress, + MqttConnectionTimeout: mqttConnectionTimeout, + MqttUser: mqttUser, + MqttPassword: mqttPassword, + UseDTLS: useDTLS, + UsePSK: usePSK, + PSK: cache.New(pskCacheExpiration, 5*time.Minute), + PSKCacheExpiration: pskCacheExpiration, + PSKIdentity: pskIdentity, + PSKApiBasicAuthUsername: pskApiBasicAuthUsername, + PSKApiBasicAuthPassword: pskApiBasicAuthPassword, + PSKApiEndpoint: pskApiEndpoint, + SelfSigned: useSelfSigned, + Certificate: certificate, + PrivateKey: privateKey, + PerformanceLogTime: performanceLogTime, + PredefinedTopics: predefinedTopics, + AuthEnabled: authEnabled, + RetryDelay: 10 * time.Second, + RetryCount: 4, } logTag := "gw" diff --git a/cmd/bisquitt/application.go b/cmd/bisquitt/application.go index ae4bfbb..096dfd8 100644 --- a/cmd/bisquitt/application.go +++ b/cmd/bisquitt/application.go @@ -11,27 +11,33 @@ import ( ) const ( - MqttHostFlag = "mqtt-host" - MqttPortFlag = "mqtt-port" - MqttUserFlag = "mqtt-user" - MqttPasswordFlag = "mqtt-password" - MqttPasswordFileFlag = "mqtt-password-file" - MqttTimeoutFlag = "mqtt-timeout" - HostFlag = "host" - PortFlag = "port" - DtlsFlag = "dtls" - SelfSignedFlag = "self-signed" - CertFlag = "cert" - KeyFlag = "key" - PredefinedTopicFlag = "predefined-topic" - PredefinedTopicsFileFlag = "predefined-topics-file" - SyslogFlag = "syslog" - DebugFlag = "debug" - PerformanceLogTimeFlag = "performance-log-time" - InsecureFlag = "insecure" - AuthFlag = "auth" - UserFlag = "user" - GroupFlag = "group" + MqttHostFlag = "mqtt-host" + MqttPortFlag = "mqtt-port" + MqttUserFlag = "mqtt-user" + MqttPasswordFlag = "mqtt-password" + MqttPasswordFileFlag = "mqtt-password-file" + MqttTimeoutFlag = "mqtt-timeout" + HostFlag = "host" + PortFlag = "port" + DtlsFlag = "dtls" + PskFlag = "psk" + PskCacheExpirationFlag = "psk-cache-expiration" + PskIdentityFlag = "psk-identity" + PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username" + PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password" + PskApiEndpointFlag = "psk-api-endpoint" + SelfSignedFlag = "self-signed" + CertFlag = "cert" + KeyFlag = "key" + PredefinedTopicFlag = "predefined-topic" + PredefinedTopicsFileFlag = "predefined-topics-file" + SyslogFlag = "syslog" + DebugFlag = "debug" + PerformanceLogTimeFlag = "performance-log-time" + InsecureFlag = "insecure" + AuthFlag = "auth" + UserFlag = "user" + GroupFlag = "group" ) var Application = cli.App{ @@ -109,6 +115,49 @@ var Application = cli.App{ "DTLS_ENABLED", }, }, + &cli.BoolFlag{ + Name: PskFlag, + Usage: "use PSK", + EnvVars: []string{ + "PSK_ENABLED", + }, + }, + &cli.DurationFlag{ + Name: PskCacheExpirationFlag, + Value: 5 * time.Minute, + Usage: "PSK cache expiration", + EnvVars: []string{ + "PSK_CACHE_EXPIRATION", + }, + }, + &cli.StringFlag{ + Name: PskIdentityFlag, + Usage: "PSK identity", + EnvVars: []string{ + "PSK_IDENTITY", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthUsernameFlag, + Usage: "PSK API basic auth username", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_USERNAME", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthPasswordFlag, + Usage: "PSK API basic auth password", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_PASSWORD", + }, + }, + &cli.StringFlag{ + Name: PskApiEndpointFlag, + Usage: "PSK API endpoint", + EnvVars: []string{ + "PSK_API_ENDPOINT", + }, + }, &cli.BoolFlag{ Name: SelfSignedFlag, Usage: "generate self-signed certificate", diff --git a/gateway/gateway.go b/gateway/gateway.go index 7efb5db..d8b0f5a 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -6,11 +6,15 @@ import ( "context" "crypto" "crypto/tls" + "encoding/json" "errors" "fmt" + "io" "net" + "net/http" "time" + "github.com/patrickmn/go-cache" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/udp" @@ -26,11 +30,24 @@ type GatewayConfig struct { MqttPassword []byte UseDTLS bool SelfSigned bool - Certificate *tls.Certificate - PrivateKey crypto.PrivateKey - PerformanceLogTime time.Duration - PredefinedTopics topics.PredefinedTopics - AuthEnabled bool + // UsePSK controls whether pre-shared key should be used to secure the + // connection to the MQTT-SN gateway. If UsePSK is true, you must provide + // PSKIdentity, PSKApiBasicAuthUsername, PSKApiBasicAuthPassword and + // PSKApiEndpoint. + // If UsePSK is true, the client will use PSK instead of the certificate + // and private key. + UsePSK bool + PSK *cache.Cache + PSKCacheExpiration time.Duration + PSKIdentity string + PSKApiBasicAuthUsername string + PSKApiBasicAuthPassword string + PSKApiEndpoint string + Certificate *tls.Certificate + PrivateKey crypto.PrivateKey + PerformanceLogTime time.Duration + PredefinedTopics topics.PredefinedTopics + AuthEnabled bool // TRetry in MQTT-SN specification RetryDelay time.Duration // NRetry in MQTT-SN specification @@ -56,32 +73,60 @@ func newDTLSListener(ctx context.Context, cfg *GatewayConfig, address *net.UDPAd var certificate *tls.Certificate var err error - if cfg.SelfSigned { - var cert tls.Certificate - cert, err = selfsign.GenerateSelfSigned() - certificate = &cert - } else { - privateKey := cfg.PrivateKey - if privateKey == nil { - err = errors.New("private key is missing") - } - if certificate = cfg.Certificate; certificate != nil { - certificate.PrivateKey = privateKey + logger := util.NewProductionLogger("gateway") + + if !cfg.UsePSK && cfg.UseDTLS { + if cfg.SelfSigned { + var cert tls.Certificate + cert, err = selfsign.GenerateSelfSigned() + certificate = &cert } else { - err = errors.New("TLS certificate is missing") + privateKey := cfg.PrivateKey + if privateKey == nil { + err = errors.New("private key is missing") + } + if certificate = cfg.Certificate; certificate != nil { + certificate.PrivateKey = privateKey + } else { + err = errors.New("TLS certificate is missing") + } } } + if err != nil { return nil, err } dtlsConfig := &dtls.Config{ - Certificates: []tls.Certificate{*certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(ctx, dtlsConnectTimeout) }, } + + if !cfg.UsePSK && cfg.UseDTLS && certificate != nil { + dtlsConfig.Certificates = []tls.Certificate{*certificate} + } + + if cfg.UsePSK && cfg.UseDTLS { + dtlsConfig.CipherSuites = []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256} + dtlsConfig.PSK = func(hint []byte) ([]byte, error) { + psk, ok := cfg.PSK.Get(string(hint)) + if ok { + return psk.([]byte), nil + } + + psk, ok = getPSK(string(hint), cfg, logger) + if ok { + cfg.PSK.Set(string(hint), psk, cfg.PSKCacheExpiration) + return psk.([]byte), nil + } + + return nil, errors.New("PSK not found") + } + dtlsConfig.PSKIdentityHint = []byte(cfg.PSKIdentity) + } + return dtls.Listen("udp", address, dtlsConfig) } @@ -154,3 +199,47 @@ func (gw *Gateway) ListenAndServe(ctx context.Context, address string) error { }() } } + +func getPSK(hint string, gwConfig *GatewayConfig, logger util.Logger) ([]byte, bool) { + req, err := http.NewRequest("GET", fmt.Sprintf(gwConfig.PSKApiEndpoint+"/%s", hint), nil) + if err != nil { + logger.Error("Error in creating request: %s", err) + return nil, false + } + + req.SetBasicAuth(gwConfig.PSKApiBasicAuthUsername, gwConfig.PSKApiBasicAuthPassword) + client := &http.Client{} + resp, err := client.Do(req) + + if err != nil { + logger.Error("Error in sending request: %s", err) + return nil, false + } + + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + logger.Debug("ID not found") + return nil, false + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.Error("Error in reading response body: %s", err) + return nil, false + } + + type Response struct { + Data map[string][]byte `json:"data"` + } + + var response Response + + err = json.Unmarshal(body, &response) + if err != nil { + logger.Error("Error in unmarshalling response body: %s", err) + return nil, false + } + + return response.Data[hint], true +} diff --git a/go.mod b/go.mod index ea84309..12c213e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.16 require ( github.com/eclipse/paho.mqtt.golang v1.3.5 + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pion/dtls/v2 v2.1.3 github.com/pion/udp v0.1.1 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 9e41fed..938ac64 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/eclipse/paho.mqtt.golang v1.3.5 h1:sWtmgNxYM9P2sP+xEItMozsR3w0cqZFlqnNN1bdl41Y= github.com/eclipse/paho.mqtt.golang v1.3.5/go.mod h1:eTzb4gxwwyWpqBUHGQZ4ABAV7+Jgm1PklsYT/eo8Hcc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pion/dtls/v2 v2.1.3 h1:3UF7udADqous+M2R5Uo2q/YaP4EzUoWKdfX2oscCUio= github.com/pion/dtls/v2 v2.1.3/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=