From 599240bce260ab834a4f2a512caca8df9e9fda8b Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Fri, 4 Aug 2023 10:17:04 -0700 Subject: [PATCH] CV2-3575 first pass on adding pydantic schemas to messages --- README.md | 27 +++++++++++++++ lib/model/audio.py | 8 +++-- lib/model/fasttext.py | 9 ++--- lib/model/generic_transformer.py | 8 +++-- lib/model/image.py | 7 ++-- lib/model/model.py | 8 ++--- lib/model/video.py | 3 +- lib/schemas.py | 54 +++++++++++++++++++++++++++++ test/lib/model/test_audio.py | 10 +++--- test/lib/model/test_fasttext.py | 7 ++-- test/lib/model/test_fptg.py | 8 ++--- test/lib/model/test_generic.py | 8 ++--- test/lib/model/test_image.py | 13 +++---- test/lib/model/test_indian_sbert.py | 8 ++--- test/lib/model/test_meantokens.py | 8 ++--- test/lib/model/test_model.py | 6 ++-- test/lib/model/test_video.py | 7 ++-- 17 files changed, 145 insertions(+), 54 deletions(-) create mode 100644 lib/schemas.py diff --git a/README.md b/README.md index 1deb1ec..1a186c3 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ Messages passed to presto input queues must have the following structure, per ea Input Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "text": "The text to be processed by the vectorizer" } @@ -86,6 +87,26 @@ Input Message: Output Message: ``` { + "id": "A unique string ID that identifies the item being processed", + "callback_url": "A unique URL that will be requested upon completion", + "text": "The text to be processed by the vectorizer", + "response": [List of floats representing vectorization results], +} +``` + +#### Language ID +Input Message: +``` +{ + "id": "A unique string ID that identifies the item being processed", + "callback_url": "A unique URL that will be requested upon completion", + "text": "The text to be processed by the vectorizer" +} +``` +Output Message: +``` +{ + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "text": "The text to be processed by the vectorizer", "response": [List of floats representing vectorization results], @@ -96,6 +117,7 @@ Output Message: Input Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "url": "The URL at which the media is located", } @@ -103,6 +125,7 @@ Input Message: Output Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "url": "The URL at which the media is located", "bucket": "bucket within which the .tmk file is stored", @@ -115,6 +138,7 @@ Output Message: Input Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "url": "The URL at which the media is located" } @@ -122,6 +146,7 @@ Input Message: Output Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "url": "The URL at which the media is located", "hash_value": [pyacoustid output hash value for the audio clip], @@ -133,6 +158,7 @@ Output Message: Input Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "url": "The URL at which the media is located", } @@ -140,6 +166,7 @@ Input Message: Output Message: ``` { + "id": "A unique string ID that identifies the item being processed", "callback_url": "A unique URL that will be requested upon completion", "url": "The URL at which the media is located", "hash_value": [pdqhasher output hash value for the image], diff --git a/lib/model/audio.py b/lib/model/audio.py index 94de5a3..4011ad4 100644 --- a/lib/model/audio.py +++ b/lib/model/audio.py @@ -2,9 +2,11 @@ import os import tempfile -from lib.model.model import Model import acoustid +from lib.model.model import Model +from lib import schemas + class Model(Model): def audio_hasher(self, filename: str) -> List[int]: """ @@ -18,8 +20,8 @@ def audio_hasher(self, filename: str) -> List[int]: except acoustid.FingerprintGenerationError: return [] - def fingerprint(self, audio: Dict[str, str]) -> Dict[str, Union[str, List[int]]]: - temp_file_name = self.get_tempfile_for_url(audio.get("body", {})["url"]) + def fingerprint(self, audio: schemas.InputMessage) -> Dict[str, Union[str, List[int]]]: + temp_file_name = self.get_tempfile_for_url(audio.body.url) try: hash_value = self.audio_hasher(temp_file_name) finally: diff --git a/lib/model/fasttext.py b/lib/model/fasttext.py index e4ad0f3..86165bb 100644 --- a/lib/model/fasttext.py +++ b/lib/model/fasttext.py @@ -1,9 +1,10 @@ from typing import Union, Dict, List -from lib.model.model import Model import fasttext from huggingface_hub import hf_hub_download +from lib.model.model import Model +from lib import schemas class FasttextModel(Model): def __init__(self): @@ -14,18 +15,18 @@ def __init__(self): self.model = fasttext.load_model(model_path) - def respond(self, docs: Union[List[Dict[str, str]], Dict[str, str]]) -> List[Dict[str, str]]: + def respond(self, docs: Union[List[schemas.InputMessage], schemas.InputMessage]) -> List[schemas.TextOutput]: """ Force messages as list of messages in case we get a singular item. Then, run fingerprint routine. Respond can probably be genericized across all models. """ if not isinstance(docs, list): docs = [docs] - detectable_texts = [e.get("body", {}).get("text") for e in docs] + detectable_texts = [e.body.text for e in docs] detected_langs = [] for text in detectable_texts: detected_langs.append(self.model.predict(text)[0][0]) for doc, detected_lang in zip(docs, detected_langs): - doc["response"] = detected_lang + doc.response = detected_lang return docs diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 21d18eb..01c56ee 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -2,6 +2,8 @@ from sentence_transformers import SentenceTransformer from lib.model.model import Model +from lib import schemas + class GenericTransformerModel(Model): def __init__(self, model_name: str): """ @@ -11,17 +13,17 @@ def __init__(self, model_name: str): if model_name: self.model = SentenceTransformer(model_name) - def respond(self, docs: Union[List[Dict[str, str]], Dict[str, str]]) -> List[Dict[str, str]]: + def respond(self, docs: Union[List[schemas.InputMessage], schemas.InputMessage]) -> List[schemas.TextOutput]: """ Force messages as list of messages in case we get a singular item. Then, run fingerprint routine. Respond can probably be genericized across all models. """ if not isinstance(docs, list): docs = [docs] - vectorizable_texts = [e.get("body", {}).get("text") for e in docs] + vectorizable_texts = [e.body.text for e in docs] vectorized = self.vectorize(vectorizable_texts) for doc, vector in zip(docs, vectorized): - doc["response"] = vector + doc.response = vector return docs def vectorize(self, texts: List[str]) -> List[List[float]]: diff --git a/lib/model/image.py b/lib/model/image.py index 296a3c9..0f51b9c 100644 --- a/lib/model/image.py +++ b/lib/model/image.py @@ -5,6 +5,7 @@ from lib.model.model import Model from pdqhashing.hasher.pdq_hasher import PDQHasher +from lib import schemas class Model(Model): def compute_pdq(iobytes: io.BytesIO) -> str: @@ -16,20 +17,20 @@ def compute_pdq(iobytes: io.BytesIO) -> str: hash_and_qual = pdq_hasher.fromBufferedImage(iobytes) return hash_and_qual.getHash().dumpBitsFlat() - def get_iobytes_for_image(self, image: Dict[str, str]) -> io.BytesIO: + def get_iobytes_for_image(self, image: schemas.InputMessage) -> io.BytesIO: """ Read file as bytes after requesting based on URL. """ return io.BytesIO( urllib.request.urlopen( urllib.request.Request( - image.get("body", {})["url"], + image.body.url, headers={'User-Agent': 'Mozilla/5.0'} ) ).read() ) - def fingerprint(self, image: Dict[str, str]) -> Dict[str, str]: + def fingerprint(self, image: schemas.InputMessage) -> schemas.ImageOutput: """ Generic function for returning the actual response. """ diff --git a/lib/model/model.py b/lib/model/model.py index 2900dec..9140fbe 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -6,7 +6,7 @@ import urllib.request from lib.helpers import get_class - +from lib import schemas class Model(ABC): BATCH_SIZE = 1 def get_tempfile_for_url(self, url: str) -> str: @@ -33,17 +33,17 @@ def get_tempfile(self) -> Any: """ return tempfile.NamedTemporaryFile() - def fingerprint(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + def fingerprint(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: return [] - def respond(self, messages: Union[List[Dict[str, str]], Dict[str, str]]) -> List[Dict[str, str]]: + def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: """ Force messages as list of messages in case we get a singular item. Then, run fingerprint routine. """ if not isinstance(messages, list): messages = [messages] for message in messages: - message["response"] = self.fingerprint(message) + message.response = self.fingerprint(message) return messages @classmethod diff --git a/lib/model/video.py b/lib/model/video.py index 8332779..422b13b 100644 --- a/lib/model/video.py +++ b/lib/model/video.py @@ -9,6 +9,7 @@ import urllib.request from lib.model.model import Model from lib import s3 +from lib import schemas class Model(Model): def __init__(self): @@ -40,7 +41,7 @@ def tmk_bucket(self) -> str: """ return "presto_tmk_videos" - def fingerprint(self, video: Dict[str, str]) -> Dict[str, str]: + def fingerprint(self, video: schemas.Message) -> schemas.VideoOutput: """ Main fingerprinting routine - download video to disk, get short hash, then calculate larger TMK hash and upload that to S3. diff --git a/lib/schemas.py b/lib/schemas.py new file mode 100644 index 0000000..1746ecd --- /dev/null +++ b/lib/schemas.py @@ -0,0 +1,54 @@ +from typing import Any, List, Union +from pydantic import BaseModel, HttpUrl + +# Output hash values can be of different types. +HashValue = Union[List[float], str, int] +class Message(BaseModel): + body: Union[TextIinput, VideoInput, AudioInput, ImageInput] + response: Any + +class TextInput(BaseModel): + id: str + callback_url: HttpUrl + text: str + +class TextOutput(BaseModel): + id: str + callback_url: HttpUrl + text: str + response: Union[List[float], str] + +class VideoInput(BaseModel): + id: str + callback_url: HttpUrl + url: HttpUrl + +class VideoOutput(BaseModel): + id: str + callback_url: HttpUrl + url: HttpUrl + bucket: str + outfile: str + hash_value: HashValue + +class AudioInput(BaseModel): + id: str + callback_url: HttpUrl + url: HttpUrl + +class AudioOutput(BaseModel): + id: str + callback_url: HttpUrl + url: HttpUrl + hash_value: HashValue + +class ImageInput(BaseModel): + id: str + callback_url: HttpUrl + url: HttpUrl + +class ImageOutput(BaseModel): + id: str + callback_url: HttpUrl + url: HttpUrl + hash_value: HashValue \ No newline at end of file diff --git a/test/lib/model/test_audio.py b/test/lib/model/test_audio.py index b31093a..e3d9cc0 100644 --- a/test/lib/model/test_audio.py +++ b/test/lib/model/test_audio.py @@ -5,7 +5,7 @@ import acoustid from acoustid import FingerprintGenerationError - +from lib import schemas class TestAudio(unittest.TestCase): def setUp(self): self.audio_model = Model() @@ -15,9 +15,9 @@ def setUp(self): def test_fingerprint_audio_success(self, mock_request, mock_urlopen): mock_request.return_value = mock_request mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=open("data/test-audio.mp3", 'rb').read())) - audio = {"body": {"url": "https://example.com/audio.mp3"}} + audio = schemas.Message(body=schemas.AudioInput(url="https://example.com/audio.mp3")) result = self.audio_model.fingerprint(audio) - mock_request.assert_called_once_with(audio["body"]["url"], headers={'User-Agent': 'Mozilla/5.0'}) + mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) mock_urlopen.assert_called_once_with(mock_request) self.assertEqual(list, type(result["hash_value"])) @@ -30,9 +30,9 @@ def test_fingerprint_audio_failure(self, mock_decode_fingerprint, mock_fingerpri mock_fingerprint_file.side_effect = FingerprintGenerationError("Failed to generate fingerprint") mock_request.return_value = mock_request mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=open("data/test-audio.mp3", 'rb').read())) - audio = {"body": {"url": "https://example.com/audio.mp3"}} + audio = schemas.Message(body=schemas.AudioInput(url="https://example.com/audio.mp3")) result = self.audio_model.fingerprint(audio) - mock_request.assert_called_once_with(audio["body"]["url"], headers={'User-Agent': 'Mozilla/5.0'}) + mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) mock_urlopen.assert_called_once_with(mock_request) self.assertEqual([], result["hash_value"]) diff --git a/test/lib/model/test_fasttext.py b/test/lib/model/test_fasttext.py index ef386fd..6a92bd2 100644 --- a/test/lib/model/test_fasttext.py +++ b/test/lib/model/test_fasttext.py @@ -5,6 +5,7 @@ import numpy as np from lib.model.fasttext import FasttextModel +from lib import schemas class TestFasttextModel(unittest.TestCase): def setUp(self): @@ -12,13 +13,13 @@ def setUp(self): self.mock_model = MagicMock() def test_respond(self): - query = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "今天是星期二"}}] + query = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="今天是星期二"))] response = self.model.respond(query) self.assertEqual(len(response), 2) - self.assertEqual(response[0]["response"], "__label__eng_Latn") - self.assertEqual(response[1]["response"], "__label__zho_Hans") + self.assertEqual(response[0].response, "__label__eng_Latn") + self.assertEqual(response[1].response, "__label__zho_Hans") if __name__ == '__main__': unittest.main() diff --git a/test/lib/model/test_fptg.py b/test/lib/model/test_fptg.py index 874e016..63f2aca 100644 --- a/test/lib/model/test_fptg.py +++ b/test/lib/model/test_fptg.py @@ -5,6 +5,7 @@ import numpy as np from lib.model.generic_transformer import GenericTransformerModel +from lib import schemas class TestMdebertaFilipino(unittest.TestCase): def setUp(self): @@ -12,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "I'm doing great, thanks!"}}] + texts = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="I'm doing great, thanks!"))] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts) @@ -21,12 +22,11 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = {"body": {"text": "Anong pangalan mo?"}} + query = schemas.Message(body=schemas.TextInput(text="Anong pangalan mo?")) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) - self.assertIn("response", response[0]) - self.assertEqual(response[0]["response"], [1, 2, 3]) + self.assertEqual(response[0].response, [1, 2, 3]) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/lib/model/test_generic.py b/test/lib/model/test_generic.py index 945f943..3fab472 100644 --- a/test/lib/model/test_generic.py +++ b/test/lib/model/test_generic.py @@ -5,6 +5,7 @@ import numpy as np from lib.model.generic_transformer import GenericTransformerModel +from lib import schemas class TestGenericTransformerModel(unittest.TestCase): def setUp(self): @@ -12,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "I'm doing great, thanks!"}}] + texts = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="I'm doing great, thanks!"))] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts) @@ -21,12 +22,11 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = {"body": {"text": "Anong pangalan mo?"}} + query = schemas.Message(body=schemas.TextInput(text="Anong pangalan mo?")) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) - self.assertIn("response", response[0]) - self.assertEqual(response[0]["response"], [1, 2, 3]) + self.assertEqual(response[0].response, [1, 2, 3]) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/lib/model/test_image.py b/test/lib/model/test_image.py index d61caa0..5ca1285 100644 --- a/test/lib/model/test_image.py +++ b/test/lib/model/test_image.py @@ -5,6 +5,7 @@ from typing import Dict from lib.model.image import Model +from lib import schemas class TestModel(unittest.TestCase): @@ -20,25 +21,25 @@ def test_get_iobytes_for_image(self, mock_urlopen): mock_response = Mock() mock_response.read.return_value = open("img/presto_flowchart.png", "rb").read() mock_urlopen.return_value = mock_response - image_dict = {"body": {"url": "http://example.com/image.jpg"}} - result = Model().get_iobytes_for_image(image_dict) + image = schemas.Message(body=schemas.ImageInput(url="http://example.com/image.jpg")) + result = Model().get_iobytes_for_image(image) self.assertIsInstance(result, io.BytesIO) self.assertEqual(result.read(), open("img/presto_flowchart.png", "rb").read()) @patch("urllib.request.urlopen") def test_get_iobytes_for_image_raises_error(self, mock_urlopen): mock_urlopen.side_effect = URLError('test error') - image_dict = {"body": {"url": "http://example.com/image.jpg"}} + image = schemas.Message(body=schemas.ImageInput(url="http://example.com/image.jpg")) with self.assertRaises(URLError): - Model().get_iobytes_for_image(image_dict) + Model().get_iobytes_for_image(image) @patch.object(Model, "get_iobytes_for_image") @patch.object(Model, "compute_pdq") def test_fingerprint(self, mock_compute_pdq, mock_get_iobytes_for_image): mock_compute_pdq.return_value = "1001" mock_get_iobytes_for_image.return_value = io.BytesIO(b"image_bytes") - image_dict = {"body": {"url": "http://example.com/image.jpg"}} - result = Model().fingerprint(image_dict) + image = schemas.Message(body=schemas.ImageInput(url="http://example.com/image.jpg")) + result = Model().fingerprint(image) self.assertEqual(result, {"hash_value": "1001"}) diff --git a/test/lib/model/test_indian_sbert.py b/test/lib/model/test_indian_sbert.py index 3a5ebba..8858280 100644 --- a/test/lib/model/test_indian_sbert.py +++ b/test/lib/model/test_indian_sbert.py @@ -5,6 +5,7 @@ import numpy as np from lib.model.generic_transformer import GenericTransformerModel +from lib import schemas class TestIndianSbert(unittest.TestCase): def setUp(self): @@ -12,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "I'm doing great, thanks!"}}] + texts = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="I'm doing great, thanks!"))] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts) @@ -21,12 +22,11 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = {"body": {"text": "What is the capital of India?"}} + query = schemas.Message(body=schemas.TextInput(text="What is the capital of India?")) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) - self.assertIn("response", response[0]) - self.assertEqual(response[0]["response"], [1, 2, 3]) + self.assertEqual(response[0].response, [1, 2, 3]) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/lib/model/test_meantokens.py b/test/lib/model/test_meantokens.py index 7a95b96..20e5b96 100644 --- a/test/lib/model/test_meantokens.py +++ b/test/lib/model/test_meantokens.py @@ -5,6 +5,7 @@ import numpy as np from lib.model.generic_transformer import GenericTransformerModel +from lib import schemas class TestXlmRBertBaseNliStsbMeanTokens(unittest.TestCase): def setUp(self): @@ -12,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [{"body": {"text": "Hello, how are you?"}}, {"body": {"text": "I'm doing great, thanks!"}}] + texts = [schemas.Message(body=schemas.TextInput(text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(text="I'm doing great, thanks!"))] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts) @@ -21,12 +22,11 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = {"body": {"text": "What is the capital of France?"}} + query = schemas.Message(body=schemas.TextInput(text="What is the capital of France?")) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) - self.assertIn("response", response[0]) - self.assertEqual(response[0]["response"], [1, 2, 3]) + self.assertEqual(response[0].response, [1, 2, 3]) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/lib/model/test_model.py b/test/lib/model/test_model.py index 3f76c06..217c3af 100644 --- a/test/lib/model/test_model.py +++ b/test/lib/model/test_model.py @@ -1,12 +1,12 @@ import os import unittest from lib.model.model import Model - +from lib import schemas class TestModel(unittest.TestCase): def test_respond(self): model = Model() - self.assertEqual(model.respond({"body": {"text": "hello"}}), [{'response': [], 'body': {'text': 'hello'}}]) - + self.assertEqual(model.respond(schemas.Message(body=schemas.TextInput(text="hello"))), [{'response': [], 'body': {'text': 'hello'}}]) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/lib/model/test_video.py b/test/lib/model/test_video.py index 6c67982..67a4932 100644 --- a/test/lib/model/test_video.py +++ b/test/lib/model/test_video.py @@ -6,9 +6,10 @@ import uuid import pathlib import urllib.request +import tmkpy from lib.model.video import Model from lib import s3 -import tmkpy +from lib import schemas class TestVideoModel(unittest.TestCase): def setUp(self): @@ -29,7 +30,7 @@ def test_fingerprint_video(self, mock_pathlib, mock_upload_file_to_s3, mock_hash_video_output.getPureAverageFeature.return_value = "hash_value" mock_hash_video.return_value = mock_hash_video_output mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=open("data/test-video.mp4", "rb").read())) - self.video_model.fingerprint({"body": {"url": "http://example.com/video.mp4"}}) + self.video_model.fingerprint(schemas.Message(body=schemas.VideoInput(url="http://example.com/video.mp4"))) mock_urlopen.assert_called_once() mock_hash_video.assert_called_once_with(ANY, "/usr/local/bin/ffmpeg") @@ -54,7 +55,7 @@ def test_respond_with_single_video(self): self.assertEqual(result, [video]) def test_respond_with_multiple_videos(self): - videos = [{"body": {"url": "http://example.com/video1.mp4"}}, {"body": {"url": "http://example.com/video2.mp4"}}] + videos = [schemas.Message(body=schemas.VideoInput(url="http://example.com/video1.mp4")), schemas.Message(body=schemas.VideoInput(url="http://example.com/video2.mp4"))] mock_fingerprint = MagicMock() self.video_model.fingerprint = mock_fingerprint result = self.video_model.respond(videos)