Skip to content

Commit

Permalink
Merge pull request #34 from meedan/cv2-3435-add-presto-to-alegre
Browse files Browse the repository at this point in the history
CV2-3435 minor tweaks as per alegre audio replacement work
  • Loading branch information
DGaffney authored Sep 12, 2023
2 parents ebb029a + 6d107c0 commit 6809ca0
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 42 deletions.
8 changes: 3 additions & 5 deletions lib/http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# from lib import schemas
# from lib.queue.worker import QueueWorker
# queue = QueueWorker.create("mean_tokens__Model")
# queue.push_message("mean_tokens__Model", schemas.Message(body={"callback_url": "http://0.0.0.0:8000/echo", "id": 123, "text": "Some text to vectorize"}))
import json
import datetime
from typing import Any, Dict
Expand All @@ -12,6 +8,7 @@
from lib.queue.worker import QueueWorker
from lib.logger import logger
from lib import schemas
from lib.sentry import sentry_sdk

app = FastAPI()

Expand All @@ -33,9 +30,10 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]:

@app.post("/process_item/{process_name}")
def process_item(process_name: str, message: Dict[str, Any]):
logger.info(message)
queue = QueueWorker.create(process_name)
queue.push_message(process_name, schemas.Message(body=message))
return {"message": "Message pushed successfully"}
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
async def process_item(message: Dict[str, Any]):
Expand Down
5 changes: 3 additions & 2 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.Ser
"""
try:
found_queues = [q for q in self.sqs.queues.filter(QueueNamePrefix=queue_name)]
if found_queues:
return found_queues
exact_match_queues = [queue for queue in found_queues if queue.attributes['QueueArn'].split(':')[-1] == queue_name]
if exact_match_queues:
return exact_match_queues
else:
return [self.create_queue(queue_name)]
except botocore.exceptions.ClientError as e:
Expand Down
5 changes: 4 additions & 1 deletion lib/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def safely_respond(self, model: Model) -> List[schemas.Message]:
responses = []
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues])
try:
responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues])
except Exception as e:
logger.error(e)
self.delete_messages(messages_with_queues)
return responses

41 changes: 24 additions & 17 deletions lib/schemas.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,60 @@
from typing import Any, List, Union
from pydantic import BaseModel, HttpUrl
from typing import Any, List, Optional, Union
from pydantic import BaseModel

# Output hash values can be of different types.
HashValue = Union[List[float], str, int]
class TextInput(BaseModel):
id: str
callback_url: HttpUrl
callback_url: str
text: str

class TextOutput(BaseModel):
id: str
callback_url: HttpUrl
callback_url: str
text: str

class VideoInput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
callback_url: str
url: str

class VideoOutput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
callback_url: str
url: str
bucket: str
outfile: str
hash_value: HashValue

class AudioInput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
callback_url: str
url: str

class AudioOutput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
callback_url: str
url: str
hash_value: HashValue

class ImageInput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
callback_url: str
url: str

class ImageOutput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
callback_url: str
url: str
hash_value: HashValue

class GenericInput(BaseModel):
id: str
callback_url: str
url: Optional[str] = None
text: Optional[str] = None
raw: Optional[dict] = {}

class Message(BaseModel):
body: Union[TextInput, VideoInput, AudioInput, ImageInput]
body: GenericInput
response: Any
9 changes: 9 additions & 0 deletions lib/sentry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os
import sentry_sdk
from lib.helpers import get_environment_setting

sentry_sdk.init(
dsn=get_environment_setting('sentry_sdk_dsn'),
environment=get_environment_setting("DEPLOY_ENV"),
traces_sample_rate=1.0,
)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ httpx==0.23.1
huggingface-hub==0.11.0
fasttext==0.9.2
requests==2.31.0
pytest==7.4.0
pytest==7.4.0
sentry-sdk==1.30.0
1 change: 1 addition & 0 deletions run_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.queue.processor import QueueProcessor
from lib.model.model import Model
from lib.logger import logger
from lib.sentry import sentry_sdk
queue = QueueProcessor.create()

logger.info("Beginning callback loop...")
Expand Down
3 changes: 2 additions & 1 deletion run_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from lib.queue.worker import QueueWorker
from lib.model.model import Model
from lib.logger import logger
from lib.sentry import sentry_sdk
queue = QueueWorker.create()

model = Model.create()

logger.info("Beginning work loop...")
while True:
queue.work(model)
queue.process(model)
6 changes: 4 additions & 2 deletions test/lib/model/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from acoustid import FingerprintGenerationError

from lib import schemas

FINGERPRINT_RESPONSE = (170.6, b'AQAA3VFYJYrGJMj74EOZUCfCHGqYLBZO8UiX5bie47sCV0xwBTe49IiVHHrQnIImJyP-44rxI2cYiuiHMCMDPcqJrBcwnYeryBX6rccR_4Iy_YhfXESzqELJ5ASTLwhvNM94KDp9_IB_6NqDZ5I9_IWYvDiNCc1z8IeuHXkYpfhSg8su3M2K5lkrFM-PK3mQH8lznEpidLEoNAeLyWispQpqvfgRZjp0lHaENAmzBeamoRIZMrha5IsyHM6H7-jRhJlSBU1FLgiv4xlKUQmNptGOU3jzIj80Jk5xsQp0UegxJtmSCpeS5PiDozz0MAb5BG5z9MEPIcy0HeWD58M_4sotlNOF8UeuLJEgJt4xkUee4cflI1nMI4uciBLeGu9z9NjH4x9iSXoELYs04pqCSCvx5ei1Tzi3NMFRmsa2DD2POxVCR4IPMSfySC-u0EKuE6IOqz_6zJh8BzZlgc1IQkyTGdeLa4cT7bi2E30e_OgTI4xDPCGLJ_gvZHlwT7EgJc2XIBY_4fnBPENC_YilsGjDJzhJoeyCJn9A1kaeDUw4VA_-41uDGycO8w_eWlCU66iio0eYL8hVK_gD5QlyMR7hzzh-vDm6JE_hcTpq5cFTdFcKZfHxRMTZCS2VHKdOfDve5Hh0hCV9JEtMSbhxSSMuHU9y4kaTx5guHIGsoEAAwoASjmDlkSAEOCSoQEw4IDgghiguAEZAAMaAAYYAhBhACBEiiAGAIUCUUUgSESjgSBlKjZEEEIAFUEIBBRBinAAplFJKAIYQEAQSA4ywACkjgBFMAEoAQgYQwARB1gFmBCAECAAIMYYIoBxBBAAAFCKAAEgIBAQgAghgihIWBACEIUEIJEZIZIBRACGAGAEEIAGAUIBIhBCgRkI')
class TestAudio(unittest.TestCase):
def setUp(self):
self.audio_model = Model()

@patch('urllib.request.urlopen')
@patch('urllib.request.Request')
def test_process_audio_success(self, mock_request, mock_urlopen):
@patch('acoustid.fingerprint_file')
def test_process_audio_success(self, mock_fingerprint_file, mock_request, mock_urlopen):
mock_fingerprint_file.return_value = FINGERPRINT_RESPONSE
mock_request.return_value = mock_request

# Use the `with` statement for proper file handling
Expand Down
24 changes: 13 additions & 11 deletions test/lib/model/test_fasttext.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import os
import unittest
from unittest.mock import MagicMock

import numpy as np

from unittest.mock import patch, MagicMock
from lib.model.fasttext import FasttextModel
from lib import schemas

class TestFasttextModel(unittest.TestCase):
def setUp(self):
self.model = FasttextModel()
self.mock_model = MagicMock()

def test_respond(self):
@patch('lib.model.fasttext.hf_hub_download')
@patch('lib.model.fasttext.fasttext.load_model')
def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download):
mock_hf_hub_download.return_value = 'mocked_path'
mock_fasttext_load_model.return_value = self.mock_model
self.mock_model.predict.return_value = (['__label__eng_Latn'], [0.9])

model = FasttextModel() # Now it uses mocked functions
query = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="今天是星期二"))]

response = self.model.respond(query)
response = model.respond(query)

self.assertEqual(len(response), 2)
self.assertEqual(response[0].response, "__label__eng_Latn")
self.assertEqual(response[1].response, "__label__zho_Hans")
self.assertEqual(response[0].response, '__label__eng_Latn')
self.assertEqual(response[1].response, '__label__eng_Latn') # Mocked, so it will be the same

if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue
self.mock_sqs_resource = MagicMock()
self.mock_input_queue = MagicMock()
self.mock_input_queue.url = "http://queue/mean_tokens__Model"
self.mock_input_queue.attributes = {"QueueArn": "queue:mean_tokens__Model"}
self.mock_output_queue = MagicMock()
self.mock_output_queue.url = "http://queue/mean_tokens__Model_output"
self.mock_output_queue.attributes = {"QueueArn": "queue:mean_tokens__Model_output"}
self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue]
mock_boto_resource.return_value = self.mock_sqs_resource

Expand Down Expand Up @@ -102,7 +104,7 @@ def test_push_message(self):
# Call push_message
returned_message = self.queue.push_message(self.queue_name_output, message_to_push)
# Check if the message was correctly serialized and sent
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "text": "This is a test"}, "response": null}')
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}}, "response": null}')
self.assertEqual(returned_message, message_to_push)

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/lib/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_process_item(self, mock_push_message, mock_create):
response = self.client.post("/process_item/test_process", json=test_data)
mock_create.assert_called_once_with("test_process")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Message pushed successfully"})
self.assertEqual(response.json(), {"message": "Message pushed successfully", "queue": "test_process", "body": test_data})


@patch('lib.http.post_url')
Expand Down

0 comments on commit 6809ca0

Please sign in to comment.