Skip to content

Commit

Permalink
Fix possible deadlock in AWS pubsub
Browse files Browse the repository at this point in the history
  • Loading branch information
DomBlack committed Jul 19, 2023
1 parent 5f096ea commit 68c47b5
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 28 deletions.
46 changes: 40 additions & 6 deletions runtime/pubsub/internal/aws/topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,28 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin
}

go func() {
defer func() {
if r := recover(); r != nil {
logger.Error().Interface("panic", r).Msg("panic in subscriber, no longer processing messages")
} else {
logger.Info().Msg("subscriber stopped due to context cancellation")
}
}()

for t.ctx.Err() == nil {
err := utils.WorkConcurrently(
t.ctx,
maxConcurrency, 10,
func(ctx context.Context, maxToFetch int) ([]sqsTypes.Message, error) {
resp, err := t.sqsClient.ReceiveMessage(t.ctx, &sqs.ReceiveMessageInput{
// We should only long poll for 20 seconds, so if this takes more than
// 30 seconds we should cancel the context and try again
//
// We do this incase the ReceiveMessage call gets stuck on the server
// and doesn't return
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

resp, err := t.sqsClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
QueueUrl: aws.String(implCfg.ProviderName),
AttributeNames: []sqsTypes.QueueAttributeName{"ApproximateReceiveCount"},
MaxNumberOfMessages: int32(maxToFetch),
Expand Down Expand Up @@ -132,33 +148,51 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin
}

// Call the callback, and if there was no error, then we can delete the message
msgCtx, cancel := context.WithTimeout(t.ctx, ackDeadline)
msgCtx, cancel := context.WithTimeout(ctx, ackDeadline)
defer cancel()
err = f(msgCtx, msgWrapper.MessageId, msgWrapper.Timestamp, int(deliveryAttempt), attributes, []byte(msgWrapper.Message))
cancel()

// Check if the context has been cancelled, and if so, return the error
if ctx.Err() != nil {
return ctx.Err()
}

// We want to wait a maximum of 30 seconds for the callback to complete
// otherwise we assume it has failed and we should retry
//
// We do this incase the callback gets stuck and doesn't return
ctx, responseCancel := context.WithTimeout(ctx, 30*time.Second)
defer responseCancel()

if err != nil {
logger.Err(err).Str("msg_id", msgWrapper.MessageId).Msg("unable to process message")

// If there was an error processing the message, apply the backoff policy
_, delay := utils.GetDelay(retryPolicy.MaxRetries, retryPolicy.MinBackoff, retryPolicy.MaxBackoff, uint16(deliveryAttempt))
_, visibilityChangeErr := t.sqsClient.ChangeMessageVisibility(t.ctx, &sqs.ChangeMessageVisibilityInput{
_, visibilityChangeErr := t.sqsClient.ChangeMessageVisibility(ctx, &sqs.ChangeMessageVisibilityInput{
QueueUrl: aws.String(implCfg.ProviderName),
ReceiptHandle: msg.ReceiptHandle,
VisibilityTimeout: int32(utils.Clamp(delay, time.Second, 12*time.Hour).Seconds()),
})
if visibilityChangeErr != nil {
log.Warn().Err(visibilityChangeErr).Str("msg_id", msgWrapper.MessageId).Msg("unable to change message visibility to apply backoff rules")
}
}
if err == nil {
_, err = t.sqsClient.DeleteMessage(t.ctx, &sqs.DeleteMessageInput{
} else {
// If the message was processed successfully, delete it from the queue
_, err = t.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{
QueueUrl: aws.String(implCfg.ProviderName),
ReceiptHandle: msg.ReceiptHandle,
})
if err != nil {
logger.Err(err).Str("msg_id", msgWrapper.MessageId).Msg("unable to delete message from SQS queue")
}
}

return nil
},
)

if err != nil && t.ctx.Err() == nil {
logger.Warn().Err(err).Msg("pubsub subscription failed, retrying in 5 seconds")
time.Sleep(5 * time.Second)
Expand Down
66 changes: 44 additions & 22 deletions runtime/pubsub/internal/utils/workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"sync"
"sync/atomic"
"time"

"encore.dev/beta/errs"
)

const (
Expand Down Expand Up @@ -42,19 +44,37 @@ type WorkProcessor[Work any] func(ctx context.Context, work Work) error
// this again you could end up with 2x maxConcurrency workers running at the same time. (1x from the original run who
// are still processing work and 1x from the new run).
func WorkConcurrently[Work any](ctx context.Context, maxConcurrency int, maxBatchSize int, fetch WorkFetcher[Work], worker WorkProcessor[Work]) error {
fetchWithPanicHandling := func(ctx context.Context, maxToFetch int) (work []Work, err error) {
defer func() {
if r := recover(); r != nil {
err = errs.B().Msgf("panic: %v", r).Err()
}
}()
return fetch(ctx, maxToFetch)
}

workWithPanicHandling := func(ctx context.Context, work Work) (err error) {
defer func() {
if r := recover(); r != nil {
err = errs.B().Msgf("panic: %v", r).Err()
}
}()
return worker(ctx, work)
}

if maxConcurrency == 1 {
// If there's no concurrency, we can just do everything synchronously within this goroutine
// This avoids the overhead of creating mutexes being used
return workInSingleRoutine(ctx, fetch, worker)
return workInSingleRoutine(ctx, fetchWithPanicHandling, workWithPanicHandling)

} else if maxConcurrency <= 0 {
// If there's infinite concurrency, we can just do everything by spawning goroutines
// for each work item
return workInInfiniteRoutines(ctx, maxBatchSize, fetch, worker)
return workInInfiniteRoutines(ctx, maxBatchSize, fetchWithPanicHandling, workWithPanicHandling)

} else {
// Else there's a cap on concurrency, we need to use channels to communicate between the fetcher and the workers
return workInWorkPool(ctx, maxConcurrency, maxBatchSize, fetch, worker)
return workInWorkPool(ctx, maxConcurrency, maxBatchSize, fetchWithPanicHandling, workWithPanicHandling)
}
}

Expand Down Expand Up @@ -139,7 +159,12 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS
workProcessor := func(work Work) {
inFlight.Add(1)
defer inFlight.Add(-1)
defer func() { workDone <- struct{}{} }()
defer func() {
select {
case workDone <- struct{}{}:
case <-fetchCtx.Done():
}
}()

// We use the parent context here, such that if we have a fetch error, the existing workers will
// continue to run until they finish processing any work already have started on
Expand All @@ -150,14 +175,17 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS

// fetchProcessor is a small wrapper around the fetcher function that passes the fetched work to the workers
// it will fetch upto maxConcurrency items at a time in batches of maxBatchSize items
var lastFetch time.Time
var lastFetch atomic.Pointer[time.Time]
var epoch time.Time
lastFetch.Store(&epoch)

var debounceTimer *time.Timer
var fetchLock sync.Mutex
fetchProcessor := func() {
// Lock the fetcher so that we don't have multiple fetchers running at the same time
fetchLock.Lock()
defer fetchLock.Unlock()
defer func() { lastFetch = time.Now() }()
defer func() { now := time.Now(); lastFetch.Store(&now) }()

// Work out how many items we need to fetch
need := maxConcurrency - int(inFlight.Load())
Expand All @@ -179,7 +207,12 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS

// Pass work to workers
for _, w := range work {
workChan <- w
select {
case workChan <- w:
// success, we passed the work to a worker
case <-fetchCtx.Done():
return
}
}

// Update the number of items we need to fetch
Expand All @@ -191,16 +224,8 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS
// Start the workers
for i := 0; i < maxConcurrency; i++ {
go func() {
for {
select {
case work, more := <-workChan:
if !more {
// the workChan has been closed, we can stop
return
}

workProcessor(work)
}
for work := range workChan {
workProcessor(work)
}
}()
}
Expand All @@ -210,20 +235,17 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS
workDone <- struct{}{}

// Start fetching work
fetchLoop:
for {
for fetchCtx.Err() == nil {
select {
case <-fetchCtx.Done():
// If the context is cancelled, we need to stop fetching work
break fetchLoop

case <-workDone:
if debounceTimer != nil {
debounceTimer.Stop()
debounceTimer = nil
}

if time.Since(lastFetch) > maxFetchDebounce {
if time.Since(*lastFetch.Load()) > maxFetchDebounce {
fetchProcessor()
} else {
debounceTimer = time.AfterFunc(workFetchDebounce, fetchProcessor)
Expand Down
76 changes: 76 additions & 0 deletions runtime/pubsub/internal/utils/workers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ func TestWorkConcurrently(t *testing.T) {

err := WorkConcurrently(ctx, tt.concurrency, tt.maxBatchSize, fetcher, processor)

workMu.Lock()
defer workMu.Unlock()

// Run assertions on the exit conditions
c.Assert(timeoutCtx.Err(), qt.IsNil, qt.Commentf("test timed out - not all work was fetched within the timeout"))
switch {
Expand Down Expand Up @@ -220,3 +223,76 @@ func TestWorkConcurrently(t *testing.T) {
})
}
}

func TestWorkConcurrentlyLoad(t *testing.T) {
// t.Skipped()

const load = 20_000
msg := make([]string, load)
for i := 0; i < load; i++ {
msg[i] = fmt.Sprintf("msg %d", i)
}

var (
mu sync.Mutex
idx int

wMu sync.Mutex
wrk []string
)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

var err error

for ctx.Err() == nil {
err = WorkConcurrently(ctx, 25, 10, func(ctx context.Context, maxToFetch int) ([]string, error) {
mu.Lock()
defer mu.Unlock()

if idx >= load {
return nil, nil
}

toFetch := maxToFetch
if toFetch > load-idx {
toFetch = load - idx
}
if toFetch == 0 {
return nil, nil
}

rtn := make([]string, toFetch)
copy(rtn, msg[idx:idx+toFetch])
idx += toFetch

return rtn, nil
}, func(ctx context.Context, work string) error {
time.Sleep(100 * time.Millisecond)
wMu.Lock()
defer wMu.Unlock()

wrk = append(wrk, work)

if len(wrk)%250 == 0 {
panic("too much work")
}

if len(wrk) == load {
cancel()
}
return nil
})

fmt.Printf("err (worked %d/%d - sent %d): %v\n", len(wrk), load, idx, err)
}

if err != nil {
t.Fatalf("err (worked %d/%d - sent %d): %v", len(wrk), load, idx, err)
}

if len(wrk) != load {
t.Fatalf("expected %d work items, got %d", load, len(wrk))
}
}

0 comments on commit 68c47b5

Please sign in to comment.