Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix possible deadlock in AWS pubsub #804

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions runtime/pubsub/internal/aws/topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,28 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin
}

go func() {
defer func() {
if r := recover(); r != nil {
logger.Error().Interface("panic", r).Msg("panic in subscriber, no longer processing messages")
} else {
logger.Info().Msg("subscriber stopped due to context cancellation")
}
}()

for t.ctx.Err() == nil {
err := utils.WorkConcurrently(
t.ctx,
maxConcurrency, 10,
func(ctx context.Context, maxToFetch int) ([]sqsTypes.Message, error) {
resp, err := t.sqsClient.ReceiveMessage(t.ctx, &sqs.ReceiveMessageInput{
// We should only long poll for 20 seconds, so if this takes more than
// 30 seconds we should cancel the context and try again
//
// We do this incase the ReceiveMessage call gets stuck on the server
// and doesn't return
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
Comment on lines +103 to +109
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One theory I have is the AWS library might be stalling and blocking under high load, so I've introduced a smaller timeout to try and cause a context cancelled error


resp, err := t.sqsClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
QueueUrl: aws.String(implCfg.ProviderName),
AttributeNames: []sqsTypes.QueueAttributeName{"ApproximateReceiveCount"},
MaxNumberOfMessages: int32(maxToFetch),
Expand Down Expand Up @@ -132,33 +148,51 @@ func (t *topic) Subscribe(logger *zerolog.Logger, maxConcurrency int, ackDeadlin
}

// Call the callback, and if there was no error, then we can delete the message
msgCtx, cancel := context.WithTimeout(t.ctx, ackDeadline)
msgCtx, cancel := context.WithTimeout(ctx, ackDeadline)
defer cancel()
err = f(msgCtx, msgWrapper.MessageId, msgWrapper.Timestamp, int(deliveryAttempt), attributes, []byte(msgWrapper.Message))
cancel()

// Check if the context has been cancelled, and if so, return the error
if ctx.Err() != nil {
return ctx.Err()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? This might hide other errors from the func if they error for some other reason

}

// We want to wait a maximum of 30 seconds for the callback to complete
// otherwise we assume it has failed and we should retry
//
// We do this incase the callback gets stuck and doesn't return
ctx, responseCancel := context.WithTimeout(ctx, 30*time.Second)
defer responseCancel()

if err != nil {
logger.Err(err).Str("msg_id", msgWrapper.MessageId).Msg("unable to process message")

// If there was an error processing the message, apply the backoff policy
_, delay := utils.GetDelay(retryPolicy.MaxRetries, retryPolicy.MinBackoff, retryPolicy.MaxBackoff, uint16(deliveryAttempt))
_, visibilityChangeErr := t.sqsClient.ChangeMessageVisibility(t.ctx, &sqs.ChangeMessageVisibilityInput{
_, visibilityChangeErr := t.sqsClient.ChangeMessageVisibility(ctx, &sqs.ChangeMessageVisibilityInput{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we do this with a context not derived from the input, so we do this even if the input context is canceled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I delribatley wanted both these API calls to release when the fetch context is cancelled, as we immediately go into a loop to go again.

I think them being based on the t.ctx, rather than the fetch context could have been an issue; as we never cancel the t.ctx, but the fetchCtx is cancelled when we want to exit the WorkConcurrently code

QueueUrl: aws.String(implCfg.ProviderName),
ReceiptHandle: msg.ReceiptHandle,
VisibilityTimeout: int32(utils.Clamp(delay, time.Second, 12*time.Hour).Seconds()),
})
if visibilityChangeErr != nil {
log.Warn().Err(visibilityChangeErr).Str("msg_id", msgWrapper.MessageId).Msg("unable to change message visibility to apply backoff rules")
}
}
if err == nil {
_, err = t.sqsClient.DeleteMessage(t.ctx, &sqs.DeleteMessageInput{
} else {
// If the message was processed successfully, delete it from the queue
_, err = t.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

QueueUrl: aws.String(implCfg.ProviderName),
ReceiptHandle: msg.ReceiptHandle,
})
if err != nil {
logger.Err(err).Str("msg_id", msgWrapper.MessageId).Msg("unable to delete message from SQS queue")
}
}

return nil
},
)

if err != nil && t.ctx.Err() == nil {
logger.Warn().Err(err).Msg("pubsub subscription failed, retrying in 5 seconds")
time.Sleep(5 * time.Second)
Expand Down
66 changes: 44 additions & 22 deletions runtime/pubsub/internal/utils/workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"sync"
"sync/atomic"
"time"

"encore.dev/beta/errs"
)

const (
Expand Down Expand Up @@ -42,19 +44,37 @@ type WorkProcessor[Work any] func(ctx context.Context, work Work) error
// this again you could end up with 2x maxConcurrency workers running at the same time. (1x from the original run who
// are still processing work and 1x from the new run).
func WorkConcurrently[Work any](ctx context.Context, maxConcurrency int, maxBatchSize int, fetch WorkFetcher[Work], worker WorkProcessor[Work]) error {
fetchWithPanicHandling := func(ctx context.Context, maxToFetch int) (work []Work, err error) {
defer func() {
if r := recover(); r != nil {
err = errs.B().Msgf("panic: %v", r).Err()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include the stack like we do elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errs.B() will build in a stack no?

}
}()
return fetch(ctx, maxToFetch)
}

workWithPanicHandling := func(ctx context.Context, work Work) (err error) {
defer func() {
if r := recover(); r != nil {
err = errs.B().Msgf("panic: %v", r).Err()
DomBlack marked this conversation as resolved.
Show resolved Hide resolved
}
}()
return worker(ctx, work)
}

if maxConcurrency == 1 {
// If there's no concurrency, we can just do everything synchronously within this goroutine
// This avoids the overhead of creating mutexes being used
return workInSingleRoutine(ctx, fetch, worker)
return workInSingleRoutine(ctx, fetchWithPanicHandling, workWithPanicHandling)

} else if maxConcurrency <= 0 {
// If there's infinite concurrency, we can just do everything by spawning goroutines
// for each work item
return workInInfiniteRoutines(ctx, maxBatchSize, fetch, worker)
return workInInfiniteRoutines(ctx, maxBatchSize, fetchWithPanicHandling, workWithPanicHandling)

} else {
// Else there's a cap on concurrency, we need to use channels to communicate between the fetcher and the workers
return workInWorkPool(ctx, maxConcurrency, maxBatchSize, fetch, worker)
return workInWorkPool(ctx, maxConcurrency, maxBatchSize, fetchWithPanicHandling, workWithPanicHandling)
}
}

Expand Down Expand Up @@ -139,7 +159,12 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS
workProcessor := func(work Work) {
inFlight.Add(1)
defer inFlight.Add(-1)
defer func() { workDone <- struct{}{} }()
defer func() {
select {
case workDone <- struct{}{}:
case <-fetchCtx.Done():
}
}()

// We use the parent context here, such that if we have a fetch error, the existing workers will
// continue to run until they finish processing any work already have started on
Expand All @@ -150,14 +175,17 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS

// fetchProcessor is a small wrapper around the fetcher function that passes the fetched work to the workers
// it will fetch upto maxConcurrency items at a time in batches of maxBatchSize items
var lastFetch time.Time
var lastFetch atomic.Pointer[time.Time]
var epoch time.Time
lastFetch.Store(&epoch)
Comment on lines +178 to +180
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a race on this, so I've moved it into an atomic


var debounceTimer *time.Timer
var fetchLock sync.Mutex
fetchProcessor := func() {
// Lock the fetcher so that we don't have multiple fetchers running at the same time
fetchLock.Lock()
defer fetchLock.Unlock()
defer func() { lastFetch = time.Now() }()
defer func() { now := time.Now(); lastFetch.Store(&now) }()

// Work out how many items we need to fetch
need := maxConcurrency - int(inFlight.Load())
Expand All @@ -179,7 +207,12 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS

// Pass work to workers
for _, w := range work {
workChan <- w
select {
case workChan <- w:
// success, we passed the work to a worker
case <-fetchCtx.Done():
return
Comment on lines +213 to +214
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was the deadlock;

If all the workers panic'ed and quit, that could have resulted in the fetch processor trying to write onto the workChan but nothing pulling and reading them.

I've solved this by adding this new select, so if the ctx is done, we don't even try to push to the channel & I've also added additional panic recovery wrappers at other points in the code.

}
}

// Update the number of items we need to fetch
Expand All @@ -191,16 +224,8 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS
// Start the workers
for i := 0; i < maxConcurrency; i++ {
go func() {
for {
select {
case work, more := <-workChan:
if !more {
// the workChan has been closed, we can stop
return
}

workProcessor(work)
}
for work := range workChan {
workProcessor(work)
}
}()
}
Expand All @@ -210,20 +235,17 @@ func workInWorkPool[Work any](ctx context.Context, maxConcurrency int, maxBatchS
workDone <- struct{}{}

// Start fetching work
fetchLoop:
for {
for fetchCtx.Err() == nil {
select {
case <-fetchCtx.Done():
// If the context is cancelled, we need to stop fetching work
break fetchLoop

case <-workDone:
if debounceTimer != nil {
debounceTimer.Stop()
debounceTimer = nil
}

if time.Since(lastFetch) > maxFetchDebounce {
if time.Since(*lastFetch.Load()) > maxFetchDebounce {
fetchProcessor()
} else {
debounceTimer = time.AfterFunc(workFetchDebounce, fetchProcessor)
Expand Down
76 changes: 76 additions & 0 deletions runtime/pubsub/internal/utils/workers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ func TestWorkConcurrently(t *testing.T) {

err := WorkConcurrently(ctx, tt.concurrency, tt.maxBatchSize, fetcher, processor)

workMu.Lock()
defer workMu.Unlock()

// Run assertions on the exit conditions
c.Assert(timeoutCtx.Err(), qt.IsNil, qt.Commentf("test timed out - not all work was fetched within the timeout"))
switch {
Expand Down Expand Up @@ -220,3 +223,76 @@ func TestWorkConcurrently(t *testing.T) {
})
}
}

func TestWorkConcurrentlyLoad(t *testing.T) {
t.Skipped()

const load = 20_000
msg := make([]string, load)
for i := 0; i < load; i++ {
msg[i] = fmt.Sprintf("msg %d", i)
}

var (
mu sync.Mutex
idx int

wMu sync.Mutex
wrk []string
)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

var err error

for ctx.Err() == nil {
err = WorkConcurrently(ctx, 25, 10, func(ctx context.Context, maxToFetch int) ([]string, error) {
mu.Lock()
defer mu.Unlock()

if idx >= load {
return nil, nil
}

toFetch := maxToFetch
if toFetch > load-idx {
toFetch = load - idx
}
if toFetch == 0 {
return nil, nil
}

rtn := make([]string, toFetch)
copy(rtn, msg[idx:idx+toFetch])
idx += toFetch

return rtn, nil
}, func(ctx context.Context, work string) error {
time.Sleep(100 * time.Millisecond)
wMu.Lock()
defer wMu.Unlock()

wrk = append(wrk, work)

if len(wrk)%250 == 0 {
panic("too much work")
}

if len(wrk) == load {
cancel()
}
return nil
})

fmt.Printf("err (worked %d/%d - sent %d): %v\n", len(wrk), load, idx, err)
}

if err != nil {
t.Fatalf("err (worked %d/%d - sent %d): %v", len(wrk), load, idx, err)
}

if len(wrk) != load {
t.Fatalf("expected %d work items, got %d", load, len(wrk))
}
}
Loading