diff --git a/client/client.go b/client/client.go index 73cba07..2e192a3 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ // Use ClientConfig struct to set various client's options and features. // // Example: +// // import ( // "fmt" // "time" @@ -78,10 +79,16 @@ import ( type ClientConfig struct { // UseDTLS controls whether DTLS should be used to secure the connection // to the MQTT-SN gateway. - UseDTLS bool - Certificate *tls.Certificate - PrivateKey crypto.PrivateKey - CACertificates []*x509.Certificate + UseDTLS bool + Certificate *tls.Certificate + PrivateKey crypto.PrivateKey + UsePSK bool + PSK map[string][]byte + PSKIdentity string + PSKApiBasicAuthUsername string + PSKApiBasicAuthPassword string + PSKApiEndpoint string + CACertificates []*x509.Certificate // SelfSigned controls whether the client should use a self-signed // certificate and key. If SelfSigned is false and UseDTLS is true, you // must provide CertFile and KeyFile. @@ -147,21 +154,24 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err var certificate *tls.Certificate var err error - if c.cfg.SelfSigned { - var cert tls.Certificate - cert, err = selfsign.GenerateSelfSigned() - certificate = &cert - } else { - privateKey := c.cfg.PrivateKey - if privateKey == nil { - err = errors.New("private key is missing") - } - if certificate = c.cfg.Certificate; certificate != nil { - certificate.PrivateKey = privateKey + if !c.cfg.UsePSK && c.cfg.UseDTLS { + if c.cfg.SelfSigned { + var cert tls.Certificate + cert, err = selfsign.GenerateSelfSigned() + certificate = &cert } else { - err = errors.New("TLS certificate is missing") + privateKey := c.cfg.PrivateKey + if privateKey == nil { + err = errors.New("private key is missing") + } + if certificate = c.cfg.Certificate; certificate != nil { + certificate.PrivateKey = privateKey + } else { + err = errors.New("TLS certificate is missing") + } } } + if err != nil { return nil, err } @@ -182,12 +192,29 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err // Prepare the configuration of the DTLS connection config := &dtls.Config{ - Certificates: []tls.Certificate{*certificate}, InsecureSkipVerify: c.cfg.Insecure, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, RootCAs: certPool, } + if !c.cfg.UsePSK && c.cfg.UseDTLS && certificate != nil { + config.Certificates = []tls.Certificate{*certificate} + } + + if c.cfg.UsePSK && c.cfg.UseDTLS { + c.log.Info("Using PSK with DTLS") + config.CipherSuites = []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256} + config.PSK = func(hint []byte) ([]byte, error) { + psk, ok := c.cfg.PSK[string(hint)] + if !ok { + return nil, errors.New("PSK not found") + } + + return psk, nil + } + config.PSKIdentityHint = []byte(c.cfg.PSKIdentity) + } + // Connect to a DTLS server ctx2, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/cmd/bisquitt-pub/actions.go b/cmd/bisquitt-pub/actions.go index 93bcbf5..f8bada8 100644 --- a/cmd/bisquitt-pub/actions.go +++ b/cmd/bisquitt-pub/actions.go @@ -4,8 +4,11 @@ import ( "crypto" "crypto/tls" "crypto/x509" + "encoding/json" "fmt" + "io" "math/rand" + "net/http" "os" "os/signal" "path" @@ -33,13 +36,18 @@ func handleAction() cli.ActionFunc { useDTLS := c.Bool(DtlsFlag) useSelfSigned := c.Bool(SelfSignedFlag) + usePSK := c.Bool(PskFlag) + pskIdentity := c.String(PskIdentityFlag) + pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag) + pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag) + pskApiEndpoint := c.String(PskApiEndpointFlag) certFile := c.Path(CertFlag) keyFile := c.Path(KeyFlag) caFile := c.Path(CAFileFlag) caPath := c.Path(CAPathFlag) debug := c.Bool(DebugFlag) - if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned { + if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK { return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`, CertFlag, KeyFlag, SelfSignedFlag) } @@ -132,21 +140,26 @@ func handleAction() cli.ActionFunc { password := []byte(c.String(PasswordFlag)) clientCfg := &snClient.ClientConfig{ - ClientID: clientID, - UseDTLS: useDTLS, - SelfSigned: useSelfSigned, - Insecure: insecure, - Certificate: certificate, - PrivateKey: privateKey, - CACertificates: caCertificates, - RetryDelay: 10 * time.Second, - RetryCount: 4, - ConnectTimeout: 20 * time.Second, - KeepAlive: 60 * time.Second, - CleanSession: true, - PredefinedTopics: predefinedTopics, - User: user, - Password: password, + ClientID: clientID, + UseDTLS: useDTLS, + UsePSK: usePSK, + PSKIdentity: pskIdentity, + PSKApiBasicAuthUsername: pskApiBasicAuthUsername, + PSKApiBasicAuthPassword: pskApiBasicAuthPassword, + PSKApiEndpoint: pskApiEndpoint, + SelfSigned: useSelfSigned, + Insecure: insecure, + Certificate: certificate, + PrivateKey: privateKey, + CACertificates: caCertificates, + RetryDelay: 10 * time.Second, + RetryCount: 4, + ConnectTimeout: 20 * time.Second, + KeepAlive: 60 * time.Second, + CleanSession: true, + PredefinedTopics: predefinedTopics, + User: user, + Password: password, } var logger util.Logger @@ -157,6 +170,10 @@ func handleAction() cli.ActionFunc { } defer logger.Sync() + if usePSK && useDTLS { + loadPsk(clientCfg, logger) + } + client := snClient.NewClient(logger, clientCfg) signalCh := make(chan os.Signal, 1) @@ -208,3 +225,42 @@ func handleAction() cli.ActionFunc { return nil } } + +func loadPsk(snClientConfig *snClient.ClientConfig, logger util.Logger) { + req, err := http.NewRequest("GET", snClientConfig.PSKApiEndpoint, nil) + if err != nil { + logger.Error("Error in creating request: %s", err) + return + } + + req.SetBasicAuth(snClientConfig.PSKApiBasicAuthUsername, snClientConfig.PSKApiBasicAuthPassword) + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + logger.Error("Error in sending request: %s", err) + return + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.Error("Error in reading response body: %s", err) + return + } + + type Response struct { + Data map[string][]byte `json:"data"` + } + + var response Response + + err = json.Unmarshal(body, &response) + if err != nil { + logger.Error("Error in unmarshalling response body: %s", err) + return + } + + snClientConfig.PSK = response.Data +} diff --git a/cmd/bisquitt-pub/application.go b/cmd/bisquitt-pub/application.go index f32cb5b..f02bd99 100644 --- a/cmd/bisquitt-pub/application.go +++ b/cmd/bisquitt-pub/application.go @@ -9,25 +9,30 @@ import ( ) const ( - HostFlag = "host" - PortFlag = "port" - DtlsFlag = "dtls" - SelfSignedFlag = "self-signed" - CertFlag = "cert" - KeyFlag = "key" - CAFileFlag = "cafile" - CAPathFlag = "capath" - InsecureFlag = "insecure" - DebugFlag = "debug" - TopicFlag = "topic" - MessageFlag = "message" - RetainFlag = "retain" - PredefinedTopicFlag = "predefined-topic" - PredefinedTopicsFileFlag = "predefined-topics-file" - QOSFlag = "qos" - ClientIDFlag = "client-id" - UserFlag = "user" - PasswordFlag = "password" + HostFlag = "host" + PortFlag = "port" + DtlsFlag = "dtls" + SelfSignedFlag = "self-signed" + PskFlag = "psk" + PskIdentityFlag = "psk-identity" + PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username" + PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password" + PskApiEndpointFlag = "psk-api-endpoint" + CertFlag = "cert" + KeyFlag = "key" + CAFileFlag = "cafile" + CAPathFlag = "capath" + InsecureFlag = "insecure" + DebugFlag = "debug" + TopicFlag = "topic" + MessageFlag = "message" + RetainFlag = "retain" + PredefinedTopicFlag = "predefined-topic" + PredefinedTopicsFileFlag = "predefined-topics-file" + QOSFlag = "qos" + ClientIDFlag = "client-id" + UserFlag = "user" + PasswordFlag = "password" ) func init() { @@ -76,6 +81,41 @@ var Application = cli.App{ "SELF_SIGNED", }, }, + &cli.BoolFlag{ + Name: PskFlag, + Usage: "use PSK", + EnvVars: []string{ + "PSK_ENABLED", + }, + }, + &cli.StringFlag{ + Name: PskIdentityFlag, + Usage: "PSK identity", + EnvVars: []string{ + "PSK_IDENTITY", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthUsernameFlag, + Usage: "PSK API basic auth username", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_USERNAME", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthPasswordFlag, + Usage: "PSK API basic auth password", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_PASSWORD", + }, + }, + &cli.StringFlag{ + Name: PskApiEndpointFlag, + Usage: "PSK API endpoint", + EnvVars: []string{ + "PSK_API_ENDPOINT", + }, + }, &cli.PathFlag{ Name: CertFlag, Usage: "DTLS certificate file", diff --git a/cmd/bisquitt-sub/actions.go b/cmd/bisquitt-sub/actions.go index 7224c02..b3b8809 100644 --- a/cmd/bisquitt-sub/actions.go +++ b/cmd/bisquitt-sub/actions.go @@ -4,9 +4,12 @@ import ( "crypto" "crypto/tls" "crypto/x509" + "encoding/json" "errors" "fmt" + "io" "math/rand" + "net/http" "os" "os/signal" "path" @@ -34,13 +37,18 @@ func handleAction() cli.ActionFunc { useDTLS := c.Bool(DtlsFlag) useSelfSigned := c.Bool(SelfSignedFlag) + usePSK := c.Bool(PskFlag) + pskIdentity := c.String(PskIdentityFlag) + pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag) + pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag) + pskApiEndpoint := c.String(PskApiEndpointFlag) certFile := c.Path(CertFlag) keyFile := c.Path(KeyFlag) caFile := c.Path(CAFileFlag) caPath := c.Path(CAPathFlag) debug := c.Bool(DebugFlag) - if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned { + if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK { return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`, CertFlag, KeyFlag, SelfSignedFlag) } @@ -133,21 +141,26 @@ func handleAction() cli.ActionFunc { password := []byte(c.String(PasswordFlag)) clientCfg := &snClient.ClientConfig{ - ClientID: clientID, - UseDTLS: useDTLS, - SelfSigned: useSelfSigned, - Insecure: insecure, - Certificate: certificate, - PrivateKey: privateKey, - CACertificates: caCertificates, - PredefinedTopics: predefinedTopics, - RetryDelay: 10 * time.Second, - RetryCount: 4, - ConnectTimeout: 20 * time.Second, - KeepAlive: 60 * time.Second, - CleanSession: true, - User: user, - Password: password, + ClientID: clientID, + UseDTLS: useDTLS, + SelfSigned: useSelfSigned, + UsePSK: usePSK, + PSKIdentity: pskIdentity, + PSKApiBasicAuthUsername: pskApiBasicAuthUsername, + PSKApiBasicAuthPassword: pskApiBasicAuthPassword, + PSKApiEndpoint: pskApiEndpoint, + Insecure: insecure, + Certificate: certificate, + PrivateKey: privateKey, + CACertificates: caCertificates, + PredefinedTopics: predefinedTopics, + RetryDelay: 10 * time.Second, + RetryCount: 4, + ConnectTimeout: 20 * time.Second, + KeepAlive: 60 * time.Second, + CleanSession: true, + User: user, + Password: password, } if c.IsSet(WillTopicFlag) { @@ -174,6 +187,10 @@ func handleAction() cli.ActionFunc { } defer logger.Sync() + if usePSK && useDTLS { + loadPsk(clientCfg, logger) + } + client := snClient.NewClient(logger, clientCfg) signalCh := make(chan os.Signal, 1) @@ -238,3 +255,42 @@ func handleAction() cli.ActionFunc { return nil } } + +func loadPsk(gwConfig *snClient.ClientConfig, logger util.Logger) { + req, err := http.NewRequest("GET", gwConfig.PSKApiEndpoint, nil) + if err != nil { + logger.Error("Error in creating request: %s", err) + return + } + + req.SetBasicAuth(gwConfig.PSKApiBasicAuthUsername, gwConfig.PSKApiBasicAuthPassword) + client := &http.Client{} + resp, err := client.Do(req) + + if err != nil { + logger.Error("Error in sending request: %s", err) + return + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.Error("Error in reading response body: %s", err) + return + } + + type Response struct { + Data map[string][]byte `json:"data"` + } + + var response Response + + err = json.Unmarshal(body, &response) + if err != nil { + logger.Error("Error in unmarshalling response body: %s", err) + return + } + + gwConfig.PSK = response.Data +} diff --git a/cmd/bisquitt-sub/application.go b/cmd/bisquitt-sub/application.go index 1d8a889..630e9c7 100644 --- a/cmd/bisquitt-sub/application.go +++ b/cmd/bisquitt-sub/application.go @@ -9,27 +9,32 @@ import ( ) const ( - HostFlag = "host" - PortFlag = "port" - DtlsFlag = "dtls" - SelfSignedFlag = "self-signed" - CertFlag = "cert" - KeyFlag = "key" - CAFileFlag = "cafile" - CAPathFlag = "capath" - InsecureFlag = "insecure" - DebugFlag = "debug" - TopicFlag = "topic" - PredefinedTopicFlag = "predefined-topic" - PredefinedTopicsFileFlag = "predefined-topics-file" - QOSFlag = "qos" - ClientIDFlag = "client-id" - WillTopicFlag = "will-topic" - WillMessageFlag = "will-message" - WillQOSFlag = "will-qos" - WillRetainFlag = "will-retain" - UserFlag = "user" - PasswordFlag = "password" + HostFlag = "host" + PortFlag = "port" + DtlsFlag = "dtls" + SelfSignedFlag = "self-signed" + PskFlag = "psk" + PskIdentityFlag = "psk-identity" + PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username" + PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password" + PskApiEndpointFlag = "psk-api-endpoint" + CertFlag = "cert" + KeyFlag = "key" + CAFileFlag = "cafile" + CAPathFlag = "capath" + InsecureFlag = "insecure" + DebugFlag = "debug" + TopicFlag = "topic" + PredefinedTopicFlag = "predefined-topic" + PredefinedTopicsFileFlag = "predefined-topics-file" + QOSFlag = "qos" + ClientIDFlag = "client-id" + WillTopicFlag = "will-topic" + WillMessageFlag = "will-message" + WillQOSFlag = "will-qos" + WillRetainFlag = "will-retain" + UserFlag = "user" + PasswordFlag = "password" ) func init() { @@ -71,6 +76,41 @@ var Application = cli.App{ "DTLS_ENABLED", }, }, + &cli.BoolFlag{ + Name: PskFlag, + Usage: "use PSK", + EnvVars: []string{ + "PSK_ENABLED", + }, + }, + &cli.StringFlag{ + Name: PskIdentityFlag, + Usage: "PSK identity", + EnvVars: []string{ + "PSK_IDENTITY", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthUsernameFlag, + Usage: "PSK API basic auth username", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_USERNAME", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthPasswordFlag, + Usage: "PSK API basic auth password", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_PASSWORD", + }, + }, + &cli.StringFlag{ + Name: PskApiEndpointFlag, + Usage: "PSK API endpoint", + EnvVars: []string{ + "PSK_API_ENDPOINT", + }, + }, &cli.BoolFlag{ Name: SelfSignedFlag, Usage: "generate self-signed certificate", diff --git a/cmd/bisquitt/actions.go b/cmd/bisquitt/actions.go index 38b8e50..1895d45 100644 --- a/cmd/bisquitt/actions.go +++ b/cmd/bisquitt/actions.go @@ -4,9 +4,12 @@ import ( "context" "crypto" "crypto/tls" + "encoding/json" "fmt" + "io" "io/ioutil" "net" + "net/http" "os" "os/signal" "syscall" @@ -25,12 +28,17 @@ func handleAction() cli.ActionFunc { return func(c *cli.Context) error { useDTLS := c.Bool(DtlsFlag) useSelfSigned := c.Bool(SelfSignedFlag) + usePSK := c.Bool(PskFlag) + pskIdentity := c.String(PskIdentityFlag) + pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag) + pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag) + pskApiEndpoint := c.String(PskApiEndpointFlag) certFile := c.Path(CertFlag) keyFile := c.Path(KeyFlag) debug := c.Bool(DebugFlag) syslog := c.Bool(SyslogFlag) - if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned { + if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK { return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`, CertFlag, KeyFlag, SelfSignedFlag) } @@ -116,19 +124,24 @@ func handleAction() cli.ActionFunc { performanceLogTime := c.Duration(PerformanceLogTimeFlag) gwConfig := &gateway.GatewayConfig{ - MqttBrokerAddress: mqttBrokerAddress, - MqttConnectionTimeout: mqttConnectionTimeout, - MqttUser: mqttUser, - MqttPassword: mqttPassword, - UseDTLS: useDTLS, - SelfSigned: useSelfSigned, - Certificate: certificate, - PrivateKey: privateKey, - PerformanceLogTime: performanceLogTime, - PredefinedTopics: predefinedTopics, - AuthEnabled: authEnabled, - RetryDelay: 10 * time.Second, - RetryCount: 4, + MqttBrokerAddress: mqttBrokerAddress, + MqttConnectionTimeout: mqttConnectionTimeout, + MqttUser: mqttUser, + MqttPassword: mqttPassword, + UseDTLS: useDTLS, + UsePSK: usePSK, + PSKIdentity: pskIdentity, + PSKApiBasicAuthUsername: pskApiBasicAuthUsername, + PSKApiBasicAuthPassword: pskApiBasicAuthPassword, + PSKApiEndpoint: pskApiEndpoint, + SelfSigned: useSelfSigned, + Certificate: certificate, + PrivateKey: privateKey, + PerformanceLogTime: performanceLogTime, + PredefinedTopics: predefinedTopics, + AuthEnabled: authEnabled, + RetryDelay: 10 * time.Second, + RetryCount: 4, } logTag := "gw" @@ -194,8 +207,51 @@ func handleAction() cli.ActionFunc { logger.Info("switched to %s:%s", currentUser.Username, currentGroup.Name) } + if usePSK && useDTLS { + loadPsk(gwConfig, logger) + } + gw := gateway.NewGateway(logger, gwConfig) return gw.ListenAndServe(ctx, fmt.Sprintf("%s:%d", host, port)) } } + +func loadPsk(gwConfig *gateway.GatewayConfig, logger util.Logger) { + req, err := http.NewRequest("GET", gwConfig.PSKApiEndpoint, nil) + if err != nil { + logger.Error("Error in creating request: %s", err) + return + } + + req.SetBasicAuth(gwConfig.PSKApiBasicAuthUsername, gwConfig.PSKApiBasicAuthPassword) + client := &http.Client{} + resp, err := client.Do(req) + + if err != nil { + logger.Error("Error in sending request: %s", err) + return + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.Error("Error in reading response body: %s", err) + return + } + + type Response struct { + Data map[string][]byte `json:"data"` + } + + var response Response + + err = json.Unmarshal(body, &response) + if err != nil { + logger.Error("Error in unmarshalling response body: %s", err) + return + } + + gwConfig.PSK = response.Data +} diff --git a/cmd/bisquitt/application.go b/cmd/bisquitt/application.go index ae4bfbb..b77ea1a 100644 --- a/cmd/bisquitt/application.go +++ b/cmd/bisquitt/application.go @@ -11,27 +11,32 @@ import ( ) const ( - MqttHostFlag = "mqtt-host" - MqttPortFlag = "mqtt-port" - MqttUserFlag = "mqtt-user" - MqttPasswordFlag = "mqtt-password" - MqttPasswordFileFlag = "mqtt-password-file" - MqttTimeoutFlag = "mqtt-timeout" - HostFlag = "host" - PortFlag = "port" - DtlsFlag = "dtls" - SelfSignedFlag = "self-signed" - CertFlag = "cert" - KeyFlag = "key" - PredefinedTopicFlag = "predefined-topic" - PredefinedTopicsFileFlag = "predefined-topics-file" - SyslogFlag = "syslog" - DebugFlag = "debug" - PerformanceLogTimeFlag = "performance-log-time" - InsecureFlag = "insecure" - AuthFlag = "auth" - UserFlag = "user" - GroupFlag = "group" + MqttHostFlag = "mqtt-host" + MqttPortFlag = "mqtt-port" + MqttUserFlag = "mqtt-user" + MqttPasswordFlag = "mqtt-password" + MqttPasswordFileFlag = "mqtt-password-file" + MqttTimeoutFlag = "mqtt-timeout" + HostFlag = "host" + PortFlag = "port" + DtlsFlag = "dtls" + PskFlag = "psk" + PskIdentityFlag = "psk-identity" + PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username" + PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password" + PskApiEndpointFlag = "psk-api-endpoint" + SelfSignedFlag = "self-signed" + CertFlag = "cert" + KeyFlag = "key" + PredefinedTopicFlag = "predefined-topic" + PredefinedTopicsFileFlag = "predefined-topics-file" + SyslogFlag = "syslog" + DebugFlag = "debug" + PerformanceLogTimeFlag = "performance-log-time" + InsecureFlag = "insecure" + AuthFlag = "auth" + UserFlag = "user" + GroupFlag = "group" ) var Application = cli.App{ @@ -109,6 +114,41 @@ var Application = cli.App{ "DTLS_ENABLED", }, }, + &cli.BoolFlag{ + Name: PskFlag, + Usage: "use PSK", + EnvVars: []string{ + "PSK_ENABLED", + }, + }, + &cli.StringFlag{ + Name: PskIdentityFlag, + Usage: "PSK identity", + EnvVars: []string{ + "PSK_IDENTITY", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthUsernameFlag, + Usage: "PSK API basic auth username", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_USERNAME", + }, + }, + &cli.StringFlag{ + Name: PskApiBasicAuthPasswordFlag, + Usage: "PSK API basic auth password", + EnvVars: []string{ + "PSK_API_BASIC_AUTH_PASSWORD", + }, + }, + &cli.StringFlag{ + Name: PskApiEndpointFlag, + Usage: "PSK API endpoint", + EnvVars: []string{ + "PSK_API_ENDPOINT", + }, + }, &cli.BoolFlag{ Name: SelfSignedFlag, Usage: "generate self-signed certificate", diff --git a/gateway/gateway.go b/gateway/gateway.go index 7efb5db..dc0dbd2 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -20,17 +20,23 @@ import ( ) type GatewayConfig struct { - MqttBrokerAddress *net.TCPAddr - MqttConnectionTimeout time.Duration - MqttUser *string - MqttPassword []byte - UseDTLS bool - SelfSigned bool - Certificate *tls.Certificate - PrivateKey crypto.PrivateKey - PerformanceLogTime time.Duration - PredefinedTopics topics.PredefinedTopics - AuthEnabled bool + MqttBrokerAddress *net.TCPAddr + MqttConnectionTimeout time.Duration + MqttUser *string + MqttPassword []byte + UseDTLS bool + SelfSigned bool + UsePSK bool + PSK map[string][]byte + PSKIdentity string + PSKApiBasicAuthUsername string + PSKApiBasicAuthPassword string + PSKApiEndpoint string + Certificate *tls.Certificate + PrivateKey crypto.PrivateKey + PerformanceLogTime time.Duration + PredefinedTopics topics.PredefinedTopics + AuthEnabled bool // TRetry in MQTT-SN specification RetryDelay time.Duration // NRetry in MQTT-SN specification @@ -56,32 +62,53 @@ func newDTLSListener(ctx context.Context, cfg *GatewayConfig, address *net.UDPAd var certificate *tls.Certificate var err error - if cfg.SelfSigned { - var cert tls.Certificate - cert, err = selfsign.GenerateSelfSigned() - certificate = &cert - } else { - privateKey := cfg.PrivateKey - if privateKey == nil { - err = errors.New("private key is missing") - } - if certificate = cfg.Certificate; certificate != nil { - certificate.PrivateKey = privateKey + if !cfg.UsePSK && cfg.UseDTLS { + if cfg.SelfSigned { + var cert tls.Certificate + cert, err = selfsign.GenerateSelfSigned() + certificate = &cert } else { - err = errors.New("TLS certificate is missing") + privateKey := cfg.PrivateKey + if privateKey == nil { + err = errors.New("private key is missing") + } + if certificate = cfg.Certificate; certificate != nil { + certificate.PrivateKey = privateKey + } else { + err = errors.New("TLS certificate is missing") + } } } + if err != nil { return nil, err } dtlsConfig := &dtls.Config{ - Certificates: []tls.Certificate{*certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(ctx, dtlsConnectTimeout) }, } + + if !cfg.UsePSK && cfg.UseDTLS && certificate != nil { + dtlsConfig.Certificates = []tls.Certificate{*certificate} + } + + if cfg.UsePSK && cfg.UseDTLS { + fmt.Println("Using PSK") + dtlsConfig.CipherSuites = []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256} + dtlsConfig.PSK = func(hint []byte) ([]byte, error) { + psk, ok := cfg.PSK[string(hint)] + if !ok { + return nil, errors.New("PSK not found") + } + + return psk, nil + } + dtlsConfig.PSKIdentityHint = []byte(cfg.PSKIdentity) + } + return dtls.Listen("udp", address, dtlsConfig) }