Skip to content

Commit

Permalink
Add more tests for error cases in workflow_commands.go (#6442)
Browse files Browse the repository at this point in the history
* Add more tests for missing flags cases

* Add more tests
  • Loading branch information
neil-xie authored Oct 30, 2024
1 parent e412f8a commit 5af23ef
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 2 deletions.
3 changes: 1 addition & 2 deletions tools/cli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"reflect"
"regexp"
Expand Down Expand Up @@ -814,7 +813,7 @@ func processJSONInputHelper(c *cli.Context, jType jsonType) (string, error) {
inputFile := c.String(flagNameOfInputFileName)
// This method is purely used to parse input from the CLI. The input comes from a trusted user
// #nosec
data, err := ioutil.ReadFile(inputFile)
data, err := os.ReadFile(inputFile)
if err != nil {
return "", fmt.Errorf("error reading input file: %w", err)
}
Expand Down
276 changes: 276 additions & 0 deletions tools/cli/workflow_commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package cli

import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
Expand All @@ -42,6 +43,7 @@ import (
"github.com/uber/cadence/client/frontend"
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/types"
"github.com/uber/cadence/tools/cli/clitest"
)

func TestConstructStartWorkflowRequest(t *testing.T) {
Expand Down Expand Up @@ -1874,3 +1876,277 @@ func Test_RestartWorkflow_MissingFlags(t *testing.T) {
err = RestartWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))
}

func Test_DiagnoseWorkflow_MissingFlags(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
err := DiagnoseWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

set.String(FlagDomain, "test-domain", "domain")
err = DiagnoseWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))

set.String(FlagWorkflowID, "test-workflow-id", "workflow_id")
err = DiagnoseWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagRunID))
}

func Test_TerminateWorkflow_MissingFlags(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
err := TerminateWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

set.String(FlagDomain, "test-domain", "domain")
err = TerminateWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))
}

func Test_ShowHistory_MissingWorkflowID(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
err := ShowHistory(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))
}

func Test_ShowHistoryWithID_MissingWorkflowID(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
err := ShowHistoryWithWID(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))

set.Parse([]string{"test-workflow-id", "test-run-id"})
err = ShowHistoryWithWID(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))
}

func Test_ConstructStartWorkflowRequest_MissingFlags(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
_, err := constructStartWorkflowRequest(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

set.String(FlagDomain, "test-domain", "domain")
_, err = constructStartWorkflowRequest(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagTaskList))

set.String(FlagTaskList, "test-tasklist", "tasklist")
_, err = constructStartWorkflowRequest(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowType))

set.String(FlagWorkflowType, "test-workflow-type", "workflow_type")
_, err = constructStartWorkflowRequest(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s format is invalid", FlagExecutionTimeout))

set.String(FlagExecutionTimeout, "10", "execution_timeout")
set.Int(FlagWorkflowIDReusePolicy, 1, "workflowidreusepolicy")
_, err = constructStartWorkflowRequest(c)
assert.NoError(t, err)

// invalid workflowID reuse policy
ctx := clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, -10))
_, err = constructStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "value is not in supported range")

// process Json error
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, 1), clitest.StringArgument(FlagInput, "invalid json"))
_, err = constructStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "input is not valid JSON")

// error processing first run at
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, 1), clitest.StringArgument(FlagCronSchedule, "* * * * *"),
clitest.StringArgument(FirstRunAtTime, "10:00"))
_, err = constructStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "time format invalid")

// error processing header
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, 1), clitest.StringArgument(FlagCronSchedule, "* * * * *"),
clitest.StringArgument(FlagHeaderFile, "invalid file"))
_, err = constructStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "error when process header")

// error processing memo
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, 1), clitest.StringArgument(FlagCronSchedule, "* * * * *"),
clitest.StringArgument(FlagMemoFile, "invalid file"), clitest.StringArgument(FlagSearchAttributesKey, "key"))
_, err = constructStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "Error processing memo")

// error processing search attributes
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, 1), clitest.StringArgument(FlagCronSchedule, "* * * * *"),
clitest.StringArgument(FlagSearchAttributesKey, "key"))
_, err = constructStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "error processing search attributes")

ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.IntArgument(FlagWorkflowIDReusePolicy, 1), clitest.StringArgument(FlagCronSchedule, "* * * * *"),
clitest.StringArgument(FlagSearchAttributesKey, "key"), clitest.StringArgument(FlagSearchAttributesVal, "val"))
_, err = constructStartWorkflowRequest(ctx)
assert.NoError(t, err)
}

func Test_NewTest(t *testing.T) {

}

func Test_ProcessSearchAttr(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
set.String(FlagSearchAttributesKey, "key", "search attribute key")
c := cli.NewContext(app, set, nil)
_, err := processSearchAttr(c)
assert.ErrorContains(t, err, "keys and values are not equal")

set.String(FlagSearchAttributesVal, "value", "search attribute value")
resp, err := processSearchAttr(c)
assert.NoError(t, err)
expectedVal, _ := json.Marshal("value")
expectedResp := map[string][]byte{"key": expectedVal}
assert.Equal(t, expectedResp, resp)
}

