diff --git a/common/persistence/data_manager_interfaces.go b/common/persistence/data_manager_interfaces.go index 61faf199894..ba7b586c3de 100644 --- a/common/persistence/data_manager_interfaces.go +++ b/common/persistence/data_manager_interfaces.go @@ -20,7 +20,7 @@ // THE SOFTWARE. // Geneate rate limiter wrappers. -//go:generate mockgen -package $GOPACKAGE -destination dataManagerInterfaces_mock.go -self_package github.com/uber/cadence/common/persistence github.com/uber/cadence/common/persistence Task,ShardManager,ExecutionManager,ExecutionManagerFactory,TaskManager,HistoryManager,DomainManager,QueueManager,ConfigStoreManager +//go:generate mockgen -package $GOPACKAGE -destination data_manager_interfaces_mock.go -self_package github.com/uber/cadence/common/persistence github.com/uber/cadence/common/persistence Task,ShardManager,ExecutionManager,ExecutionManagerFactory,TaskManager,HistoryManager,DomainManager,QueueManager,ConfigStoreManager //go:generate gowrap gen -g -p . -i ConfigStoreManager -t ./wrappers/templates/ratelimited.tmpl -o wrappers/ratelimited/configstore_generated.go //go:generate gowrap gen -g -p . -i DomainManager -t ./wrappers/templates/ratelimited.tmpl -o wrappers/ratelimited/domain_generated.go //go:generate gowrap gen -g -p . -i HistoryManager -t ./wrappers/templates/ratelimited.tmpl -o wrappers/ratelimited/history_generated.go diff --git a/common/persistence/dataManagerInterfaces_mock.go b/common/persistence/data_manager_interfaces_mock.go similarity index 100% rename from common/persistence/dataManagerInterfaces_mock.go rename to common/persistence/data_manager_interfaces_mock.go diff --git a/common/persistence/operationModeValidator.go b/common/persistence/operation_mode_validator.go similarity index 100% rename from common/persistence/operationModeValidator.go rename to common/persistence/operation_mode_validator.go diff --git a/common/persistence/operationModeValidator_test.go b/common/persistence/operation_mode_validator_test.go similarity index 100% rename from common/persistence/operationModeValidator_test.go rename to common/persistence/operation_mode_validator_test.go diff --git a/common/persistence/queueManager.go b/common/persistence/queue_manager.go similarity index 100% rename from common/persistence/queueManager.go rename to common/persistence/queue_manager.go diff --git a/common/persistence/dataVisibilityManagerInterfaces.go b/common/persistence/visibility_manager_interfaces.go similarity index 98% rename from common/persistence/dataVisibilityManagerInterfaces.go rename to common/persistence/visibility_manager_interfaces.go index 0e15451ecbe..d6ce4abdd18 100644 --- a/common/persistence/dataVisibilityManagerInterfaces.go +++ b/common/persistence/visibility_manager_interfaces.go @@ -18,7 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -//go:generate mockgen -package $GOPACKAGE -destination dataVisibilityManagerInterfaces_mock.go -self_package github.com/uber/cadence/common/persistence github.com/uber/cadence/common/persistence VisibilityManager +//go:generate mockgen -package $GOPACKAGE -destination visibility_manager_interfaces_mock.go -self_package github.com/uber/cadence/common/persistence github.com/uber/cadence/common/persistence VisibilityManager // Generate rate limiter wrapper. //go:generate gowrap gen -g -p . -i VisibilityManager -t ./wrappers/templates/ratelimited.tmpl -o wrappers/ratelimited/visibility_generated.go diff --git a/common/persistence/dataVisibilityManagerInterfaces_mock.go b/common/persistence/visibility_manager_interfaces_mock.go similarity index 100% rename from common/persistence/dataVisibilityManagerInterfaces_mock.go rename to common/persistence/visibility_manager_interfaces_mock.go diff --git a/common/reconciliation/fetcher/concrete_test.go b/common/reconciliation/fetcher/concrete_test.go index b8d702f6a06..97b49de3d14 100644 --- a/common/reconciliation/fetcher/concrete_test.go +++ b/common/reconciliation/fetcher/concrete_test.go @@ -23,83 +23,332 @@ package fetcher import ( + "context" + "fmt" "testing" + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/uber/cadence/.gen/go/shared" "github.com/uber/cadence/common" "github.com/uber/cadence/common/codec" + "github.com/uber/cadence/common/pagination" "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/reconciliation/entity" ) -const ( - testTreeID = "test-tree-id" - testBranchID = "test-branch-id" -) +func TestConcreteExecutionIterator(t *testing.T) { + ctrl := gomock.NewController(t) + retryer := persistence.NewMockRetryer(ctrl) + retryer.EXPECT().ListConcreteExecutions(gomock.Any(), gomock.Any()). + Return(&persistence.ListConcreteExecutionsResponse{}, nil). + Times(1) -var ( - validBranchToken = []byte{89, 11, 0, 10, 0, 0, 0, 12, 116, 101, 115, 116, 45, 116, 114, 101, 101, 45, 105, 100, 11, 0, 20, 0, 0, 0, 14, 116, 101, 115, 116, 45, 98, 114, 97, 110, 99, 104, 45, 105, 100, 0} - invalidBranchToken = []byte("invalid") -) + iterator := ConcreteExecutionIterator( + context.Background(), + retryer, + 10, + ) + require.NotNil(t, iterator) +} + +func TestConcreteExecution(t *testing.T) { + encoder := codec.NewThriftRWEncoder() + tests := []struct { + desc string + req ExecutionRequest + mockFn func(retryer *persistence.MockRetryer) + wantEntity entity.Entity + wantErr bool + }{ + { + desc: "success", + req: ExecutionRequest{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + DomainName: "test-domain-name", + }, + mockFn: func(retryer *persistence.MockRetryer) { + retryer.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()). + Return( + &persistence.GetWorkflowExecutionResponse{ + State: &persistence.WorkflowMutableState{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id", "test-branch-id"), + State: persistence.WorkflowStateRunning, + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + }, + }, + nil, + ).Times(1) + + retryer.EXPECT().GetShardID().Return(355).Times(1) + }, + wantEntity: &entity.ConcreteExecution{ + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id", "test-branch-id"), + TreeID: "test-tree-id", + BranchID: "test-branch-id", + Execution: entity.Execution{ + ShardID: 355, + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateRunning, + }, + }, + }, + { + desc: "GetWorkflowExecution failed", + req: ExecutionRequest{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + DomainName: "test-domain-name", + }, + mockFn: func(retryer *persistence.MockRetryer) { + retryer.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + ctrl := gomock.NewController(t) + retryer := persistence.NewMockRetryer(ctrl) + + tc.mockFn(retryer) + + gotEntity, err := ConcreteExecution(context.Background(), retryer, tc.req) + if (err != nil) != tc.wantErr { + t.Fatalf("ConcreteExecution() err: %v, wantErr %v", err, tc.wantErr) + } + + if diff := cmp.Diff(tc.wantEntity, gotEntity); diff != "" { + t.Errorf("ConcreteExecution() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestGetConcreteExecutions(t *testing.T) { + encoder := codec.NewThriftRWEncoder() + testExecutions := []*persistence.ListConcreteExecutionsEntity{ + { + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id-1", "test-branch-id-1"), + State: persistence.WorkflowStateRunning, + DomainID: "test-domain-id-1", + WorkflowID: "test-workflow-id-1", + RunID: "test-run-id-1", + }, + }, + { + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id-2", "test-branch-id-2"), + State: persistence.WorkflowStateCompleted, + DomainID: "test-domain-id-2", + WorkflowID: "test-workflow-id-2", + RunID: "test-run-id-2", + }, + }, + } + + tests := []struct { + desc string + pageSize int + pageToken pagination.PageToken + mockFn func(*testing.T, *persistence.MockRetryer) + wantPage pagination.Page + wantErr bool + }{ + { + desc: "success", + pageSize: 2, + pageToken: []byte("test-page-token"), + mockFn: func(t *testing.T, retryer *persistence.MockRetryer) { + retryer.EXPECT().ListConcreteExecutions(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *persistence.ListConcreteExecutionsRequest) (*persistence.ListConcreteExecutionsResponse, error) { + wantReq := &persistence.ListConcreteExecutionsRequest{ + PageSize: 2, + PageToken: []byte("test-page-token"), + } + if diff := cmp.Diff(wantReq, req); diff != "" { + t.Errorf("Request mismatch (-want +got):\n%s", diff) + } + return &persistence.ListConcreteExecutionsResponse{ + PageToken: []byte("test-next-page-token"), + Executions: testExecutions, + }, nil + }).Times(1) + + // will be called for each execution in the response + retryer.EXPECT().GetShardID().Return(355).Times(2) + }, + wantPage: pagination.Page{ + CurrentToken: []byte("test-page-token"), + NextToken: []byte("test-next-page-token"), + Entities: concreteExecutionsToEntities(testExecutions, 355, encoder), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + ctrl := gomock.NewController(t) + retryer := persistence.NewMockRetryer(ctrl) + + tc.mockFn(t, retryer) + + fetchFn := getConcreteExecutions(retryer, tc.pageSize, encoder) + gotPage, err := fetchFn(context.Background(), tc.pageToken) + if (err != nil) != tc.wantErr { + t.Fatalf("ConcreteExecution() err: %v, wantErr %v", err, tc.wantErr) + } + + if diff := cmp.Diff(tc.wantPage, gotPage); diff != "" { + t.Errorf("ConcreteExecution() mismatch (-want +got):\n%s", diff) + } + }) + } +} func TestGetBranchToken(t *testing.T) { encoder := codec.NewThriftRWEncoder() testCases := []struct { - name string - entity *persistence.ListConcreteExecutionsEntity - expectError bool - branchToken []byte - treeID string - branchID string + name string + entity *persistence.ListConcreteExecutionsEntity + wantErr bool + wantBranchToken []byte + wantHistoryBranch shared.HistoryBranch }{ { - name: "ValidBranchToken", + name: "valid branch token - no version histories", + entity: &persistence.ListConcreteExecutionsEntity{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id", "test-branch-id"), + }, + }, + wantBranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id", "test-branch-id"), + wantHistoryBranch: shared.HistoryBranch{ + TreeID: common.StringPtr("test-tree-id"), + BranchID: common.StringPtr("test-branch-id"), + }, + }, + { + name: "valid branch token - with version histories", entity: &persistence.ListConcreteExecutionsEntity{ ExecutionInfo: &persistence.WorkflowExecutionInfo{ - BranchToken: getValidBranchToken(t, encoder), + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id", "test-branch-id"), }, + VersionHistories: &persistence.VersionHistories{ + CurrentVersionHistoryIndex: 1, + Histories: []*persistence.VersionHistory{ + {}, // this will be ignored because index is 1 + { + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id-from-versionhistory", "test-branch-id-from-versionhistory"), + }, + }, + }, + }, + wantBranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id-from-versionhistory", "test-branch-id-from-versionhistory"), + wantHistoryBranch: shared.HistoryBranch{ + TreeID: common.StringPtr("test-tree-id-from-versionhistory"), + BranchID: common.StringPtr("test-branch-id-from-versionhistory"), }, - expectError: false, - branchToken: validBranchToken, - treeID: testTreeID, - branchID: testBranchID, }, { - name: "InvalidBranchToken", + name: "version history index out of bound", entity: &persistence.ListConcreteExecutionsEntity{ ExecutionInfo: &persistence.WorkflowExecutionInfo{ - BranchToken: invalidBranchToken, + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id", "test-branch-id"), + }, + VersionHistories: &persistence.VersionHistories{ + CurrentVersionHistoryIndex: 2, + Histories: []*persistence.VersionHistory{ + {}, + { + BranchToken: mustGetValidBranchToken(t, encoder, "test-tree-id-from-versionhistory", "test-branch-id-from-versionhistory"), + }, + }, }, }, - expectError: true, + wantErr: true, + }, + { + name: "invalid branch token", + entity: &persistence.ListConcreteExecutionsEntity{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + BranchToken: []byte("invalid"), + }, + }, + wantErr: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - branchToken, branch, err := getBranchToken(tc.entity.ExecutionInfo.BranchToken, tc.entity.VersionHistories, encoder) - if tc.expectError { + branchToken, branch, err := getBranchToken( + tc.entity.ExecutionInfo.BranchToken, + tc.entity.VersionHistories, + encoder, + ) + + if tc.wantErr { require.Error(t, err) require.Nil(t, branchToken) require.Empty(t, branch.GetTreeID()) require.Empty(t, branch.GetBranchID()) } else { - require.NoError(t, err) - require.Equal(t, tc.branchToken, branchToken) - require.Equal(t, tc.treeID, branch.GetTreeID()) - require.Equal(t, tc.branchID, branch.GetBranchID()) + if diff := cmp.Diff(tc.wantHistoryBranch, branch); diff != "" { + t.Fatalf("HistoryBranch mismatch (-want +got):\n%s", diff) + } + require.Equal(t, tc.wantBranchToken, branchToken) } }) } } -func getValidBranchToken(t *testing.T, encoder *codec.ThriftRWEncoder) []byte { +func mustGetValidBranchToken(t *testing.T, encoder *codec.ThriftRWEncoder, treeID, branchID string) []byte { hb := &shared.HistoryBranch{ - TreeID: common.StringPtr(testTreeID), - BranchID: common.StringPtr(testBranchID), + TreeID: common.StringPtr(treeID), + BranchID: common.StringPtr(branchID), } bytes, err := encoder.Encode(hb) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to encode branch token: %v", err) + } + return bytes } + +func concreteExecutionsToEntities(execs []*persistence.ListConcreteExecutionsEntity, shardID int, encoder *codec.ThriftRWEncoder) []pagination.Entity { + entities := make([]pagination.Entity, len(execs)) + for i, e := range execs { + branchToken, branch, err := getBranchToken(e.ExecutionInfo.BranchToken, e.VersionHistories, encoder) + if err != nil { + return nil + } + concreteExec := &entity.ConcreteExecution{ + BranchToken: branchToken, + TreeID: branch.GetTreeID(), + BranchID: branch.GetBranchID(), + Execution: entity.Execution{ + ShardID: shardID, + DomainID: e.ExecutionInfo.DomainID, + WorkflowID: e.ExecutionInfo.WorkflowID, + RunID: e.ExecutionInfo.RunID, + State: e.ExecutionInfo.State, + }, + } + entities[i] = concreteExec + } + return entities +} diff --git a/common/reconciliation/fetcher/current_test.go b/common/reconciliation/fetcher/current_test.go index c2f7e6268cd..641403daf06 100644 --- a/common/reconciliation/fetcher/current_test.go +++ b/common/reconciliation/fetcher/current_test.go @@ -23,6 +23,7 @@ package fetcher import ( "context" + "fmt" "testing" "github.com/golang/mock/gomock" @@ -33,21 +34,108 @@ import ( "github.com/uber/cadence/common/reconciliation/entity" ) -func TestGetCurrentExecution(t *testing.T) { +func TestCurrentExecutionIterator(t *testing.T) { ctrl := gomock.NewController(t) - mockRetryer := persistence.NewMockRetryer(ctrl) + retryer := persistence.NewMockRetryer(ctrl) + retryer.EXPECT().ListCurrentExecutions(gomock.Any(), gomock.Any()). + Return(&persistence.ListCurrentExecutionsResponse{}, nil). + Times(1) + + iterator := CurrentExecutionIterator( + context.Background(), + retryer, + 10, + ) + require.NotNil(t, iterator) +} + +func TestCurrentExecution(t *testing.T) { + ctx := context.Background() + + testCases := []struct { + name string + setupMock func(*persistence.MockRetryer) + request ExecutionRequest + wantEntity entity.Entity + wantErr bool + }{ + { + name: "success", + request: ExecutionRequest{ + DomainID: "testDomainID", + WorkflowID: "testWorkflowID", + DomainName: "testDomainName", + }, + setupMock: func(mockRetryer *persistence.MockRetryer) { + mockRetryer.EXPECT().GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{ + DomainID: "testDomainID", + WorkflowID: "testWorkflowID", + DomainName: "testDomainName", + }).Return(&persistence.GetCurrentExecutionResponse{ + RunID: "testRunID", + State: persistence.WorkflowStateRunning, + }, nil).Times(1) + + mockRetryer.EXPECT().GetShardID().Return(123).Times(1) + }, + wantEntity: &entity.CurrentExecution{ + CurrentRunID: "testRunID", + Execution: entity.Execution{ + ShardID: 123, + DomainID: "testDomainID", + WorkflowID: "testWorkflowID", + RunID: "testRunID", + State: persistence.WorkflowStateRunning, + }, + }, + }, + { + name: "GetCurrentExecution failed", + request: ExecutionRequest{ + DomainID: "testDomainID", + WorkflowID: "testWorkflowID", + DomainName: "testDomainName", + }, + setupMock: func(mockRetryer *persistence.MockRetryer) { + mockRetryer.EXPECT().GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{ + DomainID: "testDomainID", + WorkflowID: "testWorkflowID", + DomainName: "testDomainName", + }).Return(nil, fmt.Errorf("failed")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockRetryer := persistence.NewMockRetryer(ctrl) + + tc.setupMock(mockRetryer) + gotEntity, err := CurrentExecution(ctx, mockRetryer, tc.request) + if (err != nil) != tc.wantErr { + t.Errorf("CurrentExecution() error = %v, wantErr %v", err, tc.wantErr) + } + + require.Equal(t, tc.wantEntity, gotEntity) + }) + } +} + +func TestGetCurrentExecution(t *testing.T) { ctx := context.Background() pageSize := 10 testCases := []struct { - name string - setupMock func() - expectedPage pagination.Page - expectedError bool + name string + setupMock func(*persistence.MockRetryer) + wantPage pagination.Page + wantErr bool }{ { - name: "Success", - setupMock: func() { + name: "success", + setupMock: func(mockRetryer *persistence.MockRetryer) { executions := []*persistence.CurrentWorkflowExecution{ { DomainID: "testDomainID", @@ -65,11 +153,11 @@ func TestGetCurrentExecution(t *testing.T) { Return(&persistence.ListCurrentExecutionsResponse{ Executions: executions, PageToken: nil, - }, nil) + }, nil).Times(1) - mockRetryer.EXPECT().GetShardID().Return(123) + mockRetryer.EXPECT().GetShardID().Return(123).Times(1) }, - expectedPage: pagination.Page{ + wantPage: pagination.Page{ Entities: []pagination.Entity{ &entity.CurrentExecution{ CurrentRunID: "testCurrentRunID", // This should match with the mocked data @@ -83,22 +171,35 @@ func TestGetCurrentExecution(t *testing.T) { }, }, }, - expectedError: false, + wantErr: false, + }, + { + name: "ListCurrentExecutions failed", + setupMock: func(mockRetryer *persistence.MockRetryer) { + mockRetryer.EXPECT(). + ListCurrentExecutions(ctx, &persistence.ListCurrentExecutionsRequest{ + PageSize: pageSize, + }). + Return(nil, fmt.Errorf("failed")). + Times(1) + }, + wantErr: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - tc.setupMock() + ctrl := gomock.NewController(t) + mockRetryer := persistence.NewMockRetryer(ctrl) + + tc.setupMock(mockRetryer) fetchFn := getCurrentExecution(mockRetryer, pageSize) page, err := fetchFn(ctx, nil) - - if tc.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tc.expectedPage, page) + if (err != nil) != tc.wantErr { + t.Errorf("getCurrentExecution() error = %v, wantErr %v", err, tc.wantErr) } + + require.Equal(t, tc.wantPage, page) }) } } diff --git a/common/reconciliation/fetcher/timer_test.go b/common/reconciliation/fetcher/timer_test.go index a0bfe770c3e..1b1ed4a0726 100644 --- a/common/reconciliation/fetcher/timer_test.go +++ b/common/reconciliation/fetcher/timer_test.go @@ -36,6 +36,23 @@ import ( "github.com/uber/cadence/common/reconciliation/entity" ) +func TestTimerIterator(t *testing.T) { + ctrl := gomock.NewController(t) + retryer := persistence.NewMockRetryer(ctrl) + retryer.EXPECT().GetTimerIndexTasks(gomock.Any(), gomock.Any()). + Return(&persistence.GetTimerIndexTasksResponse{}, nil). + Times(1) + + iterator := TimerIterator( + context.Background(), + retryer, + time.Now(), + time.Now(), + 10, + ) + require.NotNil(t, iterator) +} + func TestGetUserTimers(t *testing.T) { fixedTimestamp, err := time.Parse(time.RFC3339, "2023-12-12T22:08:41Z") if err != nil {