Skip to content

Commit

Permalink
CV2-3551 add local queue consumption and re-work a ton of the startup…
Browse files Browse the repository at this point in the history
… flow to accommodate
  • Loading branch information
DGaffney committed Aug 29, 2023
1 parent 764df76 commit dd26eb7
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 95 deletions.
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
.PHONY: run run_http run_worker run_test

run:
./start_healthcheck_and_model_engine.sh
./start_all.sh

run_http:
uvicorn main:app --host 0.0.0.0 --reload

run_worker:
python run.py
python run_worker.py

run_processor:
python run_processor.py

run_test:
python -m unittest discover .
python -m pytest test
6 changes: 3 additions & 3 deletions lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from httpx import HTTPStatusError
from fastapi import FastAPI, Request
from pydantic import BaseModel
from lib.queue.queue import Queue
from lib.queue.worker import QueueWorker
from lib.logger import logger
from lib import schemas

Expand All @@ -29,8 +29,8 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]:

@app.post("/fingerprint_item/{fingerprinter}")
def fingerprint_item(fingerprinter: str, message: Dict[str, Any]):
queue = Queue.create(fingerprinter)
queue.push_message(fingerprinter, schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
queue = QueueWorker.create(fingerprinter)
queue.push_message(fingerprinter, schemas.Message(body=message))
return {"message": "Message pushed successfully"}

@app.post("/trigger_callback")
Expand Down
55 changes: 55 additions & 0 deletions lib/queue/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import List
import json

import requests

from lib import schemas
from lib.logger import logger
from lib.helpers import get_setting
from lib.queue.queue import Queue
class QueueProcessor(Queue):
@classmethod
def create(cls, input_queue_name: str = None, batch_size: int = 10):
"""
Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size.
Pulls settings and then inits instance.
"""
input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__")
logger.info(f"Starting queue with: ('{input_queue_name}', {batch_size})")
return QueueProcessor(input_queue_name, batch_size)

def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size.
"""
super().__init__()
self.input_queue_name = input_queue_name
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output")
self.all_queues = self.store_queue_map(self.input_queues)
self.batch_size = batch_size

def send_callbacks(self) -> List[schemas.Message]:
"""
Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth,
pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue.
If failures happen at any point, resend failed messages to input queue.
"""
messages_with_queues = self.receive_messages(self.batch_size)
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
bodies = [schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]
for body in bodies:
self.send_callback(body)
self.delete_messages(messages_with_queues)


def send_callback(self, body):
"""
Rescue against failures when attempting to respond (i.e. fingerprint) from models.
Return responses if no failure.
"""
try:
callback_url = body.get("callback_url")
requests.post(callback_url, json=body)
except Exception as e:
logger.error(f"Callback fail! Failed with {e} on {callback_url} with body of {body}")
73 changes: 11 additions & 62 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,28 @@
import json
from typing import Any, List, Dict, Tuple, Union
from typing import List, Dict, Tuple
import os

import boto3
import botocore

from lib.helpers import get_class, get_setting, get_environment_setting
from lib.model.model import Model
from lib.helpers import get_environment_setting
from lib.logger import logger
from lib import schemas
SQS_MAX_BATCH_SIZE = 10
class Queue:
@classmethod
def create(cls, input_queue_name: str = None, output_queue_name: str = None, batch_size: int = 10):
def __init__(self):
"""
Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size.
Pulls settings and then inits instance.
"""
input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__")
output_queue_name = output_queue_name or f"{input_queue_name}_output"
logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}', {batch_size})")
return Queue(input_queue_name, output_queue_name, batch_size)

def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size.
Start a specific queue - must pass input_queue_name.
"""
self.sqs = self.get_sqs()
self.input_queue_name = input_queue_name
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output")
if output_queue_name:
self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name)
self.output_queues = self.get_or_create_queues(output_queue_name)
self.all_queues = self.store_queue_map()
self.batch_size = batch_size


def store_queue_map(self) -> Dict[str, boto3.resources.base.ServiceResource]:
def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]) -> Dict[str, boto3.resources.base.ServiceResource]:
"""
Store a quick lookup so that we dont loop through this over and over in other places.
"""
queue_map = {}
for group in [self.input_queues, self.output_queues]:
for q in group:
queue_map[self.queue_name(q)] = q
for queue in all_queues:
queue_map[self.queue_name(queue)] = queue
return queue_map

def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str:
Expand Down Expand Up @@ -101,7 +81,7 @@ def get_output_queue_name(self, input_queue_name: str, output_queue_name: str =
If output_queue_name was empty or None, set name for queue.
"""
if not output_queue_name:
output_queue_name = f'{input_queue_name}-output'
output_queue_name = f'{input_queue_name}_output'
return output_queue_name

def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]:
Expand All @@ -115,7 +95,7 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto
queue_to_messages[queue].append(message)
return queue_to_messages

def delete_messages(self, messages_with_queues: List[Tuple[Dict[str, Any], boto3.resources.base.ServiceResource]]) -> None:
def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None:
"""
Delete messages as we process them so other processes don't pick them up.
SQS deals in max batches of 10, so break up messages into groups of 10
Expand All @@ -141,31 +121,6 @@ def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource
entries.append(entry)
queue.delete_messages(Entries=entries)

def safely_respond(self, model: Model) -> List[schemas.Message]:
"""
Rescue against failures when attempting to respond (i.e. fingerprint) from models.
Return responses if no failure.
"""
messages_with_queues = self.receive_messages(model.BATCH_SIZE)
responses = []
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues])
self.delete_messages(messages_with_queues)
return responses

