From 7df40de8e69a5482331a53565bcefaa8a0a78ce6 Mon Sep 17 00:00:00 2001 From: Zijian Date: Fri, 1 Nov 2024 17:47:20 +0000 Subject: [PATCH] Implement new APIs --- client/matching/client.go | 12 +- client/matching/client_test.go | 92 ++++ common/metrics/defs.go | 32 +- common/types/mapper/proto/matching_test.go | 24 + common/types/matching.go | 14 + common/types/matching_test.go | 60 +++ common/types/testdata/service_matching.go | 14 + service/matching/handler/engine.go | 79 +++ service/matching/handler/engine_test.go | 406 ++++++++++++++ service/matching/handler/handler.go | 44 +- service/matching/handler/handler_test.go | 113 ++++ service/matching/handler/interfaces.go | 2 + service/matching/handler/interfaces_mock.go | 30 ++ service/matching/tasklist/db.go | 22 + service/matching/tasklist/identifier.go | 4 +- service/matching/tasklist/interfaces.go | 2 + service/matching/tasklist/interfaces_mock.go | 28 + .../matching/tasklist/task_list_manager.go | 158 ++++-- .../tasklist/task_list_manager_test.go | 497 ++++++++++++++++-- 19 files changed, 1529 insertions(+), 104 deletions(-) diff --git a/client/matching/client.go b/client/matching/client.go index b3a872caf6c..21d53c20b3a 100644 --- a/client/matching/client.go +++ b/client/matching/client.go @@ -269,7 +269,11 @@ func (c *clientImpl) UpdateTaskListPartitionConfig( request *types.MatchingUpdateTaskListPartitionConfigRequest, opts ...yarpc.CallOption, ) (*types.MatchingUpdateTaskListPartitionConfigResponse, error) { - return nil, &types.BadRequestError{} + peer, err := c.peerResolver.FromTaskList(request.TaskList.GetName()) + if err != nil { + return nil, err + } + return c.client.UpdateTaskListPartitionConfig(ctx, request, append(opts, yarpc.WithShardKey(peer))...) } func (c *clientImpl) RefreshTaskListPartitionConfig( @@ -277,5 +281,9 @@ func (c *clientImpl) RefreshTaskListPartitionConfig( request *types.MatchingRefreshTaskListPartitionConfigRequest, opts ...yarpc.CallOption, ) (*types.MatchingRefreshTaskListPartitionConfigResponse, error) { - return nil, &types.BadRequestError{} + peer, err := c.peerResolver.FromTaskList(request.TaskList.GetName()) + if err != nil { + return nil, err + } + return c.client.RefreshTaskListPartitionConfig(ctx, request, append(opts, yarpc.WithShardKey(peer))...) } diff --git a/client/matching/client_test.go b/client/matching/client_test.go index d71ec7d2dc5..0a84709c7b6 100644 --- a/client/matching/client_test.go +++ b/client/matching/client_test.go @@ -431,6 +431,74 @@ func TestClient_withResponse(t *testing.T) { want: nil, wantError: true, }, + { + name: "UpdateTaskListPartitionConfig", + op: func(c Client) (any, error) { + return c.UpdateTaskListPartitionConfig(context.Background(), testMatchingUpdateTaskListPartitionConfigRequest()) + }, + mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { + p.EXPECT().FromTaskList(_testTaskList).Return("peer0", nil) + c.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.MatchingUpdateTaskListPartitionConfigResponse{}, nil) + }, + want: &types.MatchingUpdateTaskListPartitionConfigResponse{}, + }, + { + name: "UpdateTaskListPartitionConfig - Error in resolving peer", + op: func(c Client) (any, error) { + return c.UpdateTaskListPartitionConfig(context.Background(), testMatchingUpdateTaskListPartitionConfigRequest()) + }, + mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { + p.EXPECT().FromTaskList(_testTaskList).Return("peer0", assert.AnError) + }, + want: nil, + wantError: true, + }, + { + name: "UpdateTaskListPartitionConfig - Error while listing tasklist partitions", + op: func(c Client) (any, error) { + return c.UpdateTaskListPartitionConfig(context.Background(), testMatchingUpdateTaskListPartitionConfigRequest()) + }, + mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { + p.EXPECT().FromTaskList(_testTaskList).Return("peer0", nil) + c.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) + }, + want: nil, + wantError: true, + }, + { + name: "RefreshTaskListPartitionConfig", + op: func(c Client) (any, error) { + return c.RefreshTaskListPartitionConfig(context.Background(), testMatchingRefreshTaskListPartitionConfigRequest()) + }, + mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { + p.EXPECT().FromTaskList(_testTaskList).Return("peer0", nil) + c.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.MatchingRefreshTaskListPartitionConfigResponse{}, nil) + }, + want: &types.MatchingRefreshTaskListPartitionConfigResponse{}, + }, + { + name: "RefreshTaskListPartitionConfig - Error in resolving peer", + op: func(c Client) (any, error) { + return c.RefreshTaskListPartitionConfig(context.Background(), testMatchingRefreshTaskListPartitionConfigRequest()) + }, + mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { + p.EXPECT().FromTaskList(_testTaskList).Return("peer0", assert.AnError) + }, + want: nil, + wantError: true, + }, + { + name: "RefreshTaskListPartitionConfig - Error while listing tasklist partitions", + op: func(c Client) (any, error) { + return c.RefreshTaskListPartitionConfig(context.Background(), testMatchingRefreshTaskListPartitionConfigRequest()) + }, + mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { + p.EXPECT().FromTaskList(_testTaskList).Return("peer0", nil) + c.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) + }, + want: nil, + wantError: true, + }, } for _, tt := range tests { tt := tt @@ -526,3 +594,27 @@ func testGetTaskListsByDomainRequest() *types.GetTaskListsByDomainRequest { Domain: _testDomain, } } + +func testMatchingUpdateTaskListPartitionConfigRequest() *types.MatchingUpdateTaskListPartitionConfigRequest { + return &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: _testDomainUUID, + TaskList: &types.TaskList{Name: _testTaskList}, + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 2, + }, + } +} + +func testMatchingRefreshTaskListPartitionConfigRequest() *types.MatchingRefreshTaskListPartitionConfigRequest { + return &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: _testDomainUUID, + TaskList: &types.TaskList{Name: _testTaskList}, + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 2, + }, + } +} diff --git a/common/metrics/defs.go b/common/metrics/defs.go index b3eca90e706..e7fede89a44 100644 --- a/common/metrics/defs.go +++ b/common/metrics/defs.go @@ -1331,6 +1331,10 @@ const ( MatchingListTaskListPartitionsScope // MatchingGetTaskListsByDomainScope tracks GetTaskListsByDomain API calls received by service MatchingGetTaskListsByDomainScope + // MatchingUpdateTaskListPartitionConfigScope tracks UpdateTaskListPartitionConfig API calls received by service + MatchingUpdateTaskListPartitionConfigScope + // MatchingRefreshTaskListPartitionConfigScope tracks RefreshTaskListPartitionConfig API calls received by service + MatchingRefreshTaskListPartitionConfigScope NumMatchingScopes ) @@ -1976,18 +1980,20 @@ var ScopeDefs = map[ServiceIdx]map[int]scopeDefinition{ }, // Matching Scope Names Matching: { - MatchingPollForDecisionTaskScope: {operation: "PollForDecisionTask"}, - MatchingPollForActivityTaskScope: {operation: "PollForActivityTask"}, - MatchingAddActivityTaskScope: {operation: "AddActivityTask"}, - MatchingAddDecisionTaskScope: {operation: "AddDecisionTask"}, - MatchingAddTaskScope: {operation: "AddTask"}, - MatchingTaskListMgrScope: {operation: "TaskListMgr"}, - MatchingQueryWorkflowScope: {operation: "QueryWorkflow"}, - MatchingRespondQueryTaskCompletedScope: {operation: "RespondQueryTaskCompleted"}, - MatchingCancelOutstandingPollScope: {operation: "CancelOutstandingPoll"}, - MatchingDescribeTaskListScope: {operation: "DescribeTaskList"}, - MatchingListTaskListPartitionsScope: {operation: "ListTaskListPartitions"}, - MatchingGetTaskListsByDomainScope: {operation: "GetTaskListsByDomain"}, + MatchingPollForDecisionTaskScope: {operation: "PollForDecisionTask"}, + MatchingPollForActivityTaskScope: {operation: "PollForActivityTask"}, + MatchingAddActivityTaskScope: {operation: "AddActivityTask"}, + MatchingAddDecisionTaskScope: {operation: "AddDecisionTask"}, + MatchingAddTaskScope: {operation: "AddTask"}, + MatchingTaskListMgrScope: {operation: "TaskListMgr"}, + MatchingQueryWorkflowScope: {operation: "QueryWorkflow"}, + MatchingRespondQueryTaskCompletedScope: {operation: "RespondQueryTaskCompleted"}, + MatchingCancelOutstandingPollScope: {operation: "CancelOutstandingPoll"}, + MatchingDescribeTaskListScope: {operation: "DescribeTaskList"}, + MatchingListTaskListPartitionsScope: {operation: "ListTaskListPartitions"}, + MatchingGetTaskListsByDomainScope: {operation: "GetTaskListsByDomain"}, + MatchingUpdateTaskListPartitionConfigScope: {operation: "UpdateTaskListPartitionConfig"}, + MatchingRefreshTaskListPartitionConfigScope: {operation: "RefreshTaskListPartitionConfig"}, }, // Worker Scope Names Worker: { @@ -2575,6 +2581,7 @@ const ( IsolationTaskMatchPerTaskListCounter PollerPerTaskListCounter PollerInvalidIsolationGroupCounter + TaskListPartitionUpdateFailedCounter TaskListManagersGauge TaskLagPerTaskListGauge TaskBacklogPerTaskListGauge @@ -3257,6 +3264,7 @@ var MetricDefs = map[ServiceIdx]map[int]metricDefinition{ IsolationTaskMatchPerTaskListCounter: {metricName: "isolation_task_matches_per_tl", metricType: Counter}, PollerPerTaskListCounter: {metricName: "poller_count_per_tl", metricRollupName: "poller_count"}, PollerInvalidIsolationGroupCounter: {metricName: "poller_invalid_isolation_group_per_tl", metricType: Counter}, + TaskListPartitionUpdateFailedCounter: {metricName: "tasklist_partition_update_failed_per_tl", metricType: Counter}, TaskListManagersGauge: {metricName: "tasklist_managers", metricType: Gauge}, TaskLagPerTaskListGauge: {metricName: "task_lag_per_tl", metricType: Gauge}, TaskBacklogPerTaskListGauge: {metricName: "task_backlog_per_tl", metricType: Gauge}, diff --git a/common/types/mapper/proto/matching_test.go b/common/types/mapper/proto/matching_test.go index 92f68547fd9..63bb9fcc155 100644 --- a/common/types/mapper/proto/matching_test.go +++ b/common/types/mapper/proto/matching_test.go @@ -148,3 +148,27 @@ func TestMatchingGetTaskListsByDomainResponse(t *testing.T) { assert.Equal(t, item, ToMatchingGetTaskListsByDomainResponse(FromMatchingGetTaskListsByDomainResponse(item))) } } + +func TestMatchingUpdateTaskListPartitionConfigRequest(t *testing.T) { + for _, item := range []*types.MatchingUpdateTaskListPartitionConfigRequest{nil, {}, &testdata.MatchingUpdateTaskListPartitionConfigRequest} { + assert.Equal(t, item, ToMatchingUpdateTaskListPartitionConfigRequest(FromMatchingUpdateTaskListPartitionConfigRequest(item))) + } +} + +func TestMatchingUpdateTaskListPartitionConfigResponse(t *testing.T) { + for _, item := range []*types.MatchingUpdateTaskListPartitionConfigResponse{nil, {}} { + assert.Equal(t, item, ToMatchingUpdateTaskListPartitionConfigResponse(FromMatchingUpdateTaskListPartitionConfigResponse(item))) + } +} + +func TestMatchingRefreshTaskListPartitionConfigRequest(t *testing.T) { + for _, item := range []*types.MatchingRefreshTaskListPartitionConfigRequest{nil, {}, &testdata.MatchingRefreshTaskListPartitionConfigRequest} { + assert.Equal(t, item, ToMatchingRefreshTaskListPartitionConfigRequest(FromMatchingRefreshTaskListPartitionConfigRequest(item))) + } +} + +func TestMatchingRefreshTaskListPartitionConfigResponse(t *testing.T) { + for _, item := range []*types.MatchingRefreshTaskListPartitionConfigResponse{nil, {}} { + assert.Equal(t, item, ToMatchingRefreshTaskListPartitionConfigResponse(FromMatchingRefreshTaskListPartitionConfigResponse(item))) + } +} diff --git a/common/types/matching.go b/common/types/matching.go index da5a9ac41a6..d82892a7db6 100644 --- a/common/types/matching.go +++ b/common/types/matching.go @@ -654,6 +654,13 @@ type MatchingUpdateTaskListPartitionConfigRequest struct { PartitionConfig *TaskListPartitionConfig } +func (v *MatchingUpdateTaskListPartitionConfigRequest) GetTaskListType() (o TaskListType) { + if v != nil && v.TaskListType != nil { + return *v.TaskListType + } + return +} + type MatchingUpdateTaskListPartitionConfigResponse struct{} type MatchingRefreshTaskListPartitionConfigRequest struct { @@ -663,4 +670,11 @@ type MatchingRefreshTaskListPartitionConfigRequest struct { PartitionConfig *TaskListPartitionConfig } +func (v *MatchingRefreshTaskListPartitionConfigRequest) GetTaskListType() (o TaskListType) { + if v != nil && v.TaskListType != nil { + return *v.TaskListType + } + return +} + type MatchingRefreshTaskListPartitionConfigResponse struct{} diff --git a/common/types/matching_test.go b/common/types/matching_test.go index 9f2c3db2a97..f4711ad3768 100644 --- a/common/types/matching_test.go +++ b/common/types/matching_test.go @@ -1721,3 +1721,63 @@ func TestTaskSource_MarshalText(t *testing.T) { }) } } + +func TestMatchingUpdateTaskListPartitionConfigRequest_GetTaskListType(t *testing.T) { + tests := []struct { + name string + req *MatchingUpdateTaskListPartitionConfigRequest + want TaskListType + }{ + { + name: "nil request", + req: nil, + want: TaskListTypeDecision, + }, + { + name: "empty request", + req: &MatchingUpdateTaskListPartitionConfigRequest{}, + want: TaskListTypeDecision, + }, + { + name: "non empty request", + req: &MatchingUpdateTaskListPartitionConfigRequest{TaskListType: TaskListTypeActivity.Ptr()}, + want: TaskListTypeActivity, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.req.GetTaskListType() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMatchingRefreshTaskListPartitionConfigRequest_GetTaskListType(t *testing.T) { + tests := []struct { + name string + req *MatchingRefreshTaskListPartitionConfigRequest + want TaskListType + }{ + { + name: "nil request", + req: nil, + want: TaskListTypeDecision, + }, + { + name: "empty request", + req: &MatchingRefreshTaskListPartitionConfigRequest{}, + want: TaskListTypeDecision, + }, + { + name: "non empty request", + req: &MatchingRefreshTaskListPartitionConfigRequest{TaskListType: TaskListTypeActivity.Ptr()}, + want: TaskListTypeActivity, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.req.GetTaskListType() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/common/types/testdata/service_matching.go b/common/types/testdata/service_matching.go index db49bd79a7c..d721ae87195 100644 --- a/common/types/testdata/service_matching.go +++ b/common/types/testdata/service_matching.go @@ -174,4 +174,18 @@ var ( WorkflowType: &WorkflowType, WorkflowDomain: DomainName, } + + MatchingUpdateTaskListPartitionConfigRequest = types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: DomainID, + TaskList: &TaskList, + TaskListType: &TaskListType, + PartitionConfig: &TaskListPartitionConfig, + } + + MatchingRefreshTaskListPartitionConfigRequest = types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: DomainID, + TaskList: &TaskList, + TaskListType: &TaskListType, + PartitionConfig: &TaskListPartitionConfig, + } ) diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index caab0347452..94e9080c4ce 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -1011,6 +1011,85 @@ func (e *matchingEngineImpl) GetTaskListsByDomain( return e.getTaskListByDomainLocked(domainID), nil } +func (e *matchingEngineImpl) UpdateTaskListPartitionConfig( + hCtx *handlerContext, + request *types.MatchingUpdateTaskListPartitionConfigRequest, +) (*types.MatchingUpdateTaskListPartitionConfigResponse, error) { + domainID := request.DomainUUID + taskListName := request.TaskList.GetName() + taskListKind := request.TaskList.GetKind() + taskListType := persistence.TaskListTypeDecision + if request.GetTaskListType() == types.TaskListTypeActivity { + taskListType = persistence.TaskListTypeActivity + } + if taskListKind != types.TaskListKindNormal { + return nil, &types.BadRequestError{Message: "Only normal tasklist's partition config can be updated."} + } + if request.PartitionConfig == nil { + return nil, &types.BadRequestError{Message: "Task list partition config is not set in the request."} + } + if request.PartitionConfig.NumWritePartitions > request.PartitionConfig.NumReadPartitions { + return nil, &types.BadRequestError{Message: "The number of write partitions cannot be larger than the number of read partitions."} + } + if request.PartitionConfig.NumWritePartitions <= 0 { + return nil, &types.BadRequestError{Message: "The number of partitions must be larger than 0."} + } + taskListID, err := tasklist.NewIdentifier(domainID, taskListName, taskListType) + if err != nil { + return nil, err + } + if !taskListID.IsRoot() { + return nil, &types.BadRequestError{Message: "Only root partition's partition config can be updated."} + } + tlMgr, err := e.getTaskListManager(taskListID, &taskListKind) + if err != nil { + return nil, err + } + err = tlMgr.UpdateTaskListPartitionConfig(hCtx.Context, request.PartitionConfig) + if err != nil { + return nil, err + } + return &types.MatchingUpdateTaskListPartitionConfigResponse{}, nil +} + +func (e *matchingEngineImpl) RefreshTaskListPartitionConfig( + hCtx *handlerContext, + request *types.MatchingRefreshTaskListPartitionConfigRequest, +) (*types.MatchingRefreshTaskListPartitionConfigResponse, error) { + domainID := request.DomainUUID + taskListName := request.TaskList.GetName() + taskListKind := request.TaskList.GetKind() + taskListType := persistence.TaskListTypeDecision + if request.GetTaskListType() == types.TaskListTypeActivity { + taskListType = persistence.TaskListTypeActivity + } + if taskListKind != types.TaskListKindNormal { + return nil, &types.BadRequestError{Message: "Only normal tasklist's partition config can be updated."} + } + if request.PartitionConfig != nil && request.PartitionConfig.NumWritePartitions > request.PartitionConfig.NumReadPartitions { + return nil, &types.BadRequestError{Message: "The number of write partitions cannot be larger than the number of read partitions."} + } + if request.PartitionConfig != nil && request.PartitionConfig.NumWritePartitions <= 0 { + return nil, &types.BadRequestError{Message: "The number of partitions must be larger than 0."} + } + taskListID, err := tasklist.NewIdentifier(domainID, taskListName, taskListType) + if err != nil { + return nil, err + } + if taskListID.IsRoot() && request.PartitionConfig != nil { + return nil, &types.BadRequestError{Message: "PartitionConfig must be nil for root partition."} + } + tlMgr, err := e.getTaskListManager(taskListID, &taskListKind) + if err != nil { + return nil, err + } + err = tlMgr.RefreshTaskListPartitionConfig(hCtx.Context, request.PartitionConfig) + if err != nil { + return nil, err + } + return &types.MatchingRefreshTaskListPartitionConfigResponse{}, nil +} + func (e *matchingEngineImpl) getHostInfo(partitionKey string) (string, error) { host, err := e.membershipResolver.Lookup(service.Matching, partitionKey) if err != nil { diff --git a/service/matching/handler/engine_test.go b/service/matching/handler/engine_test.go index 6e0437671d3..0f3baa5647e 100644 --- a/service/matching/handler/engine_test.go +++ b/service/matching/handler/engine_test.go @@ -725,3 +725,409 @@ func TestShutDownTasklistsNotOwned(t *testing.T) { assert.NoError(t, err) } + +func TestUpdateTaskListPartitionConfig(t *testing.T) { + testCases := []struct { + name string + req *types.MatchingUpdateTaskListPartitionConfigRequest + hCtx *handlerContext + mockSetup func(*tasklist.MockManager) + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }).Return(nil) + }, + expectError: false, + }, + { + name: "tasklist manager error", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }).Return(errors.New("tasklist manager error")) + }, + expectError: true, + expectedError: "tasklist manager error", + }, + { + name: "non root partition error", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/test-tasklist/1", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "Only root partition's partition config can be updated.", + }, + { + name: "invalid tasklist name", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "invalid partitioned task list name /__cadence_sys/test-tasklist", + }, + { + name: "invalid partition config", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 3, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "The number of write partitions cannot be larger than the number of read partitions.", + }, + { + name: "invalid partition config - 2", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: -1, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "The number of partitions must be larger than 0.", + }, + { + name: "nil partition config", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "Task list partition config is not set in the request.", + }, + { + name: "invalid tasklist kind", + req: &types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + Kind: types.TaskListKindSticky.Ptr(), + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "Only normal tasklist's partition config can be updated.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockManager := tasklist.NewMockManager(mockCtrl) + tc.mockSetup(mockManager) + tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 1) + require.NoError(t, err) + engine := &matchingEngineImpl{ + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *tasklistID: mockManager, + }, + timeSource: clock.NewRealTimeSource(), + } + _, err = engine.UpdateTaskListPartitionConfig(tc.hCtx, tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRefreshTaskListPartitionConfig(t *testing.T) { + testCases := []struct { + name string + req *types.MatchingRefreshTaskListPartitionConfigRequest + hCtx *handlerContext + mockSetup func(*tasklist.MockManager) + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/test-tasklist/1", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }).Return(nil) + }, + expectError: false, + }, + { + name: "tasklist manager error", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/test-tasklist/1", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + mockManager.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }).Return(errors.New("tasklist manager error")) + }, + expectError: true, + expectedError: "tasklist manager error", + }, + { + name: "invalid tasklist name", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "/__cadence_sys/test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "invalid partitioned task list name /__cadence_sys/test-tasklist", + }, + { + name: "invalid partition config", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 3, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "The number of write partitions cannot be larger than the number of read partitions.", + }, + { + name: "invalid partition config - 2", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: -1, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "The number of partitions must be larger than 0.", + }, + { + name: "invalid tasklist kind", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + Kind: types.TaskListKindSticky.Ptr(), + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "Only normal tasklist's partition config can be updated.", + }, + { + name: "invalid request for root partition", + req: &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{ + Name: "test-tasklist", + }, + TaskListType: types.TaskListTypeActivity.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + hCtx: &handlerContext{ + Context: context.Background(), + }, + mockSetup: func(mockManager *tasklist.MockManager) { + }, + expectError: true, + expectedError: "PartitionConfig must be nil for root partition.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockManager := tasklist.NewMockManager(mockCtrl) + tc.mockSetup(mockManager) + tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 1) + require.NoError(t, err) + tasklistID2, err := tasklist.NewIdentifier("test-domain-id", "/__cadence_sys/test-tasklist/1", 1) + require.NoError(t, err) + engine := &matchingEngineImpl{ + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *tasklistID: mockManager, + *tasklistID2: mockManager, + }, + timeSource: clock.NewRealTimeSource(), + } + _, err = engine.RefreshTaskListPartitionConfig(tc.hCtx, tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/service/matching/handler/handler.go b/service/matching/handler/handler.go index d3728ad6a84..af86351daf0 100644 --- a/service/matching/handler/handler.go +++ b/service/matching/handler/handler.go @@ -418,15 +418,51 @@ func (h *handlerImpl) GetTaskListsByDomain( func (h *handlerImpl) UpdateTaskListPartitionConfig( ctx context.Context, request *types.MatchingUpdateTaskListPartitionConfigRequest, -) (*types.MatchingUpdateTaskListPartitionConfigResponse, error) { - return nil, &types.BadRequestError{} +) (resp *types.MatchingUpdateTaskListPartitionConfigResponse, retError error) { + defer func() { log.CapturePanic(recover(), h.logger, &retError) }() + + domainName := h.domainName(request.DomainUUID) + hCtx := h.newHandlerContext( + ctx, + domainName, + request.TaskList, + metrics.MatchingUpdateTaskListPartitionConfigScope, + ) + + sw := hCtx.startProfiling(&h.startWG) + defer sw.Stop() + + if ok := h.userRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { + return nil, hCtx.handleErr(errMatchingHostThrottle) + } + + response, err := h.engine.UpdateTaskListPartitionConfig(hCtx, request) + return response, hCtx.handleErr(err) } func (h *handlerImpl) RefreshTaskListPartitionConfig( ctx context.Context, request *types.MatchingRefreshTaskListPartitionConfigRequest, -) (*types.MatchingRefreshTaskListPartitionConfigResponse, error) { - return nil, &types.BadRequestError{} +) (resp *types.MatchingRefreshTaskListPartitionConfigResponse, retError error) { + defer func() { log.CapturePanic(recover(), h.logger, &retError) }() + + domainName := h.domainName(request.DomainUUID) + hCtx := h.newHandlerContext( + ctx, + domainName, + request.TaskList, + metrics.MatchingRefreshTaskListPartitionConfigScope, + ) + + sw := hCtx.startProfiling(&h.startWG) + defer sw.Stop() + + if ok := h.userRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { + return nil, hCtx.handleErr(errMatchingHostThrottle) + } + + response, err := h.engine.RefreshTaskListPartitionConfig(hCtx, request) + return response, hCtx.handleErr(err) } func (h *handlerImpl) domainName(id string) string { diff --git a/service/matching/handler/handler_test.go b/service/matching/handler/handler_test.go index 642eecd0383..a4c900854d9 100644 --- a/service/matching/handler/handler_test.go +++ b/service/matching/handler/handler_test.go @@ -788,3 +788,116 @@ func (s *handlerSuite) TestDomainName() { } } +func (s *handlerSuite) TestRefreshTaskListPartitionConfig() { + request := types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{Name: "test-task-list"}, + } + + testCases := []struct { + name string + setupMocks func() + want *types.MatchingRefreshTaskListPartitionConfigResponse + err error + }{ + { + name: "Success case", + setupMocks: func() { + s.mockLimiter.EXPECT().Allow().Return(true).Times(1) + s.mockEngine.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &request). + Return(&types.MatchingRefreshTaskListPartitionConfigResponse{}, nil).Times(1) + }, + want: &types.MatchingRefreshTaskListPartitionConfigResponse{}, + }, + { + name: "Error case - rate limiter not allowed", + setupMocks: func() { + s.mockLimiter.EXPECT().Allow().Return(false).Times(1) + }, + err: &types.ServiceBusyError{Message: "Matching host rps exceeded"}, + }, + { + name: "Error case - RefreshTaskListPartitionConfig failed", + setupMocks: func() { + s.mockLimiter.EXPECT().Allow().Return(true).Times(1) + s.mockEngine.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &request). + Return(nil, errors.New("refresh-tasklist-error")).Times(1) + }, + err: &types.InternalServiceError{Message: "refresh-tasklist-error"}, + }, + } + + for _, tc := range testCases { + s.T().Run(tc.name, func(t *testing.T) { + tc.setupMocks() + s.mockDomainCache.EXPECT().GetDomainName(request.DomainUUID).Return(s.testDomain, nil).Times(1) + + resp, err := s.handler.RefreshTaskListPartitionConfig(context.Background(), &request) + + if tc.err != nil { + s.Error(err) + s.Equal(tc.err, err) + } else { + s.NoError(err) + s.Equal(tc.want, resp) + } + }) + } +} + +func (s *handlerSuite) TestUpdateTaskListPartitionConfig() { + request := types.MatchingUpdateTaskListPartitionConfigRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{Name: "test-task-list"}, + } + + testCases := []struct { + name string + setupMocks func() + want *types.MatchingUpdateTaskListPartitionConfigResponse + err error + }{ + { + name: "Success case", + setupMocks: func() { + s.mockLimiter.EXPECT().Allow().Return(true).Times(1) + s.mockEngine.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), &request). + Return(&types.MatchingUpdateTaskListPartitionConfigResponse{}, nil).Times(1) + }, + want: &types.MatchingUpdateTaskListPartitionConfigResponse{}, + }, + { + name: "Error case - rate limiter not allowed", + setupMocks: func() { + s.mockLimiter.EXPECT().Allow().Return(false).Times(1) + }, + err: &types.ServiceBusyError{Message: "Matching host rps exceeded"}, + }, + { + name: "Error case - UpdateTaskListPartitionConfig failed", + setupMocks: func() { + s.mockLimiter.EXPECT().Allow().Return(true).Times(1) + s.mockEngine.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), &request). + Return(nil, errors.New("update-tasklist-error")).Times(1) + }, + err: &types.InternalServiceError{Message: "update-tasklist-error"}, + }, + } + + for _, tc := range testCases { + s.T().Run(tc.name, func(t *testing.T) { + tc.setupMocks() + s.mockDomainCache.EXPECT().GetDomainName(request.DomainUUID).Return(s.testDomain, nil).Times(1) + + resp, err := s.handler.UpdateTaskListPartitionConfig(context.Background(), &request) + + if tc.err != nil { + s.Error(err) + s.Equal(tc.err, err) + } else { + s.NoError(err) + s.Equal(tc.want, resp) + } + }) + } +} diff --git a/service/matching/handler/interfaces.go b/service/matching/handler/interfaces.go index ed612dfee13..ec1846aa3a2 100644 --- a/service/matching/handler/interfaces.go +++ b/service/matching/handler/interfaces.go @@ -46,6 +46,8 @@ type ( DescribeTaskList(hCtx *handlerContext, request *types.MatchingDescribeTaskListRequest) (*types.DescribeTaskListResponse, error) ListTaskListPartitions(hCtx *handlerContext, request *types.MatchingListTaskListPartitionsRequest) (*types.ListTaskListPartitionsResponse, error) GetTaskListsByDomain(hCtx *handlerContext, request *types.GetTaskListsByDomainRequest) (*types.GetTaskListsByDomainResponse, error) + UpdateTaskListPartitionConfig(hCtx *handlerContext, request *types.MatchingUpdateTaskListPartitionConfigRequest) (*types.MatchingUpdateTaskListPartitionConfigResponse, error) + RefreshTaskListPartitionConfig(hCtx *handlerContext, request *types.MatchingRefreshTaskListPartitionConfigRequest) (*types.MatchingRefreshTaskListPartitionConfigResponse, error) } // Handler interface for matching service diff --git a/service/matching/handler/interfaces_mock.go b/service/matching/handler/interfaces_mock.go index b08fd2c3434..4d6761445f0 100644 --- a/service/matching/handler/interfaces_mock.go +++ b/service/matching/handler/interfaces_mock.go @@ -192,6 +192,21 @@ func (mr *MockEngineMockRecorder) QueryWorkflow(hCtx, request interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryWorkflow", reflect.TypeOf((*MockEngine)(nil).QueryWorkflow), hCtx, request) } +// RefreshTaskListPartitionConfig mocks base method. +func (m *MockEngine) RefreshTaskListPartitionConfig(hCtx *handlerContext, request *types.MatchingRefreshTaskListPartitionConfigRequest) (*types.MatchingRefreshTaskListPartitionConfigResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RefreshTaskListPartitionConfig", hCtx, request) + ret0, _ := ret[0].(*types.MatchingRefreshTaskListPartitionConfigResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RefreshTaskListPartitionConfig indicates an expected call of RefreshTaskListPartitionConfig. +func (mr *MockEngineMockRecorder) RefreshTaskListPartitionConfig(hCtx, request interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTaskListPartitionConfig", reflect.TypeOf((*MockEngine)(nil).RefreshTaskListPartitionConfig), hCtx, request) +} + // RespondQueryTaskCompleted mocks base method. func (m *MockEngine) RespondQueryTaskCompleted(hCtx *handlerContext, request *types.MatchingRespondQueryTaskCompletedRequest) error { m.ctrl.T.Helper() @@ -230,6 +245,21 @@ func (mr *MockEngineMockRecorder) Stop() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockEngine)(nil).Stop)) } +// UpdateTaskListPartitionConfig mocks base method. +func (m *MockEngine) UpdateTaskListPartitionConfig(hCtx *handlerContext, request *types.MatchingUpdateTaskListPartitionConfigRequest) (*types.MatchingUpdateTaskListPartitionConfigResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateTaskListPartitionConfig", hCtx, request) + ret0, _ := ret[0].(*types.MatchingUpdateTaskListPartitionConfigResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateTaskListPartitionConfig indicates an expected call of UpdateTaskListPartitionConfig. +func (mr *MockEngineMockRecorder) UpdateTaskListPartitionConfig(hCtx, request interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskListPartitionConfig", reflect.TypeOf((*MockEngine)(nil).UpdateTaskListPartitionConfig), hCtx, request) +} + // MockHandler is a mock of Handler interface. type MockHandler struct { ctrl *gomock.Controller diff --git a/service/matching/tasklist/db.go b/service/matching/tasklist/db.go index a29330e515a..446407b0fe8 100644 --- a/service/matching/tasklist/db.go +++ b/service/matching/tasklist/db.go @@ -136,6 +136,28 @@ func (db *taskListDB) UpdateState(ackLevel int64) error { return nil } +func (db *taskListDB) UpdateTaskListPartitionConfig(partitionConfig *persistence.TaskListPartitionConfig) error { + db.Lock() + defer db.Unlock() + _, err := db.store.UpdateTaskList(context.Background(), &persistence.UpdateTaskListRequest{ + TaskListInfo: &persistence.TaskListInfo{ + DomainID: db.domainID, + Name: db.taskListName, + TaskType: db.taskType, + AckLevel: db.ackLevel, + RangeID: db.rangeID, + Kind: db.taskListKind, + AdaptivePartitionConfig: partitionConfig, + }, + DomainName: db.domainName, + }) + if err != nil { + return err + } + db.partitionConfig = partitionConfig + return nil +} + // CreateTasks creates a batch of given tasks for this task list func (db *taskListDB) CreateTasks(tasks []*persistence.CreateTaskInfo) (*persistence.CreateTasksResponse, error) { db.Lock() diff --git a/service/matching/tasklist/identifier.go b/service/matching/tasklist/identifier.go index c1e866ab0fc..4f62c40ba10 100644 --- a/service/matching/tasklist/identifier.go +++ b/service/matching/tasklist/identifier.go @@ -99,10 +99,10 @@ func (tn *qualifiedTaskListName) Parent(degree int) string { return "" } pid := (tn.partition+degree-1)/degree - 1 - return tn.mkName(pid) + return tn.GetPartition(pid) } -func (tn *qualifiedTaskListName) mkName(partition int) string { +func (tn *qualifiedTaskListName) GetPartition(partition int) string { if partition == 0 { return tn.baseName } diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go index 1b2596bef01..12e8218a69b 100644 --- a/service/matching/tasklist/interfaces.go +++ b/service/matching/tasklist/interfaces.go @@ -60,6 +60,8 @@ type ( GetTaskListKind() types.TaskListKind TaskListID() *Identifier TaskListPartitionConfig() *types.TaskListPartitionConfig + UpdateTaskListPartitionConfig(context.Context, *types.TaskListPartitionConfig) error + RefreshTaskListPartitionConfig(context.Context, *types.TaskListPartitionConfig) error } TaskMatcher interface { diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go index 9ef04d8edd3..0892d65d3d9 100644 --- a/service/matching/tasklist/interfaces_mock.go +++ b/service/matching/tasklist/interfaces_mock.go @@ -186,6 +186,20 @@ func (mr *MockManagerMockRecorder) HasPollerAfter(accessTime interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPollerAfter", reflect.TypeOf((*MockManager)(nil).HasPollerAfter), accessTime) } +// RefreshTaskListPartitionConfig mocks base method. +func (m *MockManager) RefreshTaskListPartitionConfig(arg0 context.Context, arg1 *types.TaskListPartitionConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RefreshTaskListPartitionConfig", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RefreshTaskListPartitionConfig indicates an expected call of RefreshTaskListPartitionConfig. +func (mr *MockManagerMockRecorder) RefreshTaskListPartitionConfig(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTaskListPartitionConfig", reflect.TypeOf((*MockManager)(nil).RefreshTaskListPartitionConfig), arg0, arg1) +} + // Start mocks base method. func (m *MockManager) Start() error { m.ctrl.T.Helper() @@ -254,6 +268,20 @@ func (mr *MockManagerMockRecorder) TaskListPartitionConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TaskListPartitionConfig", reflect.TypeOf((*MockManager)(nil).TaskListPartitionConfig)) } +// UpdateTaskListPartitionConfig mocks base method. +func (m *MockManager) UpdateTaskListPartitionConfig(arg0 context.Context, arg1 *types.TaskListPartitionConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateTaskListPartitionConfig", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateTaskListPartitionConfig indicates an expected call of UpdateTaskListPartitionConfig. +func (mr *MockManagerMockRecorder) UpdateTaskListPartitionConfig(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskListPartitionConfig", reflect.TypeOf((*MockManager)(nil).UpdateTaskListPartitionConfig), arg0, arg1) +} + // MockTaskMatcher is a mock of TaskMatcher interface. type MockTaskMatcher struct { ctrl *gomock.Controller diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index 3531fa36658..4e5576f97ba 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -33,6 +33,8 @@ import ( "sync/atomic" "time" + "golang.org/x/sync/errgroup" + "github.com/uber/cadence/client/matching" "github.com/uber/cadence/common" "github.com/uber/cadence/common/backoff" @@ -105,6 +107,7 @@ type ( logger log.Logger scope metrics.Scope timeSource clock.TimeSource + matchingClient matching.Client domainName string // pollerHistory stores poller which poll from this tasklist in last few minutes pollerHistory poller.History @@ -119,6 +122,7 @@ type ( startWG sync.WaitGroup // ensures that background processes do not start until setup is ready stopped int32 closeCallback func(Manager) + throttleRetry *backoff.ThrottleRetry qpsTracker stats.QPSTracker @@ -178,11 +182,16 @@ func NewManager( taskAckManager: messaging.NewAckManager(logger), taskGC: newTaskGC(db, taskListConfig), config: taskListConfig, + matchingClient: matchingClient, outstandingPollsMap: make(map[string]outstandingPollerInfo), domainName: domainName, scope: scope, timeSource: timeSource, closeCallback: closeCallback, + throttleRetry: backoff.NewThrottleRetry( + backoff.WithRetryPolicy(persistenceOperationRetryPolicy), + backoff.WithRetryableError(persistence.IsTransientError), + ), } tlMgr.pollerHistory = poller.NewPollerHistory(func() { @@ -216,20 +225,29 @@ func NewManager( func (c *taskListManagerImpl) Start() error { defer c.startWG.Done() + if !c.taskListID.IsRoot() && c.taskListKind == types.TaskListKindNormal { + var info *persistence.TaskListInfo + err := c.throttleRetry.Do(context.Background(), func() error { + var err error + info, err = c.db.GetTaskListInfo(c.taskListID.GetRoot()) + return err + }) + if err != nil { + var e *types.EntityNotExistsError + if !errors.As(err, &e) { + c.Stop() + return err + } + } else { + c.partitionConfig = info.AdaptivePartitionConfig.ToInternalType() + } + } if err := c.taskWriter.Start(); err != nil { c.Stop() return err } - c.loadTaskListPartitionConfig() - if c.taskListID.IsRoot() && c.taskListKind != types.TaskListKindSticky { + if c.taskListID.IsRoot() && c.taskListKind == types.TaskListKindNormal { c.partitionConfig = c.db.PartitionConfig().ToInternalType() - if c.partitionConfig == nil { - c.partitionConfig = &types.TaskListPartitionConfig{ - Version: 0, - NumReadPartitions: 1, - NumWritePartitions: 1, - } - } c.logger.Info("get task list partition config from db", tag.Dynamic("root-partition", c.taskListID.GetRoot()), tag.Dynamic("config", c.partitionConfig)) } c.liveness.Start() @@ -268,49 +286,112 @@ func (c *taskListManagerImpl) handleErr(err error) error { return err } -func (c *taskListManagerImpl) loadTaskListPartitionConfig() { - if c.taskListID.IsRoot() { - return - } +func (c *taskListManagerImpl) TaskListPartitionConfig() *types.TaskListPartitionConfig { c.partitionConfigLock.RLock() if c.partitionConfig != nil { + config := *c.partitionConfig c.partitionConfigLock.RUnlock() - return + c.logger.Debug("get task list partition config from db", tag.Dynamic("root-partition", c.taskListID.GetRoot()), tag.Dynamic("config", config)) + return &config } c.partitionConfigLock.RUnlock() + return nil +} - c.partitionConfigLock.Lock() - if c.partitionConfig != nil { +func isTaskListPartitionConfigEqual(a types.TaskListPartitionConfig, b types.TaskListPartitionConfig) bool { + return a.NumReadPartitions == b.NumReadPartitions && a.NumWritePartitions == b.NumWritePartitions +} + +func (c *taskListManagerImpl) RefreshTaskListPartitionConfig(ctx context.Context, config *types.TaskListPartitionConfig) error { + c.startWG.Wait() + if config == nil { + // if config is nil, we'll reload it from database + var info *persistence.TaskListInfo + err := c.throttleRetry.Do(ctx, func() error { + var err error + info, err = c.db.GetTaskListInfo(c.taskListID.GetRoot()) + return err + }) + if err != nil { + return err + } + config = info.AdaptivePartitionConfig.ToInternalType() + c.partitionConfigLock.Lock() + c.partitionConfig = config c.partitionConfigLock.Unlock() - return + return nil } + c.partitionConfigLock.Lock() defer c.partitionConfigLock.Unlock() - info, err := c.db.GetTaskListInfo(c.taskListID.GetRoot()) - if err != nil { - // Given current set up, it's possible that the root partition is created after non-root partition - // In this case, we don't fail the start for now, but set the config to nil. - // We'll check if the field is nil, if it is, we'll reload it from database on demand. - c.logger.Error("failed to get tasklist info of root partition", tag.Dynamic("root-partition", c.taskListID.GetRoot()), tag.Error(err)) - return + if c.partitionConfig == nil || c.partitionConfig.Version < config.Version { + c.partitionConfig = config } - c.partitionConfig = info.AdaptivePartitionConfig.ToInternalType() - if c.partitionConfig == nil { - c.partitionConfig = &types.TaskListPartitionConfig{ - Version: 0, - NumReadPartitions: 1, - NumWritePartitions: 1, + return nil +} + +func (c *taskListManagerImpl) UpdateTaskListPartitionConfig(ctx context.Context, config *types.TaskListPartitionConfig) error { + c.startWG.Wait() + var version int64 + originalNumReadPartitions := 1 + c.partitionConfigLock.Lock() + if c.partitionConfig != nil { + originalNumReadPartitions = int(c.partitionConfig.NumReadPartitions) + if isTaskListPartitionConfigEqual(*c.partitionConfig, *config) { + c.partitionConfigLock.Unlock() + return nil } + version = c.partitionConfig.Version + } + err := c.throttleRetry.Do(ctx, func() error { + return c.db.UpdateTaskListPartitionConfig(&persistence.TaskListPartitionConfig{ + Version: version + 1, + NumReadPartitions: int(config.NumReadPartitions), + NumWritePartitions: int(config.NumWritePartitions), + }) + }) + if err != nil { + c.partitionConfigLock.Unlock() + // We're not sure whether the update was persisted or not, + // Stop the tasklist manager and let it be reloaded + c.scope.IncCounter(metrics.TaskListPartitionUpdateFailedCounter) + c.Stop() + return err } - c.logger.Info("get task list partition config from db", tag.Dynamic("root-partition", c.taskListID.GetRoot()), tag.Dynamic("config", c.partitionConfig)) + c.partitionConfig = c.db.PartitionConfig().ToInternalType() + currentConfig := *c.partitionConfig + c.partitionConfigLock.Unlock() + // push update notification to other existing partitions + c.notifyPartitionConfig(ctx, currentConfig, originalNumReadPartitions) + return nil } -func (c *taskListManagerImpl) TaskListPartitionConfig() *types.TaskListPartitionConfig { - c.partitionConfigLock.RLock() - defer c.partitionConfigLock.RUnlock() - if c.partitionConfig != nil { - c.logger.Debug("get task list partition config from db", tag.Dynamic("root-partition", c.taskListID.GetRoot()), tag.Dynamic("config", c.partitionConfig)) +func (c *taskListManagerImpl) notifyPartitionConfig(ctx context.Context, config types.TaskListPartitionConfig, count int) { + taskListType := types.TaskListTypeDecision.Ptr() + if c.taskListID.GetType() == persistence.TaskListTypeActivity { + taskListType = types.TaskListTypeActivity.Ptr() + } + g := &errgroup.Group{} + for i := 1; i < count; i++ { + taskListName := c.taskListID.GetPartition(i) + g.Go(func() (e error) { + defer func() { log.CapturePanic(recover(), c.logger, &e) }() + + _, e = c.matchingClient.RefreshTaskListPartitionConfig(ctx, &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: c.taskListID.GetDomainID(), + TaskList: &types.TaskList{Name: taskListName}, + TaskListType: taskListType, + PartitionConfig: &config, + }) + if e != nil { + c.logger.Error("failed to notify partition", tag.Error(e), tag.Dynamic("task-list-partition-name", taskListName)) + } + return e + }) + } + err := g.Wait() + if err != nil { + c.logger.Error("failed to notify all partitions", tag.Error(err)) } - return c.partitionConfig } // AddTask adds a task to the task list. This method will first attempt a synchronous @@ -323,7 +404,6 @@ func (c *taskListManagerImpl) AddTask(ctx context.Context, params AddTaskParams) c.Stop() return false, errShutdown } - c.loadTaskListPartitionConfig() if params.ForwardedFrom == "" { // request sent by history service c.liveness.MarkAlive() @@ -410,7 +490,6 @@ func (c *taskListManagerImpl) DispatchQueryTask( request *types.MatchingQueryWorkflowRequest, ) (*types.QueryWorkflowResponse, error) { c.startWG.Wait() - c.loadTaskListPartitionConfig() task := newInternalQueryTask(taskID, request) return c.matcher.OfferQuery(ctx, task) } @@ -428,7 +507,6 @@ func (c *taskListManagerImpl) GetTask( return nil, ErrNoTasks } c.liveness.MarkAlive() - c.loadTaskListPartitionConfig() // TODO: consider return early if QPS and backlog count are both 0, // since there is no task to be returned task, err := c.getTask(ctx, maxDispatchPerSecond) diff --git a/service/matching/tasklist/task_list_manager_test.go b/service/matching/tasklist/task_list_manager_test.go index 57163c836e3..a9512d6b3c1 100644 --- a/service/matching/tasklist/task_list_manager_test.go +++ b/service/matching/tasklist/task_list_manager_test.go @@ -33,7 +33,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" + "golang.org/x/sync/errgroup" + "github.com/uber/cadence/client/matching" "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/clock" @@ -50,6 +52,50 @@ import ( "github.com/uber/cadence/service/matching/poller" ) +type mockDeps struct { + mockDomainCache *cache.MockDomainCache + mockTaskManager *persistence.MockTaskManager + mockPartitioner *partition.MockPartitioner + mockMatchingClient *matching.MockClient + mockTimeSource clock.MockedTimeSource + dynamicClient dynamicconfig.Client +} + +func setupMocksForTaskListManager(t *testing.T, taskListID *Identifier, taskListKind types.TaskListKind) (*taskListManagerImpl, *mockDeps) { + ctrl := gomock.NewController(t) + dynamicClient := dynamicconfig.NewInMemoryClient() + logger := testlogger.New(t) + metricsClient := metrics.NewNoopMetricsClient() + clusterMetadata := cluster.GetTestClusterMetadata(true) + deps := &mockDeps{ + mockDomainCache: cache.NewMockDomainCache(ctrl), + mockTaskManager: persistence.NewMockTaskManager(ctrl), + mockPartitioner: partition.NewMockPartitioner(ctrl), + mockMatchingClient: matching.NewMockClient(ctrl), + mockTimeSource: clock.NewMockedTimeSource(), + dynamicClient: dynamicClient, + } + deps.mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("domainName", nil).Times(1) + config := config.NewConfig(dynamicconfig.NewCollection(dynamicClient, logger), "hostname", getIsolationgroupsHelper) + tlm, err := NewManager( + deps.mockDomainCache, + logger, + metricsClient, + deps.mockTaskManager, + clusterMetadata, + deps.mockPartitioner, + deps.mockMatchingClient, + func(Manager) {}, + taskListID, + &taskListKind, + config, + deps.mockTimeSource, + deps.mockTimeSource.Now(), + ) + require.NoError(t, err) + return tlm.(*taskListManagerImpl), deps +} + func defaultTestConfig() *config.Config { config := config.NewConfig(dynamicconfig.NewNopCollection(), "some random hostname", getIsolationgroupsHelper) config.LongPollExpirationInterval = dynamicconfig.GetDurationPropertyFnFilteredByTaskListInfo(100 * time.Millisecond) @@ -929,55 +975,418 @@ func getIsolationgroupsHelper() []string { return []string{"datacenterA", "datacenterB"} } -func TestLoadTaskListPartitionConfig(t *testing.T) { - ctrl := gomock.NewController(t) - mockPartitioner := partition.NewMockPartitioner(ctrl) - mockPartitioner.EXPECT().GetIsolationGroupByDomainID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() - mockDomainCache := cache.NewMockDomainCache(ctrl) - mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(cache.CreateDomainCacheEntry("domainName"), nil).AnyTimes() - mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("domainName", nil).AnyTimes() +func TestRefreshTaskListPartitionConfig(t *testing.T) { + testCases := []struct { + name string + req *types.TaskListPartitionConfig + originalConfig *types.TaskListPartitionConfig + setupMocks func(*mockDeps) + expectedConfig *types.TaskListPartitionConfig + expectError bool + expectedError string + }{ + { + name: "success - refresh from request", + req: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + setupMocks: func(m *mockDeps) {}, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + }, + { + name: "success - ignore older version", + req: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + originalConfig: &types.TaskListPartitionConfig{ + Version: 3, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + setupMocks: func(m *mockDeps) {}, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 3, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, + { + name: "success - refresh from database", + originalConfig: &types.TaskListPartitionConfig{ + Version: 3, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + setupMocks: func(deps *mockDeps) { + deps.mockTaskManager.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ + DomainID: "domain-id", + DomainName: "domainName", + TaskList: "tl", + TaskType: persistence.TaskListTypeDecision, + }).Return(&persistence.GetTaskListResponse{ + TaskListInfo: &persistence.TaskListInfo{ + AdaptivePartitionConfig: &persistence.TaskListPartitionConfig{ + Version: 4, + NumReadPartitions: 10, + NumWritePartitions: 10, + }, + }, + }, nil) + }, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 4, + NumReadPartitions: 10, + NumWritePartitions: 10, + }, + }, + { + name: "failed to refresh from database", + originalConfig: &types.TaskListPartitionConfig{ + Version: 3, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + setupMocks: func(deps *mockDeps) { + deps.mockTaskManager.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ + DomainID: "domain-id", + DomainName: "domainName", + TaskList: "tl", + TaskType: persistence.TaskListTypeDecision, + }).Return(nil, errors.New("some error")) + }, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 3, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + expectError: true, + expectedError: "some error", + }, + } - mockTm := persistence.NewMockTaskManager(ctrl) - mockTm.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ - DomainID: "domain", - DomainName: "domainName", - TaskList: "tasklist", - TaskType: persistence.TaskListTypeActivity, - }).Return(nil, errors.New("error")).Times(1) - mockTm.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ - DomainID: "domain", - DomainName: "domainName", - TaskList: "tasklist", - TaskType: persistence.TaskListTypeActivity, - }).Return(&persistence.GetTaskListResponse{ - TaskListInfo: &persistence.TaskListInfo{ - AdaptivePartitionConfig: nil, + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) + require.NoError(t, err) + tlm, deps := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + tc.setupMocks(deps) + tlm.partitionConfig = tc.originalConfig + tlm.startWG.Done() + + err = tlm.RefreshTaskListPartitionConfig(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + assert.Equal(t, tc.expectedConfig, tlm.TaskListPartitionConfig()) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedConfig, tlm.TaskListPartitionConfig()) + } + }) + } +} + +func TestUpdateTaskListPartitionConfig(t *testing.T) { + testCases := []struct { + name string + req *types.TaskListPartitionConfig + originalConfig *types.TaskListPartitionConfig + setupMocks func(*mockDeps) + expectedConfig *types.TaskListPartitionConfig + expectError bool + expectedError string + }{ + { + name: "success - no op", + req: &types.TaskListPartitionConfig{ + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + originalConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + setupMocks: func(m *mockDeps) {}, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + }, + { + name: "success - update", + req: &types.TaskListPartitionConfig{ + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + originalConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + setupMocks: func(deps *mockDeps) { + deps.mockTaskManager.EXPECT().UpdateTaskList(gomock.Any(), &persistence.UpdateTaskListRequest{ + DomainName: "domainName", + TaskListInfo: &persistence.TaskListInfo{ + DomainID: "domain-id", + Name: "tl", + AckLevel: 0, + RangeID: 0, + Kind: persistence.TaskListKindNormal, + AdaptivePartitionConfig: &persistence.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }, + }).Return(&persistence.UpdateTaskListResponse{}, nil) + deps.mockMatchingClient.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "domain-id", + TaskList: &types.TaskList{Name: "/__cadence_sys/tl/1"}, + TaskListType: types.TaskListTypeDecision.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }).Return(&types.MatchingRefreshTaskListPartitionConfigResponse{}, nil) + deps.mockMatchingClient.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "domain-id", + TaskList: &types.TaskList{Name: "/__cadence_sys/tl/2"}, + TaskListType: types.TaskListTypeDecision.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }).Return(&types.MatchingRefreshTaskListPartitionConfigResponse{}, nil) + }, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }, + { + name: "success - push failures are ignored", + req: &types.TaskListPartitionConfig{ + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + originalConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + setupMocks: func(deps *mockDeps) { + deps.mockTaskManager.EXPECT().UpdateTaskList(gomock.Any(), &persistence.UpdateTaskListRequest{ + DomainName: "domainName", + TaskListInfo: &persistence.TaskListInfo{ + DomainID: "domain-id", + Name: "tl", + AckLevel: 0, + RangeID: 0, + Kind: persistence.TaskListKindNormal, + AdaptivePartitionConfig: &persistence.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }, + }).Return(&persistence.UpdateTaskListResponse{}, nil) + deps.mockMatchingClient.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "domain-id", + TaskList: &types.TaskList{Name: "/__cadence_sys/tl/1"}, + TaskListType: types.TaskListTypeDecision.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }).Return(nil, errors.New("matching client error")) + deps.mockMatchingClient.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.MatchingRefreshTaskListPartitionConfigRequest{ + DomainUUID: "domain-id", + TaskList: &types.TaskList{Name: "/__cadence_sys/tl/2"}, + TaskListType: types.TaskListTypeDecision.Ptr(), + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }).Return(nil, errors.New("matching client error")) + }, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }, + { + name: "failed to update", + req: &types.TaskListPartitionConfig{ + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + originalConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + setupMocks: func(deps *mockDeps) { + deps.mockTaskManager.EXPECT().UpdateTaskList(gomock.Any(), &persistence.UpdateTaskListRequest{ + DomainName: "domainName", + TaskListInfo: &persistence.TaskListInfo{ + DomainID: "domain-id", + Name: "tl", + AckLevel: 0, + RangeID: 0, + Kind: persistence.TaskListKindNormal, + AdaptivePartitionConfig: &persistence.TaskListPartitionConfig{ + Version: 2, + NumReadPartitions: 3, + NumWritePartitions: 1, + }, + }, + }).Return(nil, errors.New("some error")) + }, + expectedConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + expectError: true, + expectedError: "some error", }, - }, nil).Times(1) + } - tlID, err := NewIdentifier("domain", "/__cadence_sys/tasklist/1", persistence.TaskListTypeActivity) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) + require.NoError(t, err) + tlm, deps := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + tc.setupMocks(deps) + tlm.partitionConfig = tc.originalConfig + tlm.startWG.Done() + + err = tlm.UpdateTaskListPartitionConfig(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + assert.Equal(t, int32(1), tlm.stopped) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectedConfig, tlm.TaskListPartitionConfig()) + }) + } +} + +func TestRefreshTaskListPartitionConfigConcurrency(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "/__cadence_sys/tl/1", persistence.TaskListTypeDecision) require.NoError(t, err) + tlm, _ := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + tlm.startWG.Done() + + var g errgroup.Group + for i := 0; i < 100; i++ { + v := i + g.Go(func() error { + return tlm.RefreshTaskListPartitionConfig(context.Background(), &types.TaskListPartitionConfig{Version: int64(v), NumReadPartitions: int32(v), NumWritePartitions: int32(v)}) + }) + } + require.NoError(t, g.Wait()) + assert.Equal(t, int64(99), tlm.TaskListPartitionConfig().Version) +} - tlMgr, err := NewManager( - mockDomainCache, - testlogger.New(t), - metrics.NewClient(tally.NoopScope, metrics.Matching), - mockTm, - cluster.GetTestClusterMetadata(true), - mockPartitioner, - nil, - func(Manager) {}, - tlID, - types.TaskListKindNormal.Ptr(), - defaultTestConfig(), - clock.NewRealTimeSource(), - time.Now()) +func TestUpdateTaskListPartitionConfigConcurrency(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "/__cadence_sys/tl/1", persistence.TaskListTypeDecision) + require.NoError(t, err) + tlm, deps := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + deps.mockTaskManager.EXPECT().UpdateTaskList(gomock.Any(), gomock.Any()).Return(&persistence.UpdateTaskListResponse{}, nil).AnyTimes() + deps.mockMatchingClient.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), gomock.Any()).Return(&types.MatchingRefreshTaskListPartitionConfigResponse{}, nil).AnyTimes() + tlm.startWG.Done() + + var g errgroup.Group + for i := 0; i < 100; i++ { + v := i + g.Go(func() error { + return tlm.UpdateTaskListPartitionConfig(context.Background(), &types.TaskListPartitionConfig{NumReadPartitions: int32(v), NumWritePartitions: int32(v)}) + }) + } + require.NoError(t, g.Wait()) + assert.Equal(t, int64(100), tlm.TaskListPartitionConfig().Version) +} - tlm := tlMgr.(*taskListManagerImpl) - tlm.loadTaskListPartitionConfig() +func TestManagerStart_RootPartition(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) + require.NoError(t, err) + tlm, deps := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + deps.mockTaskManager.EXPECT().LeaseTaskList(gomock.Any(), &persistence.LeaseTaskListRequest{ + DomainID: "domain-id", + DomainName: "domainName", + TaskList: "tl", + TaskType: persistence.TaskListTypeDecision, + }).Return(&persistence.LeaseTaskListResponse{ + TaskListInfo: &persistence.TaskListInfo{ + DomainID: "domain-id", + Name: "tl", + Kind: persistence.TaskListKindNormal, + AckLevel: 0, + RangeID: 0, + }, + }, nil) + assert.NoError(t, tlm.Start()) assert.Nil(t, tlm.TaskListPartitionConfig()) - tlm.loadTaskListPartitionConfig() - assert.Equal(t, &types.TaskListPartitionConfig{NumReadPartitions: 1, NumWritePartitions: 1}, tlm.TaskListPartitionConfig()) - tlm.loadTaskListPartitionConfig() - assert.Equal(t, &types.TaskListPartitionConfig{NumReadPartitions: 1, NumWritePartitions: 1}, tlm.TaskListPartitionConfig()) +} + +func TestManagerStart_NonRootPartition(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "/__cadence_sys/tl/1", persistence.TaskListTypeDecision) + require.NoError(t, err) + tlm, deps := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + deps.mockTaskManager.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ + DomainID: "domain-id", + DomainName: "domainName", + TaskList: "tl", + TaskType: persistence.TaskListTypeDecision, + }).Return(&persistence.GetTaskListResponse{ + TaskListInfo: &persistence.TaskListInfo{ + DomainID: "domain-id", + Name: "tl", + Kind: persistence.TaskListKindNormal, + AckLevel: 0, + RangeID: 0, + AdaptivePartitionConfig: &persistence.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, + }, + }, nil) + deps.mockTaskManager.EXPECT().LeaseTaskList(gomock.Any(), &persistence.LeaseTaskListRequest{ + DomainID: "domain-id", + DomainName: "domainName", + TaskList: "/__cadence_sys/tl/1", + TaskType: persistence.TaskListTypeDecision, + }).Return(&persistence.LeaseTaskListResponse{ + TaskListInfo: &persistence.TaskListInfo{ + DomainID: "domain-id", + Name: "/__cadence_sys/tl/1", + Kind: persistence.TaskListKindNormal, + AckLevel: 0, + RangeID: 0, + }, + }, nil) + assert.NoError(t, tlm.Start()) + assert.Equal(t, &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 3, + NumWritePartitions: 3, + }, tlm.TaskListPartitionConfig()) }