func Test_CancelWorkflow_MissingFlags(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
err := CancelWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

set.String(FlagDomain, "test-domain", "domain")
err = CancelWorkflow(c)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))
}

func Test_QueryWorkflowHelper_MissingFlags(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
set := flag.NewFlagSet("test", 0)
c := cli.NewContext(app, set, nil)
err := queryWorkflowHelper(c, "")
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

set.String(FlagDomain, "test-domain", "domain")
err = queryWorkflowHelper(c, "")
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))

content := "wid1,wid2,wid3\n\nwid4,wid5\nwid6\n"
fileName, cleanup := createTempFileWithContent(t, content)
defer cleanup()
ctx := clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagWorkflowID, "test-workflow-id"),
clitest.StringArgument(FlagInputFile, fileName))
err = QueryWorkflowUsingQueryTypes(ctx)
assert.ErrorContains(t, err, "Error processing json")
}

func Test_ProcessJsonInputHelper(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
content := "wid1,wid2,wid3\n\nwid4,wid5\nwid6\n"
fileName, cleanup := createTempFileWithContent(t, content)
defer cleanup()

ctx := clitest.NewCLIContext(t, app, clitest.StringArgument(FlagInputFile, fileName))
_, err := processJSONInputHelper(ctx, jsonTypeInput)
assert.ErrorContains(t, err, "input is not valid JSON")

ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagInputFile, "non exist file"))
_, err = processJSONInputHelper(ctx, jsonTypeInput)
assert.ErrorContains(t, err, "error reading input file")

resp, err := processJSONInputHelper(ctx, -1)
assert.Equal(t, "", resp)
assert.NoError(t, err)
}

func Test_ConstructSignalWithStartWorkflowRequest_Errors(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
ctx := clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"))

_, err := constructSignalWithStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagName))

ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"), clitest.StringArgument(FlagTaskList, "test-tasklist"),
clitest.StringArgument(FlagWorkflowType, "test-workflow-type"), clitest.StringArgument(FlagExecutionTimeout, "10"),
clitest.StringArgument(FlagName, "test-signal-name"), clitest.StringArgument(FlagSignalInputFile, "invalid json"))
_, err = constructSignalWithStartWorkflowRequest(ctx)
assert.ErrorContains(t, err, "error processing json input signal")
}

func Test_ListWorkflow_Errors(t *testing.T) {
mockCtrl := gomock.NewController(t)
serverFrontendClient := frontend.NewMockClient(mockCtrl)
app := NewCliApp(&clientFactoryMock{
serverFrontendClient: serverFrontendClient,
})
ctx := clitest.NewCLIContext(t, app)
err := ListWorkflow(ctx)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

serverFrontendClient.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
serverFrontendClient.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
serverFrontendClient.EXPECT().ScanWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("test-error")).Times(1)
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"),
clitest.StringArgument(FlagWorkflowID, "test-workflow-id"), clitest.StringArgument(FlagExcludeWorkflowIDByQuery, "test-exclude"),
clitest.StringArgument(FlagListQuery, "test-query"))
err = ListWorkflow(ctx)
assert.ErrorContains(t, err, "test-error")
}

func Test_ListAllWorkflow_Errors(t *testing.T) {
mockCtrl := gomock.NewController(t)
serverFrontendClient := frontend.NewMockClient(mockCtrl)
app := NewCliApp(&clientFactoryMock{
serverFrontendClient: serverFrontendClient,
})
ctx := clitest.NewCLIContext(t, app)
err := ListAllWorkflow(ctx)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))

serverFrontendClient.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
serverFrontendClient.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
serverFrontendClient.EXPECT().ScanWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("test-error")).Times(1)
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"),
clitest.StringArgument(FlagWorkflowID, "test-workflow-id"), clitest.StringArgument(FlagExcludeWorkflowIDByQuery, "test-exclude"),
clitest.StringArgument(FlagListQuery, "test-query"))
err = ListAllWorkflow(ctx)
assert.ErrorContains(t, err, "test-error")
}

func Test_CountWorkflow_Errors(t *testing.T) {
mockCtrl := gomock.NewController(t)
serverFrontendClient := frontend.NewMockClient(mockCtrl)
app := NewCliApp(&clientFactoryMock{
serverFrontendClient: serverFrontendClient,
})
ctx := clitest.NewCLIContext(t, app)
err := CountWorkflow(ctx)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagDomain))
ctx = clitest.NewCLIContext(t, app, clitest.StringArgument(FlagDomain, "test-domain"))
serverFrontendClient.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("test-error")).AnyTimes()
err = CountWorkflow(ctx)
assert.ErrorContains(t, err, "test-error")
}

func Test_DescribeWorkflow_Errors(t *testing.T) {
app := NewCliApp(&clientFactoryMock{})
ctx := clitest.NewCLIContext(t, app)
err := DescribeWorkflow(ctx)
assert.ErrorContains(t, err, fmt.Sprintf("%s is required", FlagWorkflowID))
}

0 comments on commit 5af23ef

Please sign in to comment.