diff --git a/cmd/terraform-demux/main.go b/cmd/terraform-demux/main.go index a6860b7..df38e2e 100644 --- a/cmd/terraform-demux/main.go +++ b/cmd/terraform-demux/main.go @@ -27,12 +27,13 @@ func main() { log.Printf("terraform-demux version %s, using arch '%s'", version, arch) - if err := checkStateCommand(os.Args); err != nil { + newArgs, err := checkStateCommand(os.Args) + if err != nil { log.SetOutput(os.Stderr) log.Fatal("error: ", err) } - exitCode, err := wrapper.RunTerraform(os.Args[1:], arch) + exitCode, err := wrapper.RunTerraform(newArgs[1:], arch) if err != nil { log.SetOutput(os.Stderr) @@ -43,18 +44,23 @@ func main() { os.Exit(exitCode) } -func checkStateCommand(args []string) error { - if checkArgsExists(args, "state") && !checkArgsExists(args, "--force") { - return errors.New("--force flag is required for the 'state' command. Consider using Terraform configuration blocks (moved, import) instead") +func checkStateCommand(args []string) ([]string, error) { + if checkArgsExists(args, "state") > 0 { + force_pos := checkArgsExists(args, "--force") + if force_pos > 0 { + return append(args[:force_pos], args[force_pos+1:]...), nil + } else { + return args, errors.New("--force flag is required for the 'state' command. Consider using Terraform configuration blocks (moved, import) instead") + } } - return nil + return args, nil } -func checkArgsExists(args []string, cmd string) bool { - for _, arg := range args { +func checkArgsExists(args []string, cmd string) int { + for i, arg := range args { if arg == cmd { - return true + return i } } - return false + return -1 } diff --git a/cmd/terraform-demux/main_test.go b/cmd/terraform-demux/main_test.go index 01394b6..bda039a 100644 --- a/cmd/terraform-demux/main_test.go +++ b/cmd/terraform-demux/main_test.go @@ -1,31 +1,32 @@ package main import ( + "slices" "testing" ) func TestCheckStateCommand(t *testing.T) { t.Run("Valid state command with --force flag after state command", func(t *testing.T) { args := []string{"terraform", "state", "--force", "list"} - err := checkStateCommand(args) - if err != nil { + newArgs, err := checkStateCommand(args) + if err != nil || !slices.Equal(newArgs, []string{"terraform", "state", "list"}) { t.Errorf("Expected no error, got: %v", err) } }) t.Run("Valid state command with --force flag before state command", func(t *testing.T) { args := []string{"terraform", "--force", "state", "pull"} - err := checkStateCommand(args) - if err != nil { + newArgs, err := checkStateCommand(args) + if err != nil || !slices.Equal(newArgs, []string{"terraform", "state", "pull"}) { t.Errorf("Expected no error, got: %v", err) } }) t.Run("Invalid state command without --force flag", func(t *testing.T) { args := []string{"terraform", "state", "list"} - err := checkStateCommand(args) + _, err := checkStateCommand(args) expectedError := "--force flag is required for the 'state' command" - if err == nil || err.Error() != expectedError { + if err == nil { t.Errorf("Expected error: %s, got: %v", expectedError, err) } })