Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multi region replica strike implementation #3

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ var (
"CCC-Taxonomy": {
Strikes.SQLFeatures,
Strikes.AutomatedBackups,
Strikes.MultiRegion,
// Strikes.VerticalScaling,
// Strikes.Replication,
// Strikes.MultiRegion,
// Strikes.BackupRecovery,
// Strikes.Encryption,
// Strikes.RBAC,
Expand Down
1 change: 1 addition & 0 deletions example-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ raids:
config:
instance_identifier: unique-id-name
database: test
region: us-east-1
grudra7714 marked this conversation as resolved.
Show resolved Hide resolved
host: localhost
password: password
port: 3306
Expand Down
39 changes: 7 additions & 32 deletions strikes/AutomatedBackups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
Movements: make(map[string]raidengine.MovementResult),
}

// Movement
// Get Configuration
cfg, err := getAWSConfig()
if err != nil {
result.Message = err.Error()
Expand All @@ -37,10 +37,10 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
return
}

autmatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg)
result.Movements["CheckForDBInstanceAutomatedBackups"] = autmatedBackupsMovement
if !autmatedBackupsMovement.Passed {
result.Message = autmatedBackupsMovement.Message
automatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg)
result.Movements["CheckForDBInstanceAutomatedBackups"] = automatedBackupsMovement
if !automatedBackupsMovement.Passed {
result.Message = automatedBackupsMovement.Message
return
}

Expand All @@ -49,31 +49,6 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
return
}

func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) {
// check if the instance is available
result = raidengine.MovementResult{
Description: "Check if the instance is available/exists",
Function: utils.CallerPath(0),
}

rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()

input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(identifier),
}

instances, err := rdsClient.DescribeDBInstances(context.TODO(), input)
if err != nil {
// Handle error
result.Message = err.Error()
result.Passed = false
return
}
result.Passed = len(instances.DBInstances) > 0
return
}

func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.MovementResult) {

result = raidengine.MovementResult{
Expand All @@ -82,10 +57,10 @@ func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.Movement
}

rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()
instanceIdentifier, _ := getHostDBInstanceIdentifier()

input := &rds.DescribeDBInstanceAutomatedBackupsInput{
DBInstanceIdentifier: aws.String(identifier),
DBInstanceIdentifier: aws.String(instanceIdentifier),
}

backups, err := rdsClient.DescribeDBInstanceAutomatedBackups(context.TODO(), input)
Expand Down
98 changes: 98 additions & 0 deletions strikes/MultiRegion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package strikes

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/privateerproj/privateer-sdk/raidengine"
"github.com/privateerproj/privateer-sdk/utils"
)

func (a *Strikes) MultiRegion() (strikeName string, result raidengine.StrikeResult) {
strikeName = "MultiRegion"
result = raidengine.StrikeResult{
Passed: false,
Description: "Check if AWS RDS instance has multi region. This strike only checks for a read replica in a seperate region",
DocsURL: "https://www.github.com/krumIO/raid-rds",
ControlID: "CCC-Taxonomy-1",
Movements: make(map[string]raidengine.MovementResult),
}

// Get Configuration
cfg, err := getAWSConfig()
if err != nil {
result.Message = err.Error()
return
}

rdsInstanceMovement := checkRDSInstanceMovement(cfg)
result.Movements["CheckForDBInstance"] = rdsInstanceMovement
if !rdsInstanceMovement.Passed {
result.Message = rdsInstanceMovement.Message
return
}

multiRegionMovement := checkRDSMultiRegionMovement(cfg)
result.Movements["CheckForMultiRegionDBInstances"] = multiRegionMovement
if !multiRegionMovement.Passed {
result.Message = multiRegionMovement.Message
return
}

result.Passed = true
result.Message = "Completed Successfully"

return
}

