diff --git a/service/matching/tasklist/matcher_test.go b/service/matching/tasklist/matcher_test.go index ba59608df5a..98d9fe4f075 100644 --- a/service/matching/tasklist/matcher_test.go +++ b/service/matching/tasklist/matcher_test.go @@ -637,6 +637,641 @@ func (t *MatcherTestSuite) TestIsolationPollFailure() { t.Nil(task) } +func (t *MatcherTestSuite) TestOffer_RateLimited() { + t.matcher.UpdateRatelimit(common.Float64Ptr(0)) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + ctx := context.Background() + + matched, err := t.matcher.Offer(ctx, task) + + t.ErrorIs(err, ErrTasklistThrottled) + t.False(matched) +} + +func (t *MatcherTestSuite) TestOffer_NoTimeoutSyncMatchedNoError() { + defer goleak.VerifyNone(t.T()) + + t.matcher.config.LocalTaskWaitTime = func() time.Duration { return 0 } + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + t.disableRemoteForwarding("") + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + syncMatched, err := t.matcher.Offer(ctx, task) + cancel() + wait() + + t.NoError(err) + t.True(syncMatched) +} + +func (t *MatcherTestSuite) TestOffer_NoTimeoutSyncMatchedError() { + defer goleak.VerifyNone(t.T()) + + t.matcher.config.LocalTaskWaitTime = func() time.Duration { return 0 } + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + t.disableRemoteForwarding("") + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + task.ResponseC <- errShutdown + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + syncMatched, err := t.matcher.Offer(ctx, task) + cancel() + wait() + + t.Error(err) + t.True(syncMatched) +} + +func (t *MatcherTestSuite) TestOffer_NoTimeoutAsyncMatchedNoError() { + defer goleak.VerifyNone(t.T()) + + t.matcher.config.LocalTaskWaitTime = func() time.Duration { return 0 } + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + t.disableRemoteForwarding("") + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + syncMatched, err := t.matcher.Offer(ctx, task) + cancel() + wait() + + t.NoError(err) + t.False(syncMatched) +} + +func (t *MatcherTestSuite) TestOffer_AsyncMatchedNoError() { + defer goleak.VerifyNone(t.T()) + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + t.disableRemoteForwarding("") + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + syncMatch, err := t.matcher.Offer(ctx, task) + cancel() + wait() + + t.NoError(err) + t.False(syncMatch) +} + +func (t *MatcherTestSuite) TestOfferOrTimeout_SyncMatchTimedOut() { + defer goleak.VerifyNone(t.T()) + + t.disableRemoteForwarding("") + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + cancel() + matched, err := t.matcher.OfferOrTimeout(ctx, time.Now(), task) + wait() + + t.NoError(err) + t.False(matched) +} + +func (t *MatcherTestSuite) TestOfferOrTimeout_AsyncMatchNotMatched() { + defer goleak.VerifyNone(t.T()) + + t.disableRemoteForwarding("") + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + matched, err := t.matcher.OfferOrTimeout(ctx, time.Now(), task) + cancel() + wait() + + t.NoError(err) + t.False(matched) +} + +func (t *MatcherTestSuite) TestOfferOrTimeout_AsyncMatchMatched() { + defer goleak.VerifyNone(t.T()) + + t.disableRemoteForwarding("") + + wait := ensureAsyncReady(time.Second, func(ctx context.Context) { + task, err := t.matcher.Poll(ctx, "") + if err == nil { + task.Finish(nil) + } + }) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, &types.ActivityTaskDispatchInfo{}, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + matched, err := t.matcher.OfferOrTimeout(ctx, time.Now(), task) + cancel() + wait() + + t.NoError(err) + t.True(matched) +} + +func (t *MatcherTestSuite) TestOfferOrTimeout_TimedOut() { + t.disableRemoteForwarding("") + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, &types.ActivityTaskDispatchInfo{}, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + + matched, err := t.matcher.OfferOrTimeout(ctx, time.Now(), task) + + t.NoError(err) + t.False(matched) +} + +func (t *MatcherTestSuite) TestOfferQuery_ForwardError() { + ctx := context.Background() + task := newInternalQueryTask(uuid.New(), &types.MatchingQueryWorkflowRequest{}) + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + fn := func() <-chan *ForwarderReqToken { + c := make(chan *ForwarderReqToken, 1) + c <- &ForwarderReqToken{ + ch: make(chan *ForwarderReqToken, 1), + } + return c + } + + mockForwarder.EXPECT().AddReqTokenC().Return(fn()).Times(1) + mockForwarder.EXPECT().ForwardQueryTask(ctx, task).Return(nil, ErrNoParent).Times(1) + + retTask, err := t.matcher.OfferQuery(ctx, task) + + t.ErrorIs(err, ErrNoParent) + t.Nil(retTask) +} + +func (t *MatcherTestSuite) TestOfferQuery_ContextExpired() { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + cancel() + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + task := newInternalQueryTask(uuid.New(), &types.MatchingQueryWorkflowRequest{}) + + mockForwarder.EXPECT().AddReqTokenC().Times(1) + + retTask, err := t.matcher.OfferQuery(ctx, task) + + t.ErrorIs(err, context.Canceled) + t.Nil(retTask) +} + +func (t *MatcherTestSuite) TestUpdateRatelimit_RateGreaterThanNumberOfPartitions() { + t.matcher.config.NumReadPartitions = func() int { return 10 } + + t.matcher.UpdateRatelimit(common.Float64Ptr(100)) + + t.Equal(rate.Limit(10), t.matcher.limiter.Limit()) +} + +func (t *MatcherTestSuite) TestUpdateRatelimit_RateLessThanOrEqualToNumberOfPartitions() { + t.matcher.config.NumReadPartitions = func() int { return 9 } + + t.matcher.UpdateRatelimit(common.Float64Ptr(5)) + + t.Equal(rate.Limit(5), t.matcher.limiter.Limit()) +} + +func (t *MatcherTestSuite) TestUpdateRatelimit_NilRps() { + rateLimit := t.matcher.limiter.Limit() + + t.matcher.UpdateRatelimit(nil) + + t.Equal(rateLimit, t.matcher.limiter.Limit()) +} + +func (t *MatcherTestSuite) TestRate() { + t.Equal(float64(t.matcher.limiter.Limit()), t.matcher.Rate()) +} + +func (t *MatcherTestSuite) TestMustOffer_RateLimited() { + t.matcher.UpdateRatelimit(common.Float64Ptr(0)) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + ctx := context.Background() + + err := t.matcher.MustOffer(ctx, task) + + t.ErrorIs(err, ErrTasklistThrottled) +} + +func (t *MatcherTestSuite) TestMustOffer_ContextExpiredFirstAttempt() { + ctx, cancel := context.WithTimeout(context.Background(), t.matcher.config.LocalTaskWaitTime()) + defer cancel() + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + err := t.matcher.MustOffer(ctx, task) + + t.ErrorIs(err, context.DeadlineExceeded) +} + +func (t *MatcherTestSuite) TestMustOffer_ContextExpiredAfterFirstAttempt() { + ctx, cancel := context.WithTimeout(context.Background(), 2*t.matcher.config.LocalTaskWaitTime()) + defer cancel() + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + mockForwarder.EXPECT().AddReqTokenC().Times(1) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + err := t.matcher.MustOffer(ctx, task) + + t.ErrorIs(err, context.DeadlineExceeded) +} + +func (t *MatcherTestSuite) TestMustOffer_LocalMatchAfterChildCtxExpired() { + ctx := context.Background() + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + t.disableRemoteForwarding("") + + mockForwarder.EXPECT().AddReqTokenC().AnyTimes() + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + go func() { + // Waits for the child context to expire; by default, it is 10 ms. + // Forces no forwarding and the context background will never expire, therefore guaranteeing that the poller + // will pick up the task after the first attempt. + time.Sleep(200 * time.Millisecond) + retTask, err := t.matcher.Poll(ctx, "") + if err == nil { + retTask.Finish(nil) + } + }() + + err := t.matcher.MustOffer(ctx, task) + + t.NoError(err) +} + +func (t *MatcherTestSuite) TestMustOffer_LocalMatchAfterForwardError() { + ctx := context.Background() + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + forwardToken := &ForwarderReqToken{ + ch: make(chan *ForwarderReqToken, 1), + } + + fn := func() <-chan *ForwarderReqToken { + c := make(chan *ForwarderReqToken, 1) + c <- forwardToken + return c + } + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + mockForwarder.EXPECT().AddReqTokenC().Return(fn()).Times(1) + mockForwarder.EXPECT().ForwardTask(gomock.Any(), task).Return(ErrNoParent).Times(1) + + go func() { + <-forwardToken.ch + retTask, err := t.matcher.Poll(ctx, "") + if err == nil { + retTask.Finish(nil) + } + }() + + err := t.matcher.MustOffer(ctx, task) + + t.NoError(err) +} + +func (t *MatcherTestSuite) TestMustOffer_ContextExpiredAfterForwardError() { + ctx, cancel := context.WithCancel(context.Background()) + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + forwardToken := &ForwarderReqToken{ + ch: make(chan *ForwarderReqToken, 1), + } + + fn := func() <-chan *ForwarderReqToken { + c := make(chan *ForwarderReqToken, 1) + c <- forwardToken + return c + } + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", false, nil, "") + + mockForwarder.EXPECT().AddReqTokenC().Return(fn()).Times(1) + mockForwarder.EXPECT().ForwardTask(gomock.Any(), task).Return(ErrNoParent).Times(1) + + go func() { + <-forwardToken.ch + // using cancel here to simulate the context being expired + cancel() + }() + + err := t.matcher.MustOffer(ctx, task) + + t.Error(err) +} + +func (t *MatcherTestSuite) Test_pollOrForward_PollIsolatedTask() { + ctx := context.Background() + startT := time.Now() + isolationGroup := "dca1" + isolatedTaskC := make(chan *InternalTask, 1) + + t.disableRemoteForwarding("") + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + // Mock the fwdrPollReqTokenC method to return a controlled channel + mockTokenC := make(chan *ForwarderReqToken) + mockForwarder.EXPECT().PollReqTokenC(isolationGroup).Return(mockTokenC).AnyTimes() + + // Test pollOrForward for isolated task - poll + isolatedTask := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, isolationGroup) + isolatedTaskC <- isolatedTask + retTask, err := t.matcher.pollOrForward(ctx, startT, isolationGroup, isolatedTaskC, nil, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(isolatedTask, retTask) +} + +func (t *MatcherTestSuite) Test_pollOrForward_PollTask() { + ctx := context.Background() + startT := time.Now() + taskC := make(chan *InternalTask, 1) + + t.disableRemoteForwarding("") + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + // Mock the fwdrPollReqTokenC method to return a controlled channel + mockTokenC := make(chan *ForwarderReqToken) + mockForwarder.EXPECT().PollReqTokenC("").Return(mockTokenC).AnyTimes() + + // Test pollOrForward for regular task - poll + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + taskC <- task + retTask, err := t.matcher.pollOrForward(ctx, startT, "", nil, taskC, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(task, retTask) +} + +func (t *MatcherTestSuite) Test_pollOrForward_PollQueryTask() { + ctx := context.Background() + startT := time.Now() + queryTaskC := make(chan *InternalTask, 1) + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + // Mock the fwdrPollReqTokenC method to return a controlled channel + mockTokenC := make(chan *ForwarderReqToken) + mockForwarder.EXPECT().PollReqTokenC("").Return(mockTokenC).AnyTimes() + + // Test pollOrForward for query task - poll + queryTask := newInternalQueryTask(uuid.New(), &types.MatchingQueryWorkflowRequest{}) + queryTaskC <- queryTask + retTask, err := t.matcher.pollOrForward(ctx, startT, "", nil, nil, queryTaskC) + t.NoError(err) + t.NotNil(retTask) + t.Equal(queryTask, retTask) +} + +func (t *MatcherTestSuite) Test_pollOrForward_ForwardTask() { + ctx := context.Background() + startT := time.Now() + isolationGroup := "dca1" + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + // Mock the fwdrPollReqTokenC method to return a controlled channel + mockTokenC := make(chan *ForwarderReqToken, 1) + forwardToken := &ForwarderReqToken{ + isolatedCh: map[string]chan *ForwarderReqToken{isolationGroup: make(chan *ForwarderReqToken, 1)}, + } + mockTokenC <- forwardToken + mockForwarder.EXPECT().PollReqTokenC(isolationGroup).Return(mockTokenC).AnyTimes() + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + mockForwarder.EXPECT().ForwardPoll(ctx).Return(task, nil).Times(1) + + retTask, err := t.matcher.pollOrForward(ctx, startT, isolationGroup, nil, nil, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(task, retTask) +} + +func (t *MatcherTestSuite) Test_pollOrForward_ForwardTaskThenPoll() { + ctx := context.Background() + startT := time.Now() + isolationGroup := "dca1" + taskC := make(chan *InternalTask, 1) + + mockForwarder := NewMockForwarder(t.controller) + t.matcher.fwdr = mockForwarder + + // Mock the fwdrPollReqTokenC method to return a controlled channel + mockTokenC := make(chan *ForwarderReqToken, 1) + forwardToken := &ForwarderReqToken{ + isolatedCh: map[string]chan *ForwarderReqToken{isolationGroup: make(chan *ForwarderReqToken, 1)}, + } + mockTokenC <- forwardToken + mockForwarder.EXPECT().PollReqTokenC(isolationGroup).Return(mockTokenC).AnyTimes() + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + mockForwarder.EXPECT().ForwardPoll(ctx).Return(nil, ErrNoParent).Times(1) + + // Add the task after the forwarderReqToken is released + // It's not a race condition because the PollReqTokenC is mocked and does not use the isolatedCh + go func() { + select { + case <-forwardToken.isolatedCh[isolationGroup]: + taskC <- task + } + }() + + retTask, err := t.matcher.pollOrForward(ctx, startT, isolationGroup, nil, taskC, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(task, retTask) +} + +func (t *MatcherTestSuite) Test_poll_IsolatedTask() { + ctx := context.Background() + startT := time.Now() + isolationGroup := "dca1" + isolatedTaskC := make(chan *InternalTask, 1) + + isolatedTask := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, isolationGroup) + isolatedTaskC <- isolatedTask + retTask, err := t.matcher.poll(ctx, startT, isolatedTaskC, nil, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(isolatedTask, retTask) +} + +func (t *MatcherTestSuite) Test_poll_Task() { + ctx := context.Background() + startT := time.Now() + taskC := make(chan *InternalTask, 1) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + taskC <- task + retTask, err := t.matcher.poll(ctx, startT, nil, taskC, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(task, retTask) +} + +func (t *MatcherTestSuite) Test_poll_QueryTask() { + ctx := context.Background() + startT := time.Now() + queryTaskC := make(chan *InternalTask, 1) + + queryTask := newInternalQueryTask(uuid.New(), &types.MatchingQueryWorkflowRequest{}) + queryTaskC <- queryTask + retTask, err := t.matcher.poll(ctx, startT, nil, nil, queryTaskC) + t.NoError(err) + t.NotNil(retTask) + t.Equal(queryTask, retTask) +} + +func (t *MatcherTestSuite) Test_poll_ContextExpired() { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + cancel() + + startT := time.Now() + + retTask, err := t.matcher.poll(ctx, startT, nil, nil, nil) + + t.ErrorIs(err, ErrNoTasks) + t.Nil(retTask) +} + +func (t *MatcherTestSuite) Test_pollNonBlocking_IsolatedTask() { + t.matcher.config.LocalPollWaitTime = func() time.Duration { return 0 } + + ctx := context.Background() + isolationGroup := "dca1" + isolatedTaskC := make(chan *InternalTask, 1) + + isolatedTask := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, isolationGroup) + isolatedTaskC <- isolatedTask + retTask, err := t.matcher.pollNonBlocking(ctx, isolatedTaskC, nil, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(isolatedTask, retTask) +} + +func (t *MatcherTestSuite) Test_pollNonBlocking_Task() { + t.matcher.config.LocalPollWaitTime = func() time.Duration { return 0 } + + ctx := context.Background() + taskC := make(chan *InternalTask, 1) + + task := newInternalTask(t.newTaskInfo(), nil, types.TaskSourceHistory, "", true, nil, "") + taskC <- task + retTask, err := t.matcher.pollNonBlocking(ctx, nil, taskC, nil) + t.NoError(err) + t.NotNil(retTask) + t.Equal(task, retTask) +} + +func (t *MatcherTestSuite) Test_pollNonBlocking_QueryTask() { + t.matcher.config.LocalPollWaitTime = func() time.Duration { return 0 } + + ctx := context.Background() + queryTaskC := make(chan *InternalTask, 1) + + queryTask := newInternalQueryTask(uuid.New(), &types.MatchingQueryWorkflowRequest{}) + queryTaskC <- queryTask + retTask, err := t.matcher.pollNonBlocking(ctx, nil, nil, queryTaskC) + t.NoError(err) + t.NotNil(retTask) + t.Equal(queryTask, retTask) +} + +func (t *MatcherTestSuite) Test_pollNonBlocking_NoTasks() { + t.matcher.config.LocalPollWaitTime = func() time.Duration { return 0 } + + ctx := context.Background() + + retTask, err := t.matcher.pollNonBlocking(ctx, nil, nil, nil) + + t.ErrorIs(err, ErrNoTasks) + t.Nil(retTask) +} + +func (t *MatcherTestSuite) Test_fwdrPollReqTokenC() { + t.matcher.fwdr = nil + t.Equal(noopForwarderTokenC, t.matcher.fwdrPollReqTokenC("")) +} + func (t *MatcherTestSuite) disableRemoteForwarding(isolationGroup string) { // force disable remote forwarding for i := 0; i < len(t.isolationGroups)+1; i++ {