Skip to content

Commit

Permalink
CV2-3575 first pass on adding pydantic schemas to messages
Browse files Browse the repository at this point in the history
  • Loading branch information
DGaffney committed Aug 4, 2023
1 parent d7053f7 commit 599240b
Show file tree
Hide file tree
Showing 17 changed files with 145 additions and 54 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,34 @@ Messages passed to presto input queues must have the following structure, per ea
Input Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"text": "The text to be processed by the vectorizer"
}
```
Output Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"text": "The text to be processed by the vectorizer",
"response": [List of floats representing vectorization results],
}
```

#### Language ID
Input Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"text": "The text to be processed by the vectorizer"
}
```
Output Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"text": "The text to be processed by the vectorizer",
"response": [List of floats representing vectorization results],
Expand All @@ -96,13 +117,15 @@ Output Message:
Input Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"url": "The URL at which the media is located",
}
```
Output Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"url": "The URL at which the media is located",
"bucket": "bucket within which the .tmk file is stored",
Expand All @@ -115,13 +138,15 @@ Output Message:
Input Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"url": "The URL at which the media is located"
}
```
Output Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"url": "The URL at which the media is located",
"hash_value": [pyacoustid output hash value for the audio clip],
Expand All @@ -133,13 +158,15 @@ Output Message:
Input Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"url": "The URL at which the media is located",
}
```
Output Message:
```
{
"id": "A unique string ID that identifies the item being processed",
"callback_url": "A unique URL that will be requested upon completion",
"url": "The URL at which the media is located",
"hash_value": [pdqhasher output hash value for the image],
Expand Down
8 changes: 5 additions & 3 deletions lib/model/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import tempfile

from lib.model.model import Model
import acoustid

from lib.model.model import Model
from lib import schemas

class Model(Model):
def audio_hasher(self, filename: str) -> List[int]:
"""
Expand All @@ -18,8 +20,8 @@ def audio_hasher(self, filename: str) -> List[int]:
except acoustid.FingerprintGenerationError:
return []

def fingerprint(self, audio: Dict[str, str]) -> Dict[str, Union[str, List[int]]]:
temp_file_name = self.get_tempfile_for_url(audio.get("body", {})["url"])
def fingerprint(self, audio: schemas.InputMessage) -> Dict[str, Union[str, List[int]]]:
temp_file_name = self.get_tempfile_for_url(audio.body.url)
try:
hash_value = self.audio_hasher(temp_file_name)
finally:
Expand Down
9 changes: 5 additions & 4 deletions lib/model/fasttext.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union, Dict, List
from lib.model.model import Model

import fasttext
from huggingface_hub import hf_hub_download

from lib.model.model import Model
from lib import schemas

class FasttextModel(Model):
def __init__(self):
Expand All @@ -14,18 +15,18 @@ def __init__(self):
self.model = fasttext.load_model(model_path)


def respond(self, docs: Union[List[Dict[str, str]], Dict[str, str]]) -> List[Dict[str, str]]:
def respond(self, docs: Union[List[schemas.InputMessage], schemas.InputMessage]) -> List[schemas.TextOutput]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
"""
if not isinstance(docs, list):
docs = [docs]
detectable_texts = [e.get("body", {}).get("text") for e in docs]
detectable_texts = [e.body.text for e in docs]
detected_langs = []
for text in detectable_texts:
detected_langs.append(self.model.predict(text)[0][0])

for doc, detected_lang in zip(docs, detected_langs):
doc["response"] = detected_lang
doc.response = detected_lang
return docs
8 changes: 5 additions & 3 deletions lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from sentence_transformers import SentenceTransformer

from lib.model.model import Model
from lib import schemas

class GenericTransformerModel(Model):
def __init__(self, model_name: str):
"""
Expand All @@ -11,17 +13,17 @@ def __init__(self, model_name: str):
if model_name:
self.model = SentenceTransformer(model_name)

def respond(self, docs: Union[List[Dict[str, str]], Dict[str, str]]) -> List[Dict[str, str]]:
def respond(self, docs: Union[List[schemas.InputMessage], schemas.InputMessage]) -> List[schemas.TextOutput]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
"""
if not isinstance(docs, list):
docs = [docs]
vectorizable_texts = [e.get("body", {}).get("text") for e in docs]
vectorizable_texts = [e.body.text for e in docs]
vectorized = self.vectorize(vectorizable_texts)
for doc, vector in zip(docs, vectorized):
doc["response"] = vector
doc.response = vector
return docs

def vectorize(self, texts: List[str]) -> List[List[float]]:
Expand Down
7 changes: 4 additions & 3 deletions lib/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.model.model import Model

from pdqhashing.hasher.pdq_hasher import PDQHasher
from lib import schemas

class Model(Model):
def compute_pdq(iobytes: io.BytesIO) -> str:
Expand All @@ -16,20 +17,20 @@ def compute_pdq(iobytes: io.BytesIO) -> str:
hash_and_qual = pdq_hasher.fromBufferedImage(iobytes)
return hash_and_qual.getHash().dumpBitsFlat()

def get_iobytes_for_image(self, image: Dict[str, str]) -> io.BytesIO:
def get_iobytes_for_image(self, image: schemas.InputMessage) -> io.BytesIO:
"""
Read file as bytes after requesting based on URL.
"""
return io.BytesIO(
urllib.request.urlopen(
urllib.request.Request(
image.get("body", {})["url"],
image.body.url,
headers={'User-Agent': 'Mozilla/5.0'}
)
).read()
)

def fingerprint(self, image: Dict[str, str]) -> Dict[str, str]:
def fingerprint(self, image: schemas.InputMessage) -> schemas.ImageOutput:
"""
Generic function for returning the actual response.
"""
Expand Down
8 changes: 4 additions & 4 deletions lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.request

from lib.helpers import get_class

from lib import schemas
class Model(ABC):
BATCH_SIZE = 1
def get_tempfile_for_url(self, url: str) -> str:
Expand All @@ -33,17 +33,17 @@ def get_tempfile(self) -> Any:
"""
return tempfile.NamedTemporaryFile()

def fingerprint(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
def fingerprint(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
return []

def respond(self, messages: Union[List[Dict[str, str]], Dict[str, str]]) -> List[Dict[str, str]]:
def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
"""
if not isinstance(messages, list):
messages = [messages]
for message in messages:
message["response"] = self.fingerprint(message)
message.response = self.fingerprint(message)
return messages

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion lib/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import urllib.request
from lib.model.model import Model
from lib import s3
from lib import schemas

class Model(Model):
def __init__(self):
Expand Down Expand Up @@ -40,7 +41,7 @@ def tmk_bucket(self) -> str:
"""
return "presto_tmk_videos"

def fingerprint(self, video: Dict[str, str]) -> Dict[str, str]:
def fingerprint(self, video: schemas.Message) -> schemas.VideoOutput:
"""
Main fingerprinting routine - download video to disk, get short hash,
then calculate larger TMK hash and upload that to S3.
Expand Down
54 changes: 54 additions & 0 deletions lib/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, List, Union
from pydantic import BaseModel, HttpUrl

# Output hash values can be of different types.
HashValue = Union[List[float], str, int]
class Message(BaseModel):
body: Union[TextIinput, VideoInput, AudioInput, ImageInput]
response: Any

class TextInput(BaseModel):
id: str
callback_url: HttpUrl
text: str

class TextOutput(BaseModel):
id: str
callback_url: HttpUrl
text: str
response: Union[List[float], str]

class VideoInput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl

class VideoOutput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
bucket: str
outfile: str
hash_value: HashValue

class AudioInput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl

class AudioOutput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
hash_value: HashValue

class ImageInput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl

class ImageOutput(BaseModel):
id: str
callback_url: HttpUrl
url: HttpUrl
hash_value: HashValue
10 changes: 5 additions & 5 deletions test/lib/model/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import acoustid
from acoustid import FingerprintGenerationError


from lib import schemas
class TestAudio(unittest.TestCase):
def setUp(self):
self.audio_model = Model()
Expand All @@ -15,9 +15,9 @@ def setUp(self):
def test_fingerprint_audio_success(self, mock_request, mock_urlopen):
mock_request.return_value = mock_request
mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=open("data/test-audio.mp3", 'rb').read()))
audio = {"body": {"url": "https://example.com/audio.mp3"}}
audio = schemas.Message(body=schemas.AudioInput(url="https://example.com/audio.mp3"))
result = self.audio_model.fingerprint(audio)
mock_request.assert_called_once_with(audio["body"]["url"], headers={'User-Agent': 'Mozilla/5.0'})
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
self.assertEqual(list, type(result["hash_value"]))

Expand All @@ -30,9 +30,9 @@ def test_fingerprint_audio_failure(self, mock_decode_fingerprint, mock_fingerpri
mock_fingerprint_file.side_effect = FingerprintGenerationError("Failed to generate fingerprint")
mock_request.return_value = mock_request
mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=open("data/test-audio.mp3", 'rb').read()))
audio = {"body": {"url": "https://example.com/audio.mp3"}}
audio = schemas.Message(body=schemas.AudioInput(url="https://example.com/audio.mp3"))
result = self.audio_model.fingerprint(audio)
mock_request.assert_called_once_with(audio["body"]["url"], headers={'User-Agent': 'Mozilla/5.0'})
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
self.assertEqual([], result["hash_value"])

Expand Down
7 changes: 4 additions & 3 deletions test/lib/model/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
import numpy as np

from lib.model.fasttext import FasttextModel
from lib import schemas

class TestFasttextModel(unittest.TestCase):
def setUp(self):
self.model = FasttextModel()
self.mock_model = MagicMock()

def test_respond(self):
query = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "今天是星期二"}}]
query = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="今天是星期二"))]

response = self.model.respond(query)

self.assertEqual(len(response), 2)
self.assertEqual(response[0]["response"], "__label__eng_Latn")
self.assertEqual(response[1]["response"], "__label__zho_Hans")
self.assertEqual(response[0].response, "__label__eng_Latn")
self.assertEqual(response[1].response, "__label__zho_Hans")

if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions test/lib/model/test_fptg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import numpy as np

from lib.model.generic_transformer import GenericTransformerModel
from lib import schemas

class TestMdebertaFilipino(unittest.TestCase):
def setUp(self):
self.model = GenericTransformerModel(None)
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "I'm doing great, thanks!"}}]
texts = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="I'm doing great, thanks!"))]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -21,12 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = {"body": {"text": "Anong pangalan mo?"}}
query = schemas.Message(body=schemas.TextInput(text="Anong pangalan mo?"))
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertIn("response", response[0])
self.assertEqual(response[0]["response"], [1, 2, 3])
self.assertEqual(response[0].response, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 599240b

Please sign in to comment.