From 2594652f8837e42431960fae47d0372370409a7f Mon Sep 17 00:00:00 2001 From: Dominic Black Date: Wed, 19 Jul 2023 19:13:28 +0100 Subject: [PATCH] Fix possible deadlock in AWS pubsub --- runtime/pubsub/internal/aws/topic.go | 46 +++++++++-- runtime/pubsub/internal/utils/workers.go | 66 ++++++++++------ runtime/pubsub/internal/utils/workers_test.go | 76 +++++++++++++++++++ 3 files changed, 160 insertions(+), 28 deletions(-) diff --git a/runtime/pubsub/internal/aws/topic.go b/runtime/pubsub/internal/aws/topic.go index ebc80489b1..95378411b9 100644 --- a/runtime/pubsub/internal/aws/topic.go +++ b/runtime/pubsub/internal/aws/topic.go @@ -87,12 +87,28 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin } go func() { + defer func() { + if r := recover(); r != nil { + logger.Error().Interface("panic", r).Msg("panic in subscriber, no longer processing messages") + } else { + logger.Info().Msg("subscriber stopped due to context cancellation") + } + }() + for t.ctx.Err() == nil { err := utils.WorkConcurrently( t.ctx, maxConcurrency, 10, func(ctx context.Context, maxToFetch int) ([]sqsTypes.Message, error) { - resp, err := t.sqsClient.ReceiveMessage(t.ctx, &sqs.ReceiveMessageInput{ + // We should only long poll for 20 seconds, so if this takes more than + // 30 seconds we should cancel the context and try again + // + // We do this incase the ReceiveMessage call gets stuck on the server + // and doesn't return + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + resp, err := t.sqsClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(implCfg.ProviderName), AttributeNames: []sqsTypes.QueueAttributeName{"ApproximateReceiveCount"}, MaxNumberOfMessages: int32(maxToFetch), @@ -132,13 +148,29 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin } // Call the callback, and if there was no error, then we can delete the message - msgCtx, cancel := context.WithTimeout(t.ctx, ackDeadline) + msgCtx, cancel := context.WithTimeout(ctx, ackDeadline) + defer cancel() err = f(msgCtx, msgWrapper.MessageId, msgWrapper.Timestamp, int(deliveryAttempt), attributes, []byte(msgWrapper.Message)) cancel() + + // Check if the context has been cancelled, and if so, return the error + if ctx.Err() != nil { + return ctx.Err() + } + + // We want to wait a maximum of 30 seconds for the callback to complete + // otherwise we assume it has failed and we should retry + // + // We do this incase the callback gets stuck and doesn't return + ctx, responseCancel := context.WithTimeout(ctx, 30*time.Second) + defer responseCancel() + if err != nil { + logger.Err(err).Str("msg_id", msgWrapper.MessageId).Msg("unable to process message") + // If there was an error processing the message, apply the backoff policy _, delay := utils.GetDelay(retryPolicy.MaxRetries, retryPolicy.MinBackoff, retryPolicy.MaxBackoff, uint16(deliveryAttempt)) - _, visibilityChangeErr := t.sqsClient.ChangeMessageVisibility(t.ctx, &sqs.ChangeMessageVisibilityInput{ + _, visibilityChangeErr := t.sqsClient.ChangeMessageVisibility(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(implCfg.ProviderName), ReceiptHandle: msg.ReceiptHandle, VisibilityTimeout: int32(utils.Clamp(delay, time.Second, 12*time.Hour).Seconds()), @@ -146,9 +178,9 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin if visibilityChangeErr != nil { log.Warn().Err(visibilityChangeErr).Str("msg_id", msgWrapper.MessageId).Msg("unable to change message visibility to apply backoff rules") } - } - if err == nil { - _, err = t.sqsClient.DeleteMessage(t.ctx, &sqs.DeleteMessageInput{ + } else { + // If the message was processed successfully, delete it from the queue + _, err = t.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(implCfg.ProviderName), ReceiptHandle: msg.ReceiptHandle, }) @@ -156,9 +188,11 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin logger.Err(err).Str("msg_id", msgWrapper.MessageId).Msg("unable to delete message from SQS queue") } } + return nil }, ) + if err != nil && t.ctx.Err() == nil { logger.Warn().Err(err).Msg("pubsub subscription failed, retrying in 5 seconds") time.Sleep(5 * time.Second) diff --git a/runtime/pubsub/internal/utils/workers.go b/runtime/pubsub/internal/utils/workers.go index 4e641e7ff4..45f0b0b9a1 100644 --- a/runtime/pubsub/internal/utils/workers.go +++ b/runtime/pubsub/internal/utils/workers.go @@ -6,6 +6,8 @@ import ( "sync" "sync/atomic" "time" + + "encore.dev/beta/errs" ) const ( @@ -42,19 +44,37 @@ type WorkProcessor[Work any] func(ctx context.Context, work Work) error // this again you could end up with 2x maxConcurrency workers running at the same time. (1x from the original run who // are still processing work and 1x from the new run). func WorkConcurrently[Work any](ctx context.Context, maxConcurrency int, maxBatchSize int, fetch WorkFetcher[Work], worker WorkProcessor[Work]) error { + fetchWithPanicHandling := func(ctx context.Context, maxToFetch int) (work []Work, err error) { + defer func() { + if r := recover(); r != nil { + err = errs.B().Msgf("panic: %v", r).Err() + } + }() + return fetch(ctx, maxToFetch) + } + + workWithPanicHandling := func(ctx context.Context, work Work) (err error) { + defer func() { + if r := recover(); r != nil { + err = errs.B().Msgf("panic: %v", r).Err() + } + }() + return worker(ctx, work) + } + if maxConcurrency == 1 { // If there's no concurrency, we can just do everything synchronously within this goroutine // This avoids the overhead of creating mutexes being used - return workInSingleRoutine(ctx, fetch, worker) + return workInSingleRoutine(ctx, fetchWithPanicHandling, workWithPanicHandling) } else if maxConcurrency <= 0 { // If there's infinite concurrency, we can just do everything by spawning goroutines // for each work item - return workInInfiniteRoutines(ctx, maxBatchSize, fetch, worker) + return workInInfiniteRoutines(ctx, maxBatchSize, fetchWithPanicHandling, workWithPanicHandling) } else { // Else there's a cap on concurrency, we need to use channels to communicate between the fetcher and the workers - return workInWorkPool(ctx, maxConcurrency, maxBatchSize, fetch, worker) + return workInWorkPool(ctx, maxConcurrency, maxBatchSize, fetchWithPanicHandling, workWithPanicHandling) } } @@ -139,7 +159,12 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS workProcessor := func(work Work) { inFlight.Add(1) defer inFlight.Add(-1) - defer func() { workDone <- struct{}{} }() + defer func() { + select { + case workDone <- struct{}{}: + case <-fetchCtx.Done(): + } + }() // We use the parent context here, such that if we have a fetch error, the existing workers will // continue to run until they finish processing any work already have started on @@ -150,14 +175,17 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS // fetchProcessor is a small wrapper around the fetcher function that passes the fetched work to the workers // it will fetch upto maxConcurrency items at a time in batches of maxBatchSize items - var lastFetch time.Time + var lastFetch atomic.Pointer[time.Time] + var epoch time.Time + lastFetch.Store(&epoch) + var debounceTimer *time.Timer var fetchLock sync.Mutex fetchProcessor := func() { // Lock the fetcher so that we don't have multiple fetchers running at the same time fetchLock.Lock() defer fetchLock.Unlock() - defer func() { lastFetch = time.Now() }() + defer func() { now := time.Now(); lastFetch.Store(&now) }() // Work out how many items we need to fetch need := maxConcurrency - int(inFlight.Load()) @@ -179,7 +207,12 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS // Pass work to workers for _, w := range work { - workChan <- w + select { + case workChan <- w: + // success, we passed the work to a worker + case <-fetchCtx.Done(): + return + } } // Update the number of items we need to fetch @@ -191,16 +224,8 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS // Start the workers for i := 0; i < maxConcurrency; i++ { go func() { - for { - select { - case work, more := <-workChan: - if !more { - // the workChan has been closed, we can stop - return - } - - workProcessor(work) - } + for work := range workChan { + workProcessor(work) } }() } @@ -210,12 +235,9 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS workDone <- struct{}{} // Start fetching work -fetchLoop: - for { + for fetchCtx.Err() == nil { select { case <-fetchCtx.Done(): - // If the context is cancelled, we need to stop fetching work - break fetchLoop case <-workDone: if debounceTimer != nil { @@ -223,7 +245,7 @@ fetchLoop: debounceTimer = nil } - if time.Since(lastFetch) > maxFetchDebounce { + if time.Since(*lastFetch.Load()) > maxFetchDebounce { fetchProcessor() } else { debounceTimer = time.AfterFunc(workFetchDebounce, fetchProcessor) diff --git a/runtime/pubsub/internal/utils/workers_test.go b/runtime/pubsub/internal/utils/workers_test.go index 4e03ecdbcf..fedc860b3a 100644 --- a/runtime/pubsub/internal/utils/workers_test.go +++ b/runtime/pubsub/internal/utils/workers_test.go @@ -175,6 +175,9 @@ func TestWorkConcurrently(t *testing.T) { err := WorkConcurrently(ctx, tt.concurrency, tt.maxBatchSize, fetcher, processor) + workMu.Lock() + defer workMu.Unlock() + // Run assertions on the exit conditions c.Assert(timeoutCtx.Err(), qt.IsNil, qt.Commentf("test timed out - not all work was fetched within the timeout")) switch { @@ -220,3 +223,76 @@ func TestWorkConcurrently(t *testing.T) { }) } } + +func TestWorkConcurrentlyLoad(t *testing.T) { + t.Skipped() + + const load = 20_000 + msg := make([]string, load) + for i := 0; i < load; i++ { + msg[i] = fmt.Sprintf("msg %d", i) + } + + var ( + mu sync.Mutex + idx int + + wMu sync.Mutex + wrk []string + ) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + var err error + + for ctx.Err() == nil { + err = WorkConcurrently(ctx, 25, 10, func(ctx context.Context, maxToFetch int) ([]string, error) { + mu.Lock() + defer mu.Unlock() + + if idx >= load { + return nil, nil + } + + toFetch := maxToFetch + if toFetch > load-idx { + toFetch = load - idx + } + if toFetch == 0 { + return nil, nil + } + + rtn := make([]string, toFetch) + copy(rtn, msg[idx:idx+toFetch]) + idx += toFetch + + return rtn, nil + }, func(ctx context.Context, work string) error { + time.Sleep(100 * time.Millisecond) + wMu.Lock() + defer wMu.Unlock() + + wrk = append(wrk, work) + + if len(wrk)%250 == 0 { + panic("too much work") + } + + if len(wrk) == load { + cancel() + } + return nil + }) + + fmt.Printf("err (worked %d/%d - sent %d): %v\n", len(wrk), load, idx, err) + } + + if err != nil { + t.Fatalf("err (worked %d/%d - sent %d): %v", len(wrk), load, idx, err) + } + + if len(wrk) != load { + t.Fatalf("expected %d work items, got %d", load, len(wrk)) + } +}