Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV2-4606 Add retry logic and dead letter message dumping #82

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pdb
import json
from typing import List, Dict, Tuple
import os
Expand All @@ -8,7 +9,10 @@
from lib.helpers import get_environment_setting
from lib.logger import logger
from lib import schemas
SQS_MAX_BATCH_SIZE = 10

SQS_MAX_BATCH_SIZE = int(os.getenv("SQS_MAX_BATCH_SIZE", "10"))
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))

class Queue:
def __init__(self):
"""
Expand All @@ -34,6 +38,11 @@ def get_output_queue_name(model_name=None):
name = model_name or get_environment_setting("MODEL_NAME").replace(".", "__")
return Queue.get_queue_prefix()+name+"_output"+Queue.get_queue_suffix()

@staticmethod
def get_dead_letter_queue_name(model_name=None):
name = model_name or get_environment_setting("MODEL_NAME").replace(".", "__")
return Queue.get_queue_prefix()+name+"_dlq"+Queue.get_queue_suffix()

def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]) -> Dict[str, boto3.resources.base.ServiceResource]:
"""
Store a quick lookup so that we dont loop through this over and over in other places.
Expand All @@ -42,13 +51,13 @@ def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]
for queue in all_queues:
queue_map[self.queue_name(queue)] = queue
return queue_map

def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str:
"""
Pull queue name from a given queue
"""
return queue.url.split('/')[-1]

def restrict_queues_to_suffix(self, queues: List[boto3.resources.base.ServiceResource], suffix: str) -> List[boto3.resources.base.ServiceResource]:
"""
When plucking input queues, we want to omit any queues that are our paired suffix queues..
Expand Down Expand Up @@ -122,7 +131,7 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto
queue_to_messages[queue] = []
queue_to_messages[queue].append(message)
return queue_to_messages

def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None:
"""
Delete messages as we process them so other processes don't pick them up.
Expand All @@ -132,7 +141,6 @@ def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto
for queue, messages in self.group_deletions(messages_with_queues).items():
self.delete_messages_from_queue(queue, messages)


def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource, messages: List[schemas.Message]) -> None:
"""
Helper function to delete a group of messages from a specific queue.
Expand Down Expand Up @@ -170,7 +178,7 @@ def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceRes
Search through queues to find the right one
"""
return self.all_queues.get(queue_name)

def push_message(self, queue_name: str, message: schemas.Message) -> schemas.Message:
"""
Actual SQS logic for pushing a message to a queue
Expand All @@ -180,3 +188,10 @@ def push_message(self, queue_name: str, message: schemas.Message) -> schemas.Mes
message_data["MessageGroupId"] = message.body.id
self.find_queue_by_name(queue_name).send_message(**message_data)
return message

def push_to_dead_letter_queue(self, message: schemas.Message):
"""
Push a message to the dead letter queue.
"""
dlq_name = Queue.get_dead_letter_queue_name()
self.push_message(dlq_name, message)
44 changes: 36 additions & 8 deletions lib/queue/worker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import pdb
import os
from concurrent.futures import ThreadPoolExecutor, TimeoutError
import json
from typing import List, Tuple
from lib import schemas
from lib.logger import logger
from lib.queue.queue import Queue
from lib.queue.queue import Queue, MAX_RETRIES
from lib.model.model import Model
from lib.sentry import capture_custom_message

TIMEOUT_SECONDS = int(os.getenv("WORK_TIMEOUT_SECONDS", "60"))

class QueueWorker(Queue):
@classmethod
def create(cls, model_name: str = None):
Expand All @@ -19,18 +23,20 @@ def create(cls, model_name: str = None):
logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}')")
return QueueWorker(input_queue_name, output_queue_name)

def __init__(self, input_queue_name: str, output_queue_name: str = None):
def __init__(self, input_queue_name: str, output_queue_name: str = None, dlq_queue_name: str = None):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name.
Start a specific queue - must pass input_queue_name, optionally pass output_queue_name, dlq_queue_name.
"""
super().__init__()
self.input_queue_name = input_queue_name
self.output_queue_name = output_queue_name or Queue.get_output_queue_name()
self.dlq_queue_name = dlq_queue_name or Queue.get_dead_letter_queue_name()
q_suffix = f"_output" + Queue.get_queue_suffix()
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix)
if output_queue_name:
self.output_queue_name = Queue.get_output_queue_name()
self.output_queues = self.get_or_create_queues(output_queue_name)
self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues] for item in row])
dlq_suffix = f"_dlq" + Queue.get_queue_suffix()
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix)
self.output_queues = self.get_or_create_queues(self.output_queue_name)
self.dead_letter_queues = self.get_or_create_queues(self.dlq_queue_name)
self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues, self.dead_letter_queues] for item in row])
logger.info(f"Worker listening to queues of {self.all_queues}")

