Skip to content

Commit

Permalink
CV2-3551 simplify accessors for queues, allow for a specific queue to…
Browse files Browse the repository at this point in the history
… be selected by http endpoint in lieu of just plucking first
  • Loading branch information
DGaffney committed Aug 23, 2023
1 parent 95681d1 commit 400c332
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 20 deletions.
7 changes: 5 additions & 2 deletions lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from lib.queue.queue import Queue
from lib.logger import logger
from lib import schemas
# fingerprinter = "mean_tokens__Model"
# queue = Queue.create(fingerprinter, f"{fingerprinter}-output")
# queue.push_message(fingerprinter, schemas.Message(body={"id": 123, "callback_url": "http://example.com", "text": "Some Text"}, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
app = FastAPI()

@app.middleware("http")
Expand All @@ -28,8 +31,8 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]:

@app.post("/fingerprint_item/{fingerprinter}")
def fingerprint_item(fingerprinter: str, message: Dict[str, Any]):
queue = Queue.create(fingerprinter, f"{fingerprinter}-output")
queue.push_message(queue.input_queues[0], schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
queue = Queue.create(fingerprinter, f"{fingerprinter}_output")
queue.push_message(fingerprinter, schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
return {"message": "Message pushed successfully"}

@app.post("/trigger_callback")
Expand Down
64 changes: 53 additions & 11 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,50 @@ def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_s
self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name)
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output")
self.output_queues = self.get_or_create_queues(output_queue_name)
self.default_input_queue = self.input_queues[0]
self.default_output_queue = self.output_queues[0]
self.batch_size = batch_size

def restrict_queues_by_suffix(self, queues, suffix):
def get_queue_by_name(self, queue_list, queue_name):
"""
Provide access point for directly selecting the right queue by a name
"""
candidates = [e for e in queue_list if queue_name == e.name]
if candidates:
return candidates[0]

def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str:
"""
Pull queue name from a given queue
"""
return queue.url.split('/')[-1]

def restrict_queues_by_suffix(self, queues: List[boto3.resources.base.ServiceResource], suffix: str) -> List[boto3.resources.base.ServiceResource]:
"""
When plucking input queues, we want to omit any queues that are our paired suffix queues..
"""
return [queue for queue in queues if not queue.url.split('/')[-1].endswith(suffix)]
return [queue for queue in queues if not self.queue_name(queue).endswith(suffix)]

def get_or_create_queues(self, queue_name):
def create_queue(self, queue_name: str) -> boto3.resources.base.ServiceResource:
"""
Create queue by name - may not work in production owing to permissions - mostly a local convenience function
"""
logger.info(f"Queue {queue_name} doesn't exist - creating")
return self.sqs.create_queue(QueueName=queue_name)

def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.ServiceResource]:
"""
Initialize all queues for the given worker - try to create them if they are not found by name for whatever reason
"""
try:
return self.sqs.queues.filter(QueueNamePrefix=queue_name)
found_queues = [q for q in self.sqs.queues.filter(QueueNamePrefix=queue_name)]
if found_queues:
return found_queues
else:
return [self.create_queue(queue_name)]
except botocore.exceptions.ClientError as e:
logger.info(f"Queue {queue_name} doesn't exist - creating")
if e.response['Error']['Code'] == "AWS.SimpleQueueService.NonExistentQueue":
return [self.sqs.create_queue(QueueName=queue_name)]
return [self.create_queue(queue_name)]
else:
raise

Expand Down Expand Up @@ -140,8 +169,7 @@ def receive_messages(self, batch_size: int = 1) -> List[Tuple[Dict[str, Any], bo
for queue in self.input_queues:
if batch_size <= 0:
break
this_batch_size = min(batch_size, self.batch_size)
batch_messages = queue.receive_messages(MaxNumberOfMessages=this_batch_size)
batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, self.batch_size))
for message in batch_messages:
if batch_size > 0:
messages_with_queues.append((message, queue))
Expand All @@ -152,11 +180,25 @@ def return_response(self, message: Dict[str, Any]):
"""
Send message to output queue
"""
return self.push_message(self.output_queue, message)
return self.push_message(self.output_queue_name, message)

def push_message(self, queue: boto3.resources.base.ServiceResource, message: Dict[str, Any]) -> Dict[str, Any]:
def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource:
"""
Search through queues to find the right one
"""
print(f"Searching for {queue_name}")
all_queues = [self.input_queues, self.output_queues]
print(f"All queues are {all_queues}")
for group in [self.input_queues, self.output_queues]:
for q in group:
qq = self.queue_name(q)
print(f"Comparing {queue_name} against {qq}")
if queue_name == self.queue_name(q):
return q

def push_message(self, queue_name: str, message: Dict[str, Any]) -> Dict[str, Any]:
"""
Actual SQS logic for pushing a message to a queue
"""
queue.send_message(MessageBody=json.dumps(message.dict()))
self.find_queue_by_name(queue_name).send_message(MessageBody=json.dumps(message.dict()))
return message
17 changes: 11 additions & 6 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,26 @@ class FakeSQSMessage(BaseModel):
receipt_handle: str

class TestQueue(unittest.TestCase):
# def overwrite_restrict_queues_by_suffix(queues, suffix):
# return [MagicMock()]
#
# @patch('lib.queue.queue.Queue.restrict_queues_by_suffix', side_effect=overwrite_restrict_queues_by_suffix)
@patch('lib.queue.queue.boto3.resource')
@patch('lib.helpers.get_environment_setting', return_value='us-west-1')
def setUp(self, mock_get_env_setting, mock_boto_resource):
def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queues_by_suffix):
self.model = GenericTransformerModel(None)
self.mock_model = MagicMock()
self.queue_name_input = 'test-input-queue'
self.queue_name_output = 'test-output-queue'
self.queue_name_input = 'mean_tokens__Model'
self.queue_name_output = 'mean_tokens__Model_output'
self.batch_size = 5

# Mock the SQS resource and the queues
self.mock_sqs_resource = MagicMock()
self.mock_input_queue = MagicMock()
self.mock_input_queue.url = "http://queue/mean_tokens__Model"
self.mock_output_queue = MagicMock()

self.mock_sqs_resource.get_queue_by_name.side_effect = [self.mock_input_queue, self.mock_output_queue]
self.mock_output_queue.url = "http://queue/mean_tokens__Model_output"
self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue]
mock_boto_resource.return_value = self.mock_sqs_resource

# Initialize the SQSQueue instance
Expand Down Expand Up @@ -94,7 +99,7 @@ def test_delete_messages_from_queue(self, mock_logger):
def test_push_message(self):
message_to_push = schemas.Message(body={"id": 1, "callback_url": "http://example.com", "text": "This is a test"})
# Call push_message
returned_message = self.queue.push_message(self.mock_output_queue, message_to_push)
returned_message = self.queue.push_message(self.queue_name_output, message_to_push)
# Check if the message was correctly serialized and sent
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "text": "This is a test"}, "response": null}')
self.assertEqual(returned_message, message_to_push)
Expand Down
2 changes: 1 addition & 1 deletion test/lib/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_fingerprint_item(self, mock_push_message, mock_create):
test_data = {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}

response = self.client.post("/fingerprint_item/test_fingerprinter", json=test_data)
mock_create.assert_called_once_with("test_fingerprinter", "test_fingerprinter-output")
mock_create.assert_called_once_with("test_fingerprinter", "test_fingerprinter_output")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Message pushed successfully"})

Expand Down

0 comments on commit 400c332

Please sign in to comment.