diff --git a/lib/queue/queue.py b/lib/queue/queue.py index 6409562..c1c99bd 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -1,3 +1,4 @@ +import pdb import json from typing import List, Dict, Tuple import os @@ -8,7 +9,10 @@ from lib.helpers import get_environment_setting from lib.logger import logger from lib import schemas -SQS_MAX_BATCH_SIZE = 10 + +SQS_MAX_BATCH_SIZE = int(os.getenv("SQS_MAX_BATCH_SIZE", "10")) +MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5")) + class Queue: def __init__(self): """ @@ -34,6 +38,11 @@ def get_output_queue_name(model_name=None): name = model_name or get_environment_setting("MODEL_NAME").replace(".", "__") return Queue.get_queue_prefix()+name+"_output"+Queue.get_queue_suffix() + @staticmethod + def get_dead_letter_queue_name(model_name=None): + name = model_name or get_environment_setting("MODEL_NAME").replace(".", "__") + return Queue.get_queue_prefix()+name+"_dlq"+Queue.get_queue_suffix() + def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]) -> Dict[str, boto3.resources.base.ServiceResource]: """ Store a quick lookup so that we dont loop through this over and over in other places. @@ -42,13 +51,13 @@ def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource] for queue in all_queues: queue_map[self.queue_name(queue)] = queue return queue_map - + def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str: """ Pull queue name from a given queue """ return queue.url.split('/')[-1] - + def restrict_queues_to_suffix(self, queues: List[boto3.resources.base.ServiceResource], suffix: str) -> List[boto3.resources.base.ServiceResource]: """ When plucking input queues, we want to omit any queues that are our paired suffix queues.. @@ -122,7 +131,7 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto queue_to_messages[queue] = [] queue_to_messages[queue].append(message) return queue_to_messages - + def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None: """ Delete messages as we process them so other processes don't pick them up. @@ -132,7 +141,6 @@ def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto for queue, messages in self.group_deletions(messages_with_queues).items(): self.delete_messages_from_queue(queue, messages) - def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource, messages: List[schemas.Message]) -> None: """ Helper function to delete a group of messages from a specific queue. @@ -170,7 +178,7 @@ def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceRes Search through queues to find the right one """ return self.all_queues.get(queue_name) - + def push_message(self, queue_name: str, message: schemas.Message) -> schemas.Message: """ Actual SQS logic for pushing a message to a queue @@ -180,3 +188,10 @@ def push_message(self, queue_name: str, message: schemas.Message) -> schemas.Mes message_data["MessageGroupId"] = message.body.id self.find_queue_by_name(queue_name).send_message(**message_data) return message + + def push_to_dead_letter_queue(self, message: schemas.Message): + """ + Push a message to the dead letter queue. + """ + dlq_name = Queue.get_dead_letter_queue_name() + self.push_message(dlq_name, message) \ No newline at end of file diff --git a/lib/queue/worker.py b/lib/queue/worker.py index f69ba65..d1ffd15 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -1,12 +1,16 @@ +import pdb import os from concurrent.futures import ThreadPoolExecutor, TimeoutError import json from typing import List, Tuple from lib import schemas from lib.logger import logger -from lib.queue.queue import Queue +from lib.queue.queue import Queue, MAX_RETRIES from lib.model.model import Model +from lib.sentry import capture_custom_message + TIMEOUT_SECONDS = int(os.getenv("WORK_TIMEOUT_SECONDS", "60")) + class QueueWorker(Queue): @classmethod def create(cls, model_name: str = None): @@ -19,18 +23,20 @@ def create(cls, model_name: str = None): logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}')") return QueueWorker(input_queue_name, output_queue_name) - def __init__(self, input_queue_name: str, output_queue_name: str = None): + def __init__(self, input_queue_name: str, output_queue_name: str = None, dlq_queue_name: str = None): """ - Start a specific queue - must pass input_queue_name - optionally pass output_queue_name. + Start a specific queue - must pass input_queue_name, optionally pass output_queue_name, dlq_queue_name. """ super().__init__() self.input_queue_name = input_queue_name + self.output_queue_name = output_queue_name or Queue.get_output_queue_name() + self.dlq_queue_name = dlq_queue_name or Queue.get_dead_letter_queue_name() q_suffix = f"_output" + Queue.get_queue_suffix() - self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix) - if output_queue_name: - self.output_queue_name = Queue.get_output_queue_name() - self.output_queues = self.get_or_create_queues(output_queue_name) - self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues] for item in row]) + dlq_suffix = f"_dlq" + Queue.get_queue_suffix() + self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix) + self.output_queues = self.get_or_create_queues(self.output_queue_name) + self.dead_letter_queues = self.get_or_create_queues(self.dlq_queue_name) + self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues, self.dead_letter_queues] for item in row]) logger.info(f"Worker listening to queues of {self.all_queues}") def process(self, model: Model): @@ -57,6 +63,8 @@ def safely_respond(self, model: Model) -> List[schemas.Message]: responses, success = self.execute_with_timeout(model.respond, messages, timeout_seconds=TIMEOUT_SECONDS) if success: self.delete_processed_messages(messages_with_queues) + else: + self.increment_message_error_counts(messages_with_queues) # Add the new functionality here return responses @staticmethod @@ -117,3 +125,23 @@ def delete_processed_messages(self, messages_with_queues: List[Tuple]): - messages_with_queues (List[Tuple]): A list of tuples, each containing a message and its corresponding queue, to be deleted. """ self.delete_messages(messages_with_queues) + + def increment_message_error_counts(self, messages_with_queues: List[Tuple]): + """ + Increment the error count for messages and push them back to the queue or to the dead letter queue if retries exceed the limit. + + Parameters: + - messages_with_queues (List[Tuple]): A list of tuples, each containing a message and its corresponding queue. + """ + for message, queue in messages_with_queues: + message_body = json.loads(message.body) + retry_count = message_body.get('retry_count', 0) + 1 + + if retry_count > MAX_RETRIES: + logger.info(f"Message {message_body} exceeded max retries. Moving to DLQ.") + capture_custom_message("Message exceeded max retries. Moving to DLQ.", 'info', {"message_body": message_body}) + self.push_to_dead_letter_queue(schemas.parse_message(message_body)) + else: + message_body['retry_count'] = retry_count + updated_message = schemas.parse_message(message_body) + self.push_message(self.input_queue_name, updated_message) \ No newline at end of file diff --git a/lib/sentry.py b/lib/sentry.py index 8804a1e..09f733b 100644 --- a/lib/sentry.py +++ b/lib/sentry.py @@ -7,3 +7,10 @@ environment=get_environment_setting("DEPLOY_ENV"), traces_sample_rate=1.0, ) + +def capture_custom_message(message, level='info', extra=None): + with sentry_sdk.configure_scope() as scope: + if extra: + for key, value in extra.items(): + scope.set_extra(key, value) + sentry_sdk.capture_message(message, level=level) \ No newline at end of file diff --git a/test/lib/queue/test_processor.py b/test/lib/queue/test_processor.py index 567e4b4..9b5700f 100644 --- a/test/lib/queue/test_processor.py +++ b/test/lib/queue/test_processor.py @@ -6,19 +6,19 @@ from lib import schemas from test.lib.queue.fake_sqs_message import FakeSQSMessage class TestQueueProcessor(unittest.TestCase): - + @patch('lib.queue.queue.boto3.resource') @patch('lib.helpers.get_environment_setting', return_value='us-west-1') def setUp(self, mock_get_env_setting, mock_boto_resource): self.queue_name_input = 'mean_tokens__Model' - + # Mock the SQS resource and the queues self.mock_sqs_resource = MagicMock() self.mock_input_queue = MagicMock() self.mock_input_queue.url = "http://queue/mean_tokens__Model" self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue] mock_boto_resource.return_value = self.mock_sqs_resource - + # Initialize the QueueProcessor instance self.queue_processor = QueueProcessor(self.queue_name_input, batch_size=2) @@ -29,9 +29,9 @@ def test_send_callbacks(self): ) self.queue_processor.send_callback = MagicMock(return_value=None) self.queue_processor.delete_messages = MagicMock(return_value=None) - + responses = self.queue_processor.send_callbacks() - + self.queue_processor.receive_messages.assert_called_once_with(2) self.queue_processor.send_callback.assert_called() self.queue_processor.delete_messages.assert_called() @@ -49,6 +49,6 @@ def test_send_callback_failure(self, mock_post): with self.assertLogs(level='ERROR') as cm: self.queue_processor.send_callback(message_body) self.assertIn("Failed with Request Failed! on http://example.com with message of", cm.output[0]) - + if __name__ == '__main__': unittest.main() diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index 65911f4..57fb044 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -20,6 +20,7 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue self.mock_model = MagicMock() self.queue_name_input = Queue.get_input_queue_name() self.queue_name_output = Queue.get_output_queue_name() + self.queue_name_dlq = Queue.get_dead_letter_queue_name() # Mock the SQS resource and the queues self.mock_sqs_resource = MagicMock() @@ -29,15 +30,21 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue self.mock_output_queue = MagicMock() self.mock_output_queue.url = f"http://queue/{self.queue_name_output}" self.mock_output_queue.attributes = {"QueueArn": f"queue:{self.queue_name_output}"} - self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue] + self.mock_dlq_queue = MagicMock() + self.mock_dlq_queue.url = f"http://queue/{self.queue_name_dlq}" + self.mock_dlq_queue.attributes = {"QueueArn": f"queue:{self.queue_name_dlq}"} + self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue, self.mock_dlq_queue] mock_boto_resource.return_value = self.mock_sqs_resource # Initialize the SQSQueue instance - self.queue = QueueWorker(self.queue_name_input, self.queue_name_output) - + self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq) + def test_get_output_queue_name(self): self.assertEqual(self.queue.get_output_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_output').replace(".fifo", "")) + def test_get_dead_letter_queue_name(self): + self.assertEqual(self.queue.get_dead_letter_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_dlq').replace(".fifo", "")) + def test_process(self): self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), self.mock_input_queue)]) self.queue.input_queue = MagicMock(return_value=None) @@ -55,7 +62,7 @@ def test_receive_messages(self): mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))] self.queue.input_queues = [mock_queue1, mock_queue2] received_messages = self.queue.receive_messages(5) - + # Check if the right number of messages were received and the content is correct self.assertEqual(len(received_messages), 2) self.assertIn("a test", received_messages[0][0].body) @@ -109,6 +116,47 @@ def test_push_message(self): self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model"}') self.assertEqual(returned_message, message_to_push) + def test_push_to_dead_letter_queue(self): + message_to_push = schemas.parse_message({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"}) + # Call push_to_dead_letter_queue + self.queue.push_to_dead_letter_queue(message_to_push) + # Check if the message was correctly serialized and sent to the DLQ + self.mock_dlq_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model"}') + + def test_increment_message_error_counts_exceed_max_retries(self): + message_body = { + "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, + "retry_count": 5, # Already at max retries + "model_name": "mean_tokens__Model" + } + fake_message = FakeSQSMessage(receipt_handle="blah", body=json.dumps(message_body)) + messages_with_queues = [(fake_message, self.mock_input_queue)] + + self.queue.push_to_dead_letter_queue = MagicMock() + self.queue.push_message = MagicMock() + + self.queue.increment_message_error_counts(messages_with_queues) + + self.queue.push_to_dead_letter_queue.assert_called_once() + self.queue.push_message.assert_not_called() + + def test_increment_message_error_counts_increment(self): + message_body = { + "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, + "retry_count": 2, # Less than max retries + "model_name": "mean_tokens__Model" + } + fake_message = FakeSQSMessage(receipt_handle="blah", body=json.dumps(message_body)) + messages_with_queues = [(fake_message, self.mock_input_queue)] + + self.queue.push_to_dead_letter_queue = MagicMock() + self.queue.push_message = MagicMock() + + self.queue.increment_message_error_counts(messages_with_queues) + + self.queue.push_to_dead_letter_queue.assert_not_called() + self.queue.push_message.assert_called_once() + def test_extract_messages(self): messages_with_queues = [ (FakeSQSMessage(receipt_handle="blah", body=json.dumps({ @@ -131,7 +179,7 @@ def test_extract_messages(self): def test_execute_with_timeout_success(self, mock_log_error): def test_func(args): return ["response"] - + responses, success = QueueWorker.execute_with_timeout(test_func, [], timeout_seconds=1) self.assertEqual(responses, ["response"]) self.assertTrue(success) @@ -141,7 +189,7 @@ def test_func(args): def test_execute_with_timeout_failure(self, mock_log_error): def test_func(args): raise TimeoutError - + responses, success = QueueWorker.execute_with_timeout(test_func, [], timeout_seconds=1) self.assertEqual(responses, []) self.assertFalse(success)