Skip to content

Commit

Permalink
Merge pull request #80 from meedan/CV2-4136-yake-model
Browse files Browse the repository at this point in the history
CV2-4136-yake-mode
  • Loading branch information
ahmednasserswe authored Mar 23, 2024
2 parents 5c0763d + ddf51aa commit 200722c
Show file tree
Hide file tree
Showing 26 changed files with 370 additions and 76 deletions.
53 changes: 51 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ version: '3.9'
services:
app:
platform: linux/amd64
build:
build:
context: .
args:
- PRESTO_PORT=${PRESTO_PORT}
Expand All @@ -15,13 +15,62 @@ services:
env_file:
- ./.env_file
depends_on:
- elasticmq
elasticmq:
condition: service_healthy
links:
- elasticmq
volumes:
- ./:/app
ports:
- "8000:8000"
environment:
ROLE: server
elasticmq:
image: softwaremill/elasticmq
hostname: presto-elasticmq
ports:
- "9324:9324"
healthcheck:
test: ["CMD","wget","-q","-S","-O","-","127.0.0.1:9324/?Action=ListQueues"]
interval: 10s
timeout: 5s
retries: 10
image:
build: .
platform: linux/amd64
volumes:
- "./:/app"
env_file:
- ./.env_file
environment:
ROLE: worker
MODEL_NAME: image.Model
depends_on:
elasticmq:
condition: service_healthy
audio:
build: .
platform: linux/amd64
volumes:
- "./:/app"
env_file:
- ./.env_file
environment:
ROLE: worker
MODEL_NAME: audio.Model
depends_on:
elasticmq:
condition: service_healthy
yake:
build: .
platform: linux/amd64
volumes:
- "./:/app"
env_file:
- ./.env_file
environment:
ROLE: worker
MODEL_NAME: yake_keywords.Model
depends_on:
elasticmq:
condition: service_healthy
2 changes: 1 addition & 1 deletion lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def process_item(process_name: str, message: Dict[str, Any]):
queue_prefix = Queue.get_queue_prefix()
queue_suffix = Queue.get_queue_suffix()
queue = QueueWorker.create(process_name)
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.Message(body=message, model_name=process_name))
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_message({"body": message, "model_name": process_name}))
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
Expand Down
2 changes: 1 addition & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
hash_value = self.audio_hasher(temp_file_name)
finally:
os.remove(temp_file_name)
return hash_value
return {"hash_value": hash_value}
2 changes: 1 addition & 1 deletion lib/model/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
detected_langs.append({'language': model_language, 'script': model_script, 'score': model_certainty})

for doc, detected_lang in zip(docs, detected_langs):
doc.body.hash_value = detected_lang
doc.body.result = detected_lang
return docs
4 changes: 2 additions & 2 deletions lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
vectorizable_texts = [e.body.text for e in docs]
vectorized = self.vectorize(vectorizable_texts)
for doc, vector in zip(docs, vectorized):
doc.body.hash_value = vector
doc.body.result = vector
return docs

def vectorize(self, texts: List[str]) -> List[List[float]]:
"""
Vectorize the text! Run as batch.
"""
return self.model.encode(texts).tolist()
return {"hash_value": self.model.encode(texts).tolist()}
2 changes: 1 addition & 1 deletion lib/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def process(self, image: schemas.Message) -> schemas.GenericItem:
"""
Generic function for returning the actual response.
"""
return self.compute_pdq(self.get_iobytes_for_image(image))
return {"hash_value": self.compute_pdq(self.get_iobytes_for_image(image))}
2 changes: 1 addition & 1 deletion lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
if not isinstance(messages, list):
messages = [messages]
for message in messages:
message.body.hash_value = self.process(message)
message.body.result = self.process(message)
return messages

@classmethod
Expand Down
52 changes: 52 additions & 0 deletions lib/model/yake_keywords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict
import io
import urllib.request

from lib.model.model import Model

from lib import schemas

import yake

class Model(Model):
def run_yake(self, text: str,
language: str,
max_ngram_size: int,
deduplication_threshold: float,
deduplication_algo: str,
window_size: int,
num_of_keywords: int) -> str:
"""run key word/phrase extraction using Yake library in reference https://github.com/LIAAD/yake
:param text: str
:param language: str
:param max_ngram_size: int
:param deduplication_threshold: float
:param deduplication_algo: str
:param window_size: int
:param num_of_keywords: int
:returns: str
"""
custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold,
dedupFunc=deduplication_algo, windowsSize=window_size,
top=num_of_keywords, features=None)
return {"keywords": custom_kw_extractor.extract_keywords(text)}

def get_params(self, message: schemas.Message) -> dict:
params = {
"text": message.body.text,
"language": message.body.parameters.get("language", "en"),
"max_ngram_size": message.body.parameters.get("max_ngram_size", 3),
"deduplication_threshold": message.body.parameters.get("deduplication_threshold", 0.25),
"deduplication_algo": message.body.parameters.get("deduplication_algo", 'seqm'),
"window_size": message.body.parameters.get("window_size", 0),
"num_of_keywords": message.body.parameters.get("num_of_keywords", 10)
}
assert params.get("text") is not None
return params

