Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multi region replica strike implementation #3

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ var (
"CCC-Taxonomy": {
Strikes.SQLFeatures,
Strikes.AutomatedBackups,
Strikes.MultiRegion,
// Strikes.VerticalScaling,
// Strikes.Replication,
// Strikes.MultiRegion,
// Strikes.BackupRecovery,
// Strikes.Encryption,
// Strikes.RBAC,
Expand Down
35 changes: 5 additions & 30 deletions strikes/AutomatedBackups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
Movements: make(map[string]raidengine.MovementResult),
}

// Movement
// Get Configuration
cfg, err := getAWSConfig()
if err != nil {
result.Message = err.Error()
Expand All @@ -37,10 +37,10 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
return
}

autmatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg)
result.Movements["CheckForDBInstanceAutomatedBackups"] = autmatedBackupsMovement
if !autmatedBackupsMovement.Passed {
result.Message = autmatedBackupsMovement.Message
automatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg)
result.Movements["CheckForDBInstanceAutomatedBackups"] = automatedBackupsMovement
if !automatedBackupsMovement.Passed {
result.Message = automatedBackupsMovement.Message
return
}

Expand All @@ -49,31 +49,6 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
return
}

func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) {
// check if the instance is available
result = raidengine.MovementResult{
Description: "Check if the instance is available/exists",
Function: utils.CallerPath(0),
}

rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()

input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(identifier),
}

instances, err := rdsClient.DescribeDBInstances(context.TODO(), input)
if err != nil {
// Handle error
result.Message = err.Error()
result.Passed = false
return
}
result.Passed = len(instances.DBInstances) > 0
return
}

func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.MovementResult) {

result = raidengine.MovementResult{
Expand Down
96 changes: 96 additions & 0 deletions strikes/MultiRegion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package strikes

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/rds"
"github.com/privateerproj/privateer-sdk/raidengine"
"github.com/privateerproj/privateer-sdk/utils"
)

func (a *Strikes) MultiRegion() (strikeName string, result raidengine.StrikeResult) {
strikeName = "MultiRegion"
result = raidengine.StrikeResult{
Passed: false,
Description: "Check if AWS RDS instance has multi region. This strike only checks for a read replica in a seperate region",
DocsURL: "https://www.github.com/krumIO/raid-rds",
ControlID: "CCC-Taxonomy-1",
Movements: make(map[string]raidengine.MovementResult),
}

// Get Configuration
cfg, err := getAWSConfig()
if err != nil {
result.Message = err.Error()
return
}

rdsInstanceMovement := checkRDSInstanceMovement(cfg)
result.Movements["CheckForDBInstance"] = rdsInstanceMovement
if !rdsInstanceMovement.Passed {
result.Message = rdsInstanceMovement.Message
return
}

multiRegionMovement := checkRDSMultiRegionMovement(cfg)
result.Movements["CheckForMultiRegionDBInstances"] = multiRegionMovement
if !multiRegionMovement.Passed {
result.Message = multiRegionMovement.Message
return
}

result.Passed = true
result.Message = "Completed Successfully"

return
}

func checkRDSMultiRegionMovement(cfg aws.Config) (result raidengine.MovementResult) {

result = raidengine.MovementResult{
Description: "Check if the instance has multi region enabled",
Function: utils.CallerPath(0),
}

rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()

input := &rds.DescribeDBInstanceAutomatedBackupsInput{
DBInstanceIdentifier: aws.String(identifier),
}

backups, err := rdsClient.DescribeDBInstanceAutomatedBackups(context.TODO(), input)
if err != nil {
result.Message = "Failed to fetch automated backups for instance " + identifier
result.Passed = false
return
}

var regions []string
for _, backup := range backups.DBInstanceAutomatedBackups {
regions = append(regions, *backup.Region)
}

// This checks if theres a read replica in a different region
if len(regions) > 0 {
hostDBRegion := getRDSRegion()
for _, region := range regions {
// region from the instances are in the form of "use2"
abbreviation, exists := AWS_REGIONS_ABBR[hostDBRegion]
if exists {
if region != abbreviation {
result.Passed = true
result.Message = "Completed Successfully"
return
}
}

}
}

result.Passed = false
result.Message = "Multi Region instances not found"
return

}
32 changes: 32 additions & 0 deletions strikes/MultiRegion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package strikes

import (
"encoding/json"
"fmt"
"testing"

"github.com/spf13/viper"
)

func TestMultiRegion(t *testing.T) {
viper.AddConfigPath("../")
viper.SetConfigName("config")
viper.SetConfigType("yaml")
err := viper.ReadInConfig()

if err != nil {
fmt.Println("Config file not found...")
return
}

strikes := Strikes{}
strikeName, result := strikes.MultiRegion()

fmt.Println(strikeName)
b, err := json.MarshalIndent(result, "", " ")
if err != nil {
fmt.Println(err)
}
fmt.Print(string(b))
fmt.Println()
}
58 changes: 57 additions & 1 deletion strikes/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,33 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/rds"
hclog "github.com/hashicorp/go-hclog"
"github.com/privateerproj/privateer-sdk/raidengine"
"github.com/privateerproj/privateer-sdk/utils"
"github.com/spf13/viper"
)

var (
AWS_REGIONS_ABBR = map[string]string{
"us-east-1": "use1",
"us-east-2": "use2",
"us-west-1": "usw1",
"us-west-2": "usw2",
"ca-central-1": "cac1",
"eu-west-1": "euw1",
"eu-west-2": "euw2",
"eu-central-1": "euc1",
"eu-north-1": "eun1",
"ap-northeast-1": "apne1",
"ap-northeast-2": "apne2",
"ap-southeast-1": "apse1",
"ap-southeast-2": "apse2",
"ap-south-1": "aps1",
"sa-east-1": "sae1",
}
)

type Strikes struct {
Log hclog.Logger
}
Expand All @@ -39,10 +60,15 @@ func getDBInstanceIdentifier() (string, error) {
return "", errors.New("database instance identifier must be set in the config file")
}

func getRDSRegion() string {
return viper.GetString("raids.RDS.aws.config.region")
}

func getAWSConfig() (cfg aws.Config, err error) {
if viper.IsSet("raids.RDS.aws.creds") &&
viper.IsSet("raids.RDS.aws.creds.aws_access_key") &&
viper.IsSet("raids.RDS.aws.creds.aws_secret_key") {
viper.IsSet("raids.RDS.aws.creds.aws_secret_key") &&
viper.IsSet("raids.RDS.aws.creds.aws_region") {

access_key := viper.GetString("raids.RDS.aws.creds.aws_access_key")
secret_key := viper.GetString("raids.RDS.aws.creds.aws_secret_key")
Expand All @@ -68,3 +94,33 @@ func connectToDb() (result raidengine.MovementResult) {
result.Passed = true
return
}

func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) {
// check if the instance is available
result = raidengine.MovementResult{
Description: "Check if the instance is available/exists",
Function: utils.CallerPath(0),
}

instance, err := getRDSInstance(cfg)
if err != nil {
// Handle error
result.Message = err.Error()
result.Passed = false
return
}
result.Passed = len(instance.DBInstances) > 0
return
}

func getRDSInstance(cfg aws.Config) (instance *rds.DescribeDBInstancesOutput, err error) {
rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()

input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(identifier),
}

instance, err = rdsClient.DescribeDBInstances(context.TODO(), input)
return
}
Loading