def process(self, model: Model):
Expand All @@ -57,6 +63,8 @@ def safely_respond(self, model: Model) -> List[schemas.Message]:
responses, success = self.execute_with_timeout(model.respond, messages, timeout_seconds=TIMEOUT_SECONDS)
if success:
self.delete_processed_messages(messages_with_queues)
else:
self.increment_message_error_counts(messages_with_queues) # Add the new functionality here
return responses

@staticmethod
Expand Down Expand Up @@ -117,3 +125,23 @@ def delete_processed_messages(self, messages_with_queues: List[Tuple]):
- messages_with_queues (List[Tuple]): A list of tuples, each containing a message and its corresponding queue, to be deleted.
"""
self.delete_messages(messages_with_queues)

def increment_message_error_counts(self, messages_with_queues: List[Tuple]):
"""
Increment the error count for messages and push them back to the queue or to the dead letter queue if retries exceed the limit.

Parameters:
- messages_with_queues (List[Tuple]): A list of tuples, each containing a message and its corresponding queue.
"""
for message, queue in messages_with_queues:
message_body = json.loads(message.body)
retry_count = message_body.get('retry_count', 0) + 1

if retry_count > MAX_RETRIES:
logger.info(f"Message {message_body} exceeded max retries. Moving to DLQ.")
computermacgyver marked this conversation as resolved.
Show resolved Hide resolved
capture_custom_message("Message exceeded max retries. Moving to DLQ.", 'info', {"message_body": message_body})
self.push_to_dead_letter_queue(schemas.parse_message(message_body))
else:
message_body['retry_count'] = retry_count
updated_message = schemas.parse_message(message_body)
self.push_message(self.input_queue_name, updated_message)
7 changes: 7 additions & 0 deletions lib/sentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,10 @@
environment=get_environment_setting("DEPLOY_ENV"),
traces_sample_rate=1.0,
)

def capture_custom_message(message, level='info', extra=None):
with sentry_sdk.configure_scope() as scope:
if extra:
for key, value in extra.items():
scope.set_extra(key, value)
sentry_sdk.capture_message(message, level=level)
12 changes: 6 additions & 6 deletions test/lib/queue/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from lib import schemas
from test.lib.queue.fake_sqs_message import FakeSQSMessage
class TestQueueProcessor(unittest.TestCase):

@patch('lib.queue.queue.boto3.resource')
@patch('lib.helpers.get_environment_setting', return_value='us-west-1')
def setUp(self, mock_get_env_setting, mock_boto_resource):
self.queue_name_input = 'mean_tokens__Model'

# Mock the SQS resource and the queues
self.mock_sqs_resource = MagicMock()
self.mock_input_queue = MagicMock()
self.mock_input_queue.url = "http://queue/mean_tokens__Model"
self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue]
mock_boto_resource.return_value = self.mock_sqs_resource

# Initialize the QueueProcessor instance
self.queue_processor = QueueProcessor(self.queue_name_input, batch_size=2)

Expand All @@ -29,9 +29,9 @@ def test_send_callbacks(self):
)
self.queue_processor.send_callback = MagicMock(return_value=None)
self.queue_processor.delete_messages = MagicMock(return_value=None)

responses = self.queue_processor.send_callbacks()

self.queue_processor.receive_messages.assert_called_once_with(2)
self.queue_processor.send_callback.assert_called()
self.queue_processor.delete_messages.assert_called()
Expand All @@ -49,6 +49,6 @@ def test_send_callback_failure(self, mock_post):
with self.assertLogs(level='ERROR') as cm:
self.queue_processor.send_callback(message_body)
self.assertIn("Failed with Request Failed! on http://example.com with message of", cm.output[0])

if __name__ == '__main__':
unittest.main()
60 changes: 54 additions & 6 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue
self.mock_model = MagicMock()
self.queue_name_input = Queue.get_input_queue_name()
self.queue_name_output = Queue.get_output_queue_name()
self.queue_name_dlq = Queue.get_dead_letter_queue_name()

# Mock the SQS resource and the queues
self.mock_sqs_resource = MagicMock()
Expand All @@ -29,15 +30,21 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue
self.mock_output_queue = MagicMock()
self.mock_output_queue.url = f"http://queue/{self.queue_name_output}"
self.mock_output_queue.attributes = {"QueueArn": f"queue:{self.queue_name_output}"}
self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue]
self.mock_dlq_queue = MagicMock()
self.mock_dlq_queue.url = f"http://queue/{self.queue_name_dlq}"
self.mock_dlq_queue.attributes = {"QueueArn": f"queue:{self.queue_name_dlq}"}
self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue, self.mock_dlq_queue]
mock_boto_resource.return_value = self.mock_sqs_resource

# Initialize the SQSQueue instance
self.queue = QueueWorker(self.queue_name_input, self.queue_name_output)
self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq)

def test_get_output_queue_name(self):
self.assertEqual(self.queue.get_output_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_output').replace(".fifo", ""))

def test_get_dead_letter_queue_name(self):
self.assertEqual(self.queue.get_dead_letter_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_dlq').replace(".fifo", ""))

def test_process(self):
self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), self.mock_input_queue)])
self.queue.input_queue = MagicMock(return_value=None)
Expand All @@ -55,7 +62,7 @@ def test_receive_messages(self):
mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))]
self.queue.input_queues = [mock_queue1, mock_queue2]
received_messages = self.queue.receive_messages(5)

# Check if the right number of messages were received and the content is correct
self.assertEqual(len(received_messages), 2)
self.assertIn("a test", received_messages[0][0].body)
Expand Down Expand Up @@ -109,6 +116,47 @@ def test_push_message(self):
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model"}')
self.assertEqual(returned_message, message_to_push)

def test_push_to_dead_letter_queue(self):
message_to_push = schemas.parse_message({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
# Call push_to_dead_letter_queue
self.queue.push_to_dead_letter_queue(message_to_push)
# Check if the message was correctly serialized and sent to the DLQ
self.mock_dlq_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model"}')

def test_increment_message_error_counts_exceed_max_retries(self):
message_body = {
"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"},
"retry_count": 5, # Already at max retries
"model_name": "mean_tokens__Model"
}
fake_message = FakeSQSMessage(receipt_handle="blah", body=json.dumps(message_body))
messages_with_queues = [(fake_message, self.mock_input_queue)]

self.queue.push_to_dead_letter_queue = MagicMock()
self.queue.push_message = MagicMock()

self.queue.increment_message_error_counts(messages_with_queues)

self.queue.push_to_dead_letter_queue.assert_called_once()
self.queue.push_message.assert_not_called()

def test_increment_message_error_counts_increment(self):
message_body = {
"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"},
"retry_count": 2, # Less than max retries
"model_name": "mean_tokens__Model"
}
fake_message = FakeSQSMessage(receipt_handle="blah", body=json.dumps(message_body))
messages_with_queues = [(fake_message, self.mock_input_queue)]

self.queue.push_to_dead_letter_queue = MagicMock()
self.queue.push_message = MagicMock()

self.queue.increment_message_error_counts(messages_with_queues)

self.queue.push_to_dead_letter_queue.assert_not_called()
self.queue.push_message.assert_called_once()

def test_extract_messages(self):
messages_with_queues = [
(FakeSQSMessage(receipt_handle="blah", body=json.dumps({
Expand All @@ -131,7 +179,7 @@ def test_extract_messages(self):
def test_execute_with_timeout_success(self, mock_log_error):
def test_func(args):
return ["response"]

responses, success = QueueWorker.execute_with_timeout(test_func, [], timeout_seconds=1)
self.assertEqual(responses, ["response"])
self.assertTrue(success)
Expand All @@ -141,7 +189,7 @@ def test_func(args):
def test_execute_with_timeout_failure(self, mock_log_error):
def test_func(args):
raise TimeoutError

responses, success = QueueWorker.execute_with_timeout(test_func, [], timeout_seconds=1)
self.assertEqual(responses, [])
self.assertFalse(success)
Expand Down
Loading