Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV2-4719 initial idea saround caching layer in presto #92

Merged
merged 10 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env_file.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ OTEL_SERVICE_NAME=my-service-name
OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf
OTEL_EXPORTER_OTLP_ENDPOINT="https://api.honeycomb.io"
OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=XXX"
HONEYCOMB_API_ENDPOINT="https://api.honeycomb.io"
HONEYCOMB_API_ENDPOINT="https://api.honeycomb.io"
REDIS_URL="redis://redis:6379/0"
CACHE_DEFAULT_TTL=86400
4 changes: 3 additions & 1 deletion .env_file.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ OTEL_SERVICE_NAME=my-service-name
OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf
OTEL_EXPORTER_OTLP_ENDPOINT="https://api.honeycomb.io"
OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=XXX"
HONEYCOMB_API_ENDPOINT="https://api.honeycomb.io"
HONEYCOMB_API_ENDPOINT="https://api.honeycomb.io"
REDIS_URL="redis://redis:6379/0"
CACHE_DEFAULT_TTL=86400
20 changes: 20 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ services:
- OTEL_EXPORTER_OTLP_HEADERS=${OTEL_EXPORTER_OTLP_HEADERS}
- HONEYCOMB_API_KEY=${HONEYCOMB_API_KEY}
- HONEYCOMB_API_ENDPOINT=${HONEYCOMB_API_ENDPOINT}
- REDIS_URL=${REDIS_URL}
- CACHE_DEFAULT_TTL=${CACHE_DEFAULT_TTL}
env_file:
- ./.env_file
depends_on:
elasticmq:
condition: service_healthy
redis:
condition: service_healthy
links:
- elasticmq
volumes:
Expand All @@ -41,6 +45,16 @@ services:
interval: 10s
timeout: 5s
retries: 10
redis:
image: redis:latest
hostname: presto-redis
ports:
- "6379:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
image:
build: .
platform: linux/amd64
Expand All @@ -54,6 +68,8 @@ services:
depends_on:
elasticmq:
condition: service_healthy
redis:
condition: service_healthy
audio:
build: .
platform: linux/amd64
Expand All @@ -67,6 +83,8 @@ services:
depends_on:
elasticmq:
condition: service_healthy
redis:
condition: service_healthy
yake:
build: .
platform: linux/amd64
Expand All @@ -80,3 +98,5 @@ services:
depends_on:
elasticmq:
condition: service_healthy
redis:
condition: service_healthy
54 changes: 54 additions & 0 deletions lib/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import redis
import json
from typing import Any, Optional
from lib.helpers import get_environment_setting

REDIS_URL = get_environment_setting("REDIS_URL")
DEFAULT_TTL = int(get_environment_setting("CACHE_DEFAULT_TTL") or 24*60*60)
CACHE_PREFIX = "presto_media_cache:"
class Cache:
@staticmethod
def get_client() -> redis.Redis:
"""
Get a Redis client instance using the provided REDIS_URL.

Returns:
redis.Redis: Redis client instance.
"""
return redis.Redis.from_url(REDIS_URL)

@staticmethod
def get_cached_result(content_hash: str, reset_ttl: bool = True, ttl: int = DEFAULT_TTL) -> Optional[Any]:
"""
Retrieve the cached result for the given content hash. By default, reset the TTL to 24 hours.

Args:
content_hash (str): The key for the cached content.
reset_ttl (bool): Whether to reset the TTL upon access. Default is True.
ttl (int): Time-to-live for the cache in seconds. Default is 86400 seconds (24 hours).

Returns:
Optional[Any]: The cached result, or None if the key does not exist.
"""
if content_hash:
client = Cache.get_client()
cached_result = client.get(CACHE_PREFIX+content_hash)
if cached_result is not None:
if reset_ttl:
client.expire(CACHE_PREFIX+content_hash, ttl)
return json.loads(cached_result)
return None

@staticmethod
def set_cached_result(content_hash: str, result: Any, ttl: int = DEFAULT_TTL) -> None:
"""
Store the result in the cache with the given content hash and TTL.

Args:
content_hash (str): The key for the cached content.
result (Any): The result to cache.
ttl (int): Time-to-live for the cache in seconds. Default is 86400 seconds (24 hours).
"""
if content_hash:
client = Cache.get_client()
client.setex(CACHE_PREFIX+content_hash, ttl, json.dumps(result))
13 changes: 12 additions & 1 deletion lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from lib.helpers import get_class
from lib import schemas
from lib.cache import Cache
class Model(ABC):
BATCH_SIZE = 1
def __init__(self):
Expand Down Expand Up @@ -39,14 +40,24 @@ def get_tempfile(self) -> Any:
def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
return []

def get_response(self, message: schemas.Message) -> schemas.GenericItem:
"""
Perform a lookup on the cache for a message, and if found, return that cached value.
"""
result = Cache.get_cached_result(message.body.content_hash)
if not result:
result = self.process(message)
Cache.set_cached_result(message.body.content_hash, result)
return result