func checkRDSMultiRegionMovement(cfg aws.Config) (result raidengine.MovementResult) {

result = raidengine.MovementResult{
Description: "Check if the instance has multi region enabled",
Function: utils.CallerPath(0),
}
instanceIdentifier, _ := getHostDBInstanceIdentifier()

instance, _ := getRDSInstanceFromIdentifier(cfg, instanceIdentifier)

// get read replicas from the instance
readReplicas := instance.DBInstances[0].ReadReplicaDBInstanceIdentifiers

if len(readReplicas) == 0 {
result.Passed = false
result.Message = "Multi Region instances not found"
return
}

hostRDSRegion, _ := getHostRDSRegion()

// loop over the read replicas and check if they are in a different region
for _, replica := range readReplicas {
// we are getting the instance identifier the read replicas
// get instance from the replica identifier
replicaInstance, err := getRDSInstanceFromIdentifier(cfg, replica)

if err != nil {
result.Passed = false
result.Message = err.Error()
return
}

if len(replicaInstance.DBInstances) == 0 {
result.Passed = false
result.Message = "Cannot access the replica instance " + replica
return
}

// check if replica region matches the host region
az := *replicaInstance.DBInstances[0].AvailabilityZone
// db instance doesnt contain the region so we need to remove the last character from the az
if az[:len(az)-1] == hostRDSRegion {
result.Passed = false
result.Message = "Multi Region instances not found"
return
}
}

result.Passed = true
return

}
32 changes: 32 additions & 0 deletions strikes/MultiRegion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package strikes

import (
"encoding/json"
"fmt"
"testing"

"github.com/spf13/viper"
)

func TestMultiRegion(t *testing.T) {
viper.AddConfigPath("../")
viper.SetConfigName("config")
viper.SetConfigType("yaml")
err := viper.ReadInConfig()

if err != nil {
fmt.Println("Config file not found...")
return
}

strikes := Strikes{}
strikeName, result := strikes.MultiRegion()

fmt.Println(strikeName)
b, err := json.MarshalIndent(result, "", " ")
if err != nil {
fmt.Println(err)
}
fmt.Print(string(b))
fmt.Println()
}
44 changes: 42 additions & 2 deletions strikes/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/rds"
hclog "github.com/hashicorp/go-hclog"
"github.com/privateerproj/privateer-sdk/raidengine"
"github.com/privateerproj/privateer-sdk/utils"
Expand All @@ -32,17 +33,25 @@ func getDBConfig() (string, error) {
return "", errors.New("database url must be set in the config file")
}

func getDBInstanceIdentifier() (string, error) {
func getHostDBInstanceIdentifier() (string, error) {
if viper.IsSet("raids.RDS.aws.config.instance_identifier") {
return viper.GetString("raids.RDS.aws.config.instance_identifier"), nil
}
return "", errors.New("database instance identifier must be set in the config file")
}

func getHostRDSRegion() (string, error) {
if viper.IsSet("raids.RDS.aws.config.region") {
return viper.GetString("raids.RDS.aws.config.region"), nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably for a later PR, but RDS is specifically Amazon, right? If anything, we'd probably do raids.aws.rds... 🤔 Noting this for later discussion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was more so reading this as raiding a database against different cloud provider. RDS term is aws specific, you are right, maybe we change that to DB or RDMS in a later pr. Also noting that we need to change the repo name as well

}
return "", errors.New("database instance identifier must be set in the config file")
}

func getAWSConfig() (cfg aws.Config, err error) {
if viper.IsSet("raids.RDS.aws.creds") &&
viper.IsSet("raids.RDS.aws.creds.aws_access_key") &&
viper.IsSet("raids.RDS.aws.creds.aws_secret_key") {
viper.IsSet("raids.RDS.aws.creds.aws_secret_key") &&
viper.IsSet("raids.RDS.aws.creds.aws_region") {

access_key := viper.GetString("raids.RDS.aws.creds.aws_access_key")
secret_key := viper.GetString("raids.RDS.aws.creds.aws_secret_key")
Expand All @@ -68,3 +77,34 @@ func connectToDb() (result raidengine.MovementResult) {
result.Passed = true
return
}

func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) {
// check if the instance is available
result = raidengine.MovementResult{
Description: "Check if the instance is available/exists",
Function: utils.CallerPath(0),
}

instanceIdentifier, _ := getHostDBInstanceIdentifier()

instance, err := getRDSInstanceFromIdentifier(cfg, instanceIdentifier)
if err != nil {
// Handle error
result.Message = err.Error()
result.Passed = false
return
}
result.Passed = len(instance.DBInstances) > 0
return
}

func getRDSInstanceFromIdentifier(cfg aws.Config, identifier string) (instance *rds.DescribeDBInstancesOutput, err error) {
rdsClient := rds.NewFromConfig(cfg)

input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(identifier),
}

instance, err = rdsClient.DescribeDBInstances(context.TODO(), input)
return
}
Loading