def process(self, message: schemas.Message) -> schemas.YakeKeywordsResponse:
"""
Generic function for returning the actual response.
"""
keywords = self.run_yake(**self.get_params(message))
return keywords
2 changes: 1 addition & 1 deletion lib/queue/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def send_callbacks(self) -> List[schemas.Message]:
messages_with_queues = self.receive_messages(self.batch_size)
if messages_with_queues:
logger.info(f"About to respond to: ({messages_with_queues})")
bodies = [schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]
bodies = [schemas.parse_message(json.loads(message.body)) for message, queue in messages_with_queues]
for body in bodies:
self.send_callback(body)
self.delete_messages(messages_with_queues)
Expand Down
2 changes: 1 addition & 1 deletion lib/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def extract_messages(messages_with_queues: List[Tuple], model: Model) -> List[sc
Returns:
- List[schemas.Message]: A list of Message objects ready for processing.
"""
return [schemas.Message(**{**json.loads(message.body), **{"model_name": model.model_name}})
return [schemas.parse_message({**json.loads(message.body), **{"model_name": model.model_name}})
for message, queue in messages_with_queues]

@staticmethod
Expand Down
56 changes: 34 additions & 22 deletions lib/schemas.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
from typing import Any, List, Optional, Union
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, ValidationError
from typing import Any, Dict, List, Optional, Union

# Output hash values can be of different types.
class GenericItem(BaseModel):
id: str
callback_url: Optional[str] = None
url: Optional[str] = None
text: Optional[str] = None
raw: Optional[dict] = {}

class MediaItem(GenericItem):
class MediaResponse(BaseModel):
hash_value: Optional[Any] = None

class VideoItem(MediaItem):
class VideoResponse(MediaResponse):
folder: Optional[str] = None
filepath: Optional[str] = None

class YakeKeywordsResponse(BaseModel):
keywords: Optional[List[List[Union[str, float]]]] = None

class GenericItem(BaseModel):
id: Union[str, int, float]
callback_url: Optional[str] = None
url: Optional[str] = None
text: Optional[str] = None
raw: Optional[Dict] = {}
parameters: Optional[Dict] = {}
result: Optional[Union[MediaResponse, VideoResponse, YakeKeywordsResponse]] = None

class Message(BaseModel):
body: Union[MediaItem, VideoItem]
body: GenericItem
model_name: str
@root_validator(pre=True)
def set_body(cls, values):
body = values.get("body")
model_name = values.get("model_name")
if model_name == "video__Model":
values["body"] = VideoItem(**values["body"]).dict()
if model_name in ["audio__Model", "image__Model", "fptg__Model", "indian_sbert__Model", "mean_tokens__Model", "fasttext__Model"]:
values["body"] = MediaItem(**values["body"]).dict()
return values

def parse_message(message_data: Dict) -> Message:
body_data = message_data['body']
model_name = message_data['model_name']
result_data = body_data.get('result', {})
if 'yake_keywords' in model_name:
result_instance = YakeKeywordsResponse(**result_data)
elif 'video' in model_name:
result_instance = VideoResponse(**result_data)
else:
result_instance = MediaResponse(**result_data)
if 'result' in body_data:
del body_data['result']
body_instance = GenericItem(**body_data)
body_instance.result = result_instance
message_instance = Message(body=body_instance, model_name=model_name)
return message_instance
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ boto3==1.18.64
pyacoustid==1.2.2
sentence-transformers==2.2.2
tmkpy==0.1.1
torch==1.9.0
torch==1.13.1
transformers>=4.6.0
fastapi==0.68.1
fastapi==0.109.1
uvicorn[standard]==0.19.0
httpx==0.23.1
huggingface-hub==0.11.0
fasttext==0.9.2
langcodes==3.3.0
requests==2.31.0
pytest==7.4.0
sentry-sdk==1.30.0
sentry-sdk==1.30.0
yake==0.4.8
8 changes: 4 additions & 4 deletions test/lib/model/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def test_process_audio_success(self, mock_fingerprint_file, mock_request, mock_u

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, model_name="audio__Model")
audio = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, "model_name": "audio__Model"})
result = self.audio_model.process(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
self.assertEqual(list, type(result))
self.assertEqual(dict, type(result))

@patch('urllib.request.urlopen')
@patch('urllib.request.Request')
Expand All @@ -45,11 +45,11 @@ def test_process_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_f

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, model_name="audio__Model")
audio = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, "model_name": "audio__Model"})
result = self.audio_model.process(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
self.assertEqual([], result)
self.assertEqual({'hash_value': []}, result)

if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions test/lib/model/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download):
mock_fasttext_load_model.return_value = self.mock_model
self.mock_model.predict.return_value = (['__label__eng_Latn'], np.array([0.9]))
model = FasttextModel() # Now it uses mocked functions
query = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="fasttext__Model")]
query = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fasttext__Model"})]
response = model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.hash_value, {'language': 'en', 'script': None, 'score': 0.9})
self.assertEqual(response[0].body.result, {'language': 'en', 'script': None, 'score': 0.9})

if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions test/lib/model/test_fptg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="fptg__Model"), schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, model_name="fptg__Model")]
texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
vectors = self.model.vectorize(texts)["hash_value"]
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, model_name="fptg__Model")
query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"})
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.hash_value, [1, 2, 3])
self.assertEqual(response[0].body.result, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions test/lib/model/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="fptg__Model"), schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, model_name="fptg__Model")]
texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
vectors = self.model.vectorize(texts)["hash_value"]
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, model_name="fptg__Model")
query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"})
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.hash_value, [1, 2, 3])
self.assertEqual(response[0].body.result, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 200722c

Please sign in to comment.