def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
"""
if not isinstance(messages, list):
messages = [messages]
for message in messages:
message.body.result = self.process(message)
message.body.result = self.get_response(message)
return messages

@classmethod
Expand Down
1 change: 1 addition & 0 deletions lib/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class YakeKeywordsResponse(BaseModel):

class GenericItem(BaseModel):
id: Union[str, int, float]
content_hash: Optional[str] = None
callback_url: Optional[str] = None
url: Optional[str] = None
text: Optional[str] = None
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ sentry-sdk==1.30.0
yake==0.4.8
opentelemetry-api==1.24.0
opentelemetry-exporter-otlp-proto-http==1.24.0
opentelemetry-sdk==1.24.0
opentelemetry-sdk==1.24.0
redis==5.0.6
12 changes: 10 additions & 2 deletions test/lib/model/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,23 @@ def test_tmk_program_name(self):
result = self.video_model.tmk_program_name()
self.assertEqual(result, "PrestoVideoEncoder")

def test_respond_with_single_video(self):
@patch('lib.cache.Cache.get_cached_result')
@patch('lib.cache.Cache.set_cached_result')
def test_respond_with_single_video(self, mock_cache_set, mock_cache_get):
mock_cache_get.return_value = None
mock_cache_set.return_value = True
video = schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"})
mock_process = MagicMock()
self.video_model.process = mock_process
result = self.video_model.respond(video)
mock_process.assert_called_once_with(video)
self.assertEqual(result, [video])

def test_respond_with_multiple_videos(self):
@patch('lib.cache.Cache.get_cached_result')
@patch('lib.cache.Cache.set_cached_result')
def test_respond_with_multiple_videos(self, mock_cache_set, mock_cache_get):
mock_cache_get.return_value = None
mock_cache_set.return_value = True
videos = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video2.mp4"}, "model_name": "video__Model"})]
mock_process = MagicMock()
self.video_model.process = mock_process
Expand Down
8 changes: 4 additions & 4 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,19 @@ def test_delete_messages_from_queue(self, mock_logger):
mock_logger.assert_called_with(f"Deleting message: {mock_messages[-1]}")

def test_push_message(self):
message_to_push = schemas.parse_message({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
message_to_push = schemas.parse_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
# Call push_message
returned_message = self.queue.push_message(self.queue_name_output, message_to_push)
# Check if the message was correctly serialized and sent
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model", "retry_count": 0}')
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "content_hash": null, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model", "retry_count": 0}')
self.assertEqual(returned_message, message_to_push)

def test_push_to_dead_letter_queue(self):
message_to_push = schemas.parse_message({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
message_to_push = schemas.parse_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
# Call push_to_dead_letter_queue
self.queue.push_to_dead_letter_queue(message_to_push)
# Check if the message was correctly serialized and sent to the DLQ
self.mock_dlq_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model", "retry_count": 0}')
self.mock_dlq_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": 1, "content_hash": null, "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}, "parameters": {}, "result": {"hash_value": null}}, "model_name": "mean_tokens__Model", "retry_count": 0}')

def test_increment_message_error_counts_exceed_max_retries(self):
message_body = {
Expand Down
52 changes: 52 additions & 0 deletions test/lib/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from unittest.mock import patch, MagicMock
from lib.cache import Cache

# Mock the Redis client and its methods
@pytest.fixture
def mock_redis_client():
with patch('lib.cache.redis.Redis') as mock_redis:
yield mock_redis

def test_set_cached_result(mock_redis_client):
mock_instance = mock_redis_client.from_url.return_value
content_hash = "test_hash"
result = {"data": "example"}
ttl = 3600

Cache.set_cached_result(content_hash, result, ttl)

mock_instance.setex.assert_called_once_with('presto_media_cache:'+content_hash, ttl, '{"data": "example"}')

def test_get_cached_result_exists(mock_redis_client):
mock_instance = mock_redis_client.from_url.return_value
content_hash = "test_hash"
ttl = 3600
cached_data = '{"data": "example"}'
mock_instance.get.return_value = cached_data

result = Cache.get_cached_result(content_hash, reset_ttl=True, ttl=ttl)

assert result == {"data": "example"}
mock_instance.expire.assert_called_once_with('presto_media_cache:'+content_hash, ttl)

def test_get_cached_result_not_exists(mock_redis_client):
mock_instance = mock_redis_client.from_url.return_value
content_hash = "test_hash"
mock_instance.get.return_value = None

result = Cache.get_cached_result(content_hash)

assert result is None
mock_instance.expire.assert_not_called()

def test_get_cached_result_no_ttl_reset(mock_redis_client):
mock_instance = mock_redis_client.from_url.return_value
content_hash = "test_hash"
cached_data = '{"data": "example"}'
mock_instance.get.return_value = cached_data

result = Cache.get_cached_result(content_hash, reset_ttl=False)

assert result == {"data": "example"}
mock_instance.expire.assert_not_called()
Loading