def fingerprint(self, model: Model):
"""
Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth,
pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue.
If failures happen at any point, resend failed messages to input queue.
"""
responses = self.safely_respond(model)
if responses:
for response in responses:
logger.info(f"Processing message of: ({response})")
self.return_response(response)

def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]:
"""
Pull batch_size messages from input queue.
Expand All @@ -175,19 +130,13 @@ def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, b
for queue in self.input_queues:
if batch_size <= 0:
break
batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, self.batch_size))
batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE))
for message in batch_messages:
if batch_size > 0:
messages_with_queues.append((message, queue))
batch_size -= 1
return messages_with_queues

def return_response(self, message: schemas.Message):
"""
Send message to output queue
"""
return self.push_message(self.output_queue_name, message)

def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource:
"""
Search through queues to find the right one
Expand Down
56 changes: 56 additions & 0 deletions lib/queue/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
from typing import List
from lib import schemas
from lib.logger import logger
from lib.queue.queue import Queue
from lib.model.model import Model
from lib.helpers import get_setting
class QueueWorker(Queue):
@classmethod
def create(cls, input_queue_name: str = None):
"""
Instantiate a queue worker. Must pass input_queue_name.
Pulls settings and then inits instance.
"""
input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__")
output_queue_name = f"{input_queue_name}_output"
logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}')")
return QueueWorker(input_queue_name, output_queue_name)

def __init__(self, input_queue_name: str, output_queue_name: str = None):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name.
"""
super().__init__()
self.input_queue_name = input_queue_name
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output")
if output_queue_name:
self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name)
self.output_queues = self.get_or_create_queues(output_queue_name)
self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues] for item in row])

def fingerprint(self, model: Model):
"""
Main routine. Given a model, in a loop, read tasks from input_queue_name,
pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue.
If failures happen at any point, resend failed messages to input queue.
"""
responses = self.safely_respond(model)
if responses:
for response in responses:
logger.info(f"Processing message of: ({response})")
self.return_response(response)

def safely_respond(self, model: Model) -> List[schemas.Message]:
"""
Rescue against failures when attempting to respond (i.e. fingerprint) from models.
Return responses if no failure.
"""
messages_with_queues = self.receive_messages(model.BATCH_SIZE)
responses = []
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues])
self.delete_messages(messages_with_queues)
return responses

3 changes: 1 addition & 2 deletions lib/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,4 @@ class ImageOutput(BaseModel):

class Message(BaseModel):
body: Union[TextInput, VideoInput, AudioInput, ImageInput]
response: Any

response: Any
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ uvicorn[standard]==0.19.0
httpx==0.23.1
huggingface-hub==0.11.0
fasttext==0.9.2
requests==2.31.0
pytest==7.4.0
11 changes: 11 additions & 0 deletions run_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import time
import os
import importlib
from lib.queue.processor import QueueProcessor
from lib.model.model import Model
from lib.logger import logger
queue = QueueProcessor.create()

logger.info("Beginning callback loop...")
while True:
queue.send_callbacks()
4 changes: 2 additions & 2 deletions run.py → run_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
import os
import importlib
from lib.queue.queue import Queue
from lib.queue.worker import QueueWorker
from lib.model.model import Model
from lib.logger import logger
queue = Queue.create()
queue = QueueWorker.create()

model = Model.create()

Expand Down
3 changes: 2 additions & 1 deletion start_healthcheck_and_model_engine.sh → start_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ uvicorn main:app --host 0.0.0.0 --reload &

# Start the second process in the foreground
# This will ensure the script won't exit until this process does
python run.py
python run_worker.py &
python run_processor.py
8 changes: 4 additions & 4 deletions test/lib/model/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def setUp(self):
@patch('urllib.request.Request')
def test_fingerprint_audio_success(self, mock_request, mock_urlopen):
mock_request.return_value = mock_request

# Use the `with` statement for proper file handling
with open("data/test-audio.mp3", 'rb') as f:
contents = f.read()

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
result = self.audio_model.fingerprint(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
Expand All @@ -36,13 +36,13 @@ def test_fingerprint_audio_failure(self, mock_decode_fingerprint, mock_fingerpri
mock_request, mock_urlopen):
mock_fingerprint_file.side_effect = FingerprintGenerationError("Failed to generate fingerprint")
mock_request.return_value = mock_request

# Use the `with` statement for proper file handling
with open("data/test-audio.mp3", 'rb') as f:
contents = f.read()

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
result = self.audio_model.fingerprint(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
Expand Down
5 changes: 5 additions & 0 deletions test/lib/queue/fake_sqs_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel
class FakeSQSMessage(BaseModel):
body: str
receipt_handle: str

Loading

0 comments on commit dd26eb7

Please sign in to comment.