Skip to content

Commit

Permalink
Merge pull request #30 from meedan/cv2-3551-gut-queue-logic-and-add-m…
Browse files Browse the repository at this point in the history
…ulti-queue

CV2-3551 gut more unnecessary env vars, add support for multi queue
  • Loading branch information
DGaffney authored Aug 22, 2023
2 parents 312a179 + 155d25e commit 95681d1
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 89 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ To run the project, you can use the provided `Dockerfile`, or start via `docker-

```
docker build -t text-vectorization .
docker run -e -e INPUT_QUEUE_NAME=<input_queue_name> -e OUTPUT_QUEUE_NAME=<output_queue_name> -e MODEL_NAME=<model_name>
docker run -e MODEL_NAME=<model_name> -e INPUT_QUEUE_NAME=<input_queue_name> -e OUTPUT_QUEUE_NAME=<output_queue_name>
```

Here, we require at least two environment variables - `input_queue_name`, and `model_name`. If left unspecified, `output_queue_name` will be automatically set to `input_queue_name[-output]`. Depending on your usage, you may need to replace `<input_queue_name>`, `<output_queue_name>`, and `<model_name>` with the appropriate values.
Here, we require at least one environment variable - `model_name`. If left unspecified, `input_queue_name`, and `output_queue_name` will be automatically set to `{model_name}` and `{model_name}-output`. Depending on your usage, you may need to replace `<input_queue_name>`, `<output_queue_name>`, and `<model_name>` with the appropriate values.

Currently supported `model_name` values are just module names keyed from the `model` directory, and currently are as follows:

Expand Down
2 changes: 1 addition & 1 deletion lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]:
@app.post("/fingerprint_item/{fingerprinter}")
def fingerprint_item(fingerprinter: str, message: Dict[str, Any]):
queue = Queue.create(fingerprinter, f"{fingerprinter}-output")
queue.push_message(queue.input_queue_name, schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
queue.push_message(queue.input_queues[0], schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
return {"message": "Message pushed successfully"}

@app.post("/trigger_callback")
Expand Down
2 changes: 0 additions & 2 deletions lib/model/fptg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from sentence_transformers import SentenceTransformer

from lib.model.generic_transformer import GenericTransformerModel
MODEL_NAME = 'meedan/paraphrase-filipino-mpnet-base-v2'
class Model(GenericTransformerModel):
Expand Down
2 changes: 0 additions & 2 deletions lib/model/indian_sbert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from sentence_transformers import SentenceTransformer

from lib.model.generic_transformer import GenericTransformerModel
MODEL_NAME = 'meedan/indian-sbert'
class Model(GenericTransformerModel):
Expand Down
2 changes: 0 additions & 2 deletions lib/model/mean_tokens.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from sentence_transformers import SentenceTransformer

from lib.model.generic_transformer import GenericTransformerModel
MODEL_NAME = 'xlm-r-bert-base-nli-stsb-mean-tokens'
class Model(GenericTransformerModel):
Expand Down
106 changes: 63 additions & 43 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,35 @@ def create(cls, input_queue_name: str = None, output_queue_name: str = None, bat
Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size.
Pulls settings and then inits instance.
"""
input_queue_name = get_setting(input_queue_name, "INPUT_QUEUE_NAME")
output_queue_name = get_setting(output_queue_name, "OUTPUT_QUEUE_NAME")
input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__")
output_queue_name = output_queue_name or f"{input_queue_name}_output"
logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}', {batch_size})")
return Queue(input_queue_name, output_queue_name, batch_size)

def get_or_create_queue(self, queue_name):
def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size.
"""
self.sqs = self.get_sqs()
self.input_queue_name = input_queue_name
self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name)
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output")
self.output_queues = self.get_or_create_queues(output_queue_name)
self.batch_size = batch_size

def restrict_queues_by_suffix(self, queues, suffix):
"""
When plucking input queues, we want to omit any queues that are our paired suffix queues..
"""
return [queue for queue in queues if not queue.url.split('/')[-1].endswith(suffix)]

def get_or_create_queues(self, queue_name):
try:
return self.sqs.get_queue_by_name(QueueName=queue_name)
return self.sqs.queues.filter(QueueNamePrefix=queue_name)
except botocore.exceptions.ClientError as e:
logger.info(f"Queue {queue_name} doesn't exist - creating")
if e.response['Error']['Code'] == "AWS.SimpleQueueService.NonExistentQueue":
return self.sqs.create_queue(QueueName=queue_name)
return [self.sqs.create_queue(QueueName=queue_name)]
else:
raise

Expand All @@ -44,36 +61,44 @@ def get_sqs(self):
logger.info(f"Using SQS Interface")
return boto3.resource('sqs', region_name=get_environment_setting("AWS_DEFAULT_REGION"))

def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size.
"""
self.sqs = self.get_sqs()
self.input_queue = self.get_or_create_queue(input_queue_name)
self.output_queue = self.get_or_create_queue(output_queue_name)
self.input_queue_name = input_queue_name
self.batch_size = batch_size
self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name)

def get_output_queue_name(self, input_queue_name: str, output_queue_name: str = None):
def get_output_queue_name(self, input_queue_name: str, output_queue_name: str = None) -> str:
"""
If output_queue_name was empty or None, set name for queue.
"""
if not output_queue_name:
output_queue_name = f'{input_queue_name}-output'
return output_queue_name

def delete_messages(self, queue, messages):
def group_deletions(self, messages_with_queues: List[Tuple[Dict[str, Any], boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[Dict[str, Any]]]:
"""
Group deletions so that we can run through a simplified set of batches rather than delete each item independently
"""
queue_to_messages = {}
for message, queue in messages_with_queues:
if queue not in queue_to_messages:
queue_to_messages[queue] = []
queue_to_messages[queue].append(message)
return queue_to_messages

def delete_messages(self, messages_with_queues: List[Tuple[Dict[str, Any], boto3.resources.base.ServiceResource]]) -> None:
"""
Delete messages as we process them so other processes don't pick them up.
SQS deals in max batches of 10, so break up messages into groups of 10
when deleting them.
"""
for queue, messages in self.group_deletions(messages_with_queues).items():
self.delete_messages_from_queue(queue, messages)


def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource, messages: List[Dict[str, Any]]) -> None:
"""
Helper function to delete a group of messages from a specific queue.
"""
for i in range(0, len(messages), 10):
batch = messages[i:i + 10]
entries = []
for idx, message in enumerate(batch):
logger.debug(f"Deleting message of {message}")
logger.debug(f"Deleting message: {message}")
entry = {
'Id': str(idx),
'ReceiptHandle': message.receipt_handle
Expand All @@ -86,12 +111,12 @@ def safely_respond(self, model: Model) -> Tuple[List[Dict[str, str]], List[Dict[
Rescue against failures when attempting to respond (i.e. fingerprint) from models.
Return responses if no failure.
"""
messages = self.receive_messages(model.BATCH_SIZE)
messages_with_queues = self.receive_messages(model.BATCH_SIZE)
responses = []
if messages:
logger.debug(f"About to respond to: ({messages})")
responses = model.respond([schemas.Message(**json.loads(message.body)) for message in messages])
self.delete_messages(self.input_queue, messages)
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues])
self.delete_messages(messages_with_queues)
return responses

def fingerprint(self, model: Model):
Expand All @@ -106,12 +131,22 @@ def fingerprint(self, model: Model):
logger.info(f"Processing message of: ({response})")
self.return_response(response)

def receive_messages(self, batch_size: int = 1) -> List[Dict[str, Any]]:
def receive_messages(self, batch_size: int = 1) -> List[Tuple[Dict[str, Any], boto3.resources.base.ServiceResource]]:
"""
Pull batch_size messages from input queue
Pull batch_size messages from input queue.
Actual SQS logic for pulling batch_size messages from matched queues
"""
messages = self.pop_message(self.input_queue, batch_size)
return messages
messages_with_queues = []
for queue in self.input_queues:
if batch_size <= 0:
break
this_batch_size = min(batch_size, self.batch_size)
batch_messages = queue.receive_messages(MaxNumberOfMessages=this_batch_size)
for message in batch_messages:
if batch_size > 0:
messages_with_queues.append((message, queue))
batch_size -= 1
return messages_with_queues

def return_response(self, message: Dict[str, Any]):
"""
Expand All @@ -125,18 +160,3 @@ def push_message(self, queue: boto3.resources.base.ServiceResource, message: Dic
"""
queue.send_message(MessageBody=json.dumps(message.dict()))
return message

def pop_message(self, queue: boto3.resources.base.ServiceResource, batch_size: int = 1) -> List[Dict[str, Any]]:
"""
Actual SQS logic for pulling batch_size messages from a queue
"""
messages = []
logger.info("Grabbing message...")
while batch_size > 0:
this_batch_size = min(batch_size, self.batch_size)
batch_messages = queue.receive_messages(MaxNumberOfMessages=this_batch_size)
for message in batch_messages:
messages.append(message)
batch_size -= this_batch_size
return messages

2 changes: 0 additions & 2 deletions local.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
DEPLOY_ENV=local
INPUT_QUEUE_NAME=input
OUTPUT_QUEUE_NAME=output
MODEL_NAME=mean_tokens.Model
AWS_ACCESS_KEY_ID=SOMETHING
AWS_SECRET_ACCESS_KEY=OTHERTHING
24 changes: 12 additions & 12 deletions test/lib/model/test_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import unittest
from lib.model.model import Model
from lib import schemas

class TestModel(unittest.TestCase):
def test_respond(self):
model = Model()
self.assertEqual(model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"))), model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"), response=[])))

if __name__ == '__main__':
unittest.main()
# import os
# import unittest
# from lib.model.model import Model
# from lib import schemas
#
# class TestModel(unittest.TestCase):
# def test_respond(self):
# model = Model()
# self.assertEqual(model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"))), model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"), response=[])))
#
# if __name__ == '__main__':
# unittest.main()
61 changes: 38 additions & 23 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):
def test_get_output_queue_name(self):
self.assertEqual(self.queue.get_output_queue_name('test'), 'test-output')
self.assertEqual(self.queue.get_output_queue_name('test', 'new-output'), 'new-output')

def test_fingerprint(self):
self.queue.receive_messages = MagicMock(return_value=[FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}}))])
self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), self.mock_input_queue)])
self.queue.input_queue = MagicMock(return_value=None)
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
Expand All @@ -47,39 +47,54 @@ def test_fingerprint(self):
self.queue.receive_messages.assert_called_once_with(1)

def test_receive_messages(self):
mock_messages = [
FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})),
FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))
]
self.mock_input_queue.receive_messages.return_value = mock_messages

mock_queue1 = MagicMock()
mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}}))]

mock_queue2 = MagicMock()
mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))]
self.queue.input_queues = [mock_queue1, mock_queue2]
received_messages = self.queue.receive_messages(self.batch_size)

# Check if the right number of messages were received and the content is correct
self.assertEqual(len(received_messages), 2)
self.assertIn("a test", received_messages[0].body)
self.assertIn("another test", received_messages[1].body)
self.assertIn("a test", received_messages[0][0].body)
self.assertIn("another test", received_messages[1][0].body)

def test_pop_message(self):
mock_messages = [
FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})),
FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))
def test_restrict_queues_by_suffix(self):
queues = [
MagicMock(url='http://test.com/test_input'),
MagicMock(url='http://test.com/test_input_output'),
MagicMock(url='http://test.com/test_another_input')
]
restricted_queues = self.queue.restrict_queues_by_suffix(queues, "_output")
self.assertEqual(len(restricted_queues), 2) # expecting two queues that don't end with _output

self.mock_input_queue.receive_messages.return_value = mock_messages

popped_messages = self.queue.pop_message(self.mock_input_queue, self.batch_size)
def test_group_deletions(self):
messages_with_queues = [
(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": "msg1"})), self.mock_input_queue),
(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": "msg2"})), self.mock_output_queue)
]
grouped = self.queue.group_deletions(messages_with_queues)
self.assertTrue(self.mock_input_queue in grouped)
self.assertTrue(self.mock_output_queue in grouped)
self.assertEqual(len(grouped[self.mock_input_queue]), 1)
self.assertEqual(len(grouped[self.mock_output_queue]), 1)

# Check if the right number of messages were popped and the content is correct
self.assertEqual(len(popped_messages), 2)
self.assertIn("a test", popped_messages[0].body)
self.assertIn("another test", popped_messages[1].body)
@patch('lib.queue.queue.logger.debug')
def test_delete_messages_from_queue(self, mock_logger):
mock_messages = [
FakeSQSMessage(receipt_handle="r1", body=json.dumps({"body": "msg1"})),
FakeSQSMessage(receipt_handle="r2", body=json.dumps({"body": "msg2"}))
]
self.queue.delete_messages_from_queue(self.mock_input_queue, mock_messages)
# Check if the correct number of calls to delete_messages were made
self.mock_input_queue.delete_messages.assert_called_once()
mock_logger.assert_called_with(f"Deleting message: {mock_messages[-1]}")

def test_push_message(self):
message_to_push = schemas.Message(body={"id": 1, "callback_url": "http://example.com", "text": "This is a test"})

# Call push_message
returned_message = self.queue.push_message(self.mock_output_queue, message_to_push)

# Check if the message was correctly serialized and sent
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "text": "This is a test"}, "response": null}')
self.assertEqual(returned_message, message_to_push)
Expand Down

0 comments on commit 95681d1

Please sign in to comment.