diff --git a/.env_file b/.env_file index 469bb0d..4258271 100644 --- a/.env_file +++ b/.env_file @@ -1,9 +1,11 @@ QUEUE_PREFIX=dev_ PRESTO_PORT=8000 DEPLOY_ENV=local -# MODEL_NAME=mean_tokens.Model -MODEL_NAME=audio.Model +MODEL_NAME=mean_tokens.Model +# MODEL_NAME=audio.Model S3_ENDPOINT=http://minio:9000 AWS_DEFAULT_REGION=us-east-1 AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +NUM_WORKERS=1 +#QUEUE_SUFFIX=.fifo diff --git a/.env_file.test b/.env_file.test index 469bb0d..3ff8b81 100644 --- a/.env_file.test +++ b/.env_file.test @@ -1,8 +1,8 @@ QUEUE_PREFIX=dev_ PRESTO_PORT=8000 DEPLOY_ENV=local -# MODEL_NAME=mean_tokens.Model -MODEL_NAME=audio.Model +MODEL_NAME=mean_tokens.Model +#MODEL_NAME=audio.Model S3_ENDPOINT=http://minio:9000 AWS_DEFAULT_REGION=us-east-1 AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE diff --git a/.github/workflows/ci-deploy.yml b/.github/workflows/ci-deploy.yml index dbaa26e..20a5577 100644 --- a/.github/workflows/ci-deploy.yml +++ b/.github/workflows/ci-deploy.yml @@ -4,6 +4,8 @@ on: push: branches: - 'master' + tags: + - 'v*' permissions: id-token: write @@ -79,6 +81,7 @@ jobs: - name: Kick off Terraform deploy in sysops/ id: sysops-deploy + if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads/master') run: | curl \ -X POST \ @@ -86,66 +89,19 @@ jobs: -H "Authorization: Bearer ${{ secrets.SYSOPS_RW_GITHUB_TOKEN }}" \ -H "X-GitHub-Api-Version: 2022-11-28" \ https://api.github.com/repos/meedan/sysops/actions/workflows/deploy_${{ github.event.repository.name }}.yml/dispatches \ - -d '{"ref": "master", "inputs": {"git_sha": "${{ github.sha }}"}}' + -d '{"ref": "master", "inputs": {"git_sha": "${{ github.sha }}", "type": "push"}}' - - name: Send GitHub Action trigger data to Slack workflow on success - id: slack-api-notify-success - if: ${{ success() }} - uses: slackapi/slack-github-action@v1.23.0 - with: - payload: | - { - "attachments": [ - { - "color": "#00FF00", - "blocks": [ - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": "Kicked off by: ${{ github.triggering_actor }}\nWorkflow: https://github.com/meedan/presto/actions/runs/${{ github.run_id }}" - } - }, - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": "Presto Deploy:\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}" - } - } - ] - } - ] - } - env: - SLACK_WEBHOOK_URL: ${{ secrets.CHECK_DEV_BOTS_SLACK_WEBHOOK_URL }} - SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK - - - name: Send GitHub Action trigger data to Slack workflow on failure - id: slack-api-notify-failure - if: ${{ failure() }} - uses: slackapi/slack-github-action@v1.23.0 - with: - payload: | - { - "attachments": [ - { - "color": "#FF0000", - "blocks": [ - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": "Presto Deploy failed\nWorkflow: https://github.com/meedan/presto/actions/runs/${{ github.run_id }}" - } - } - ] - } - ] - } - env: - SLACK_WEBHOOK_URL: ${{ secrets.ITS_BOTS_SLACK_WEBHOOK_URL }} - SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK + - name: Kick off Terraform deploy in sysops/ + id: sysops-deploy-live + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') + run: | + curl \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.SYSOPS_RW_GITHUB_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/meedan/sysops/actions/workflows/deploy_${{ github.event.repository.name }}.yml/dispatches \ + -d '{"ref": "master", "inputs": {"git_sha": "${{ github.sha }}", "type": "tag"}}' - name: Reset cache id: reset-cache diff --git a/docker-compose.yml b/docker-compose.yml index c41a2d8..dab0773 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,7 @@ services: - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION} - S3_ENDPOINT=${S3_ENDPOINT} + - QUEUE_SUFFIX=${QUEUE_SUFFIX} env_file: - ./.env_file depends_on: @@ -23,4 +24,4 @@ services: image: softwaremill/elasticmq hostname: presto-elasticmq ports: - - "9324:9324" \ No newline at end of file + - "9324:9324" diff --git a/lib/helpers.py b/lib/helpers.py index bc1d350..da9367a 100644 --- a/lib/helpers.py +++ b/lib/helpers.py @@ -5,7 +5,7 @@ def get_environment_setting(os_key: str) -> str: """ Get environment variable helper. Could be augmented with credential store if/when necessary. """ - return os.environ.get(os_key) + return os.environ.get(os_key, "") or "" def get_setting(current_value: Any, default_os_key: str) -> Any: """ @@ -22,3 +22,4 @@ def get_class(prefix: str, class_name: str) -> Any: module = prefix+str.join(".", class_name.split('.')[:-1]) module_obj = importlib.import_module(module) return getattr(module_obj, class_name.split('.')[-1]) + diff --git a/lib/http.py b/lib/http.py index bfb0818..01757b3 100644 --- a/lib/http.py +++ b/lib/http.py @@ -34,8 +34,9 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]: def process_item(process_name: str, message: Dict[str, Any]): logger.info(message) queue_prefix = Queue.get_queue_prefix() + queue_suffix = Queue.get_queue_suffix() queue = QueueWorker.create(process_name) - queue.push_message(f"{queue_prefix}{process_name}", schemas.Message(body=message, model_name=process_name)) + queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.Message(body=message, model_name=process_name)) return {"message": "Message pushed successfully", "queue": process_name, "body": message} @app.post("/trigger_callback") diff --git a/lib/logger.py b/lib/logger.py index 3fe4bc0..2525941 100644 --- a/lib/logger.py +++ b/lib/logger.py @@ -35,4 +35,4 @@ logging.config.dictConfig(LOGGING_CONFIG) # This provides an easily accessible logger for other modules -logger = logging.getLogger(__name__) \ No newline at end of file +logger = logging.getLogger(__name__) diff --git a/lib/model/audio.py b/lib/model/audio.py index c3cdbd2..613dedc 100644 --- a/lib/model/audio.py +++ b/lib/model/audio.py @@ -26,4 +26,4 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]: hash_value = self.audio_hasher(temp_file_name) finally: os.remove(temp_file_name) - return hash_value \ No newline at end of file + return hash_value diff --git a/lib/model/fasttext.py b/lib/model/fasttext.py index 71e076f..5081c5a 100644 --- a/lib/model/fasttext.py +++ b/lib/model/fasttext.py @@ -43,4 +43,4 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s for doc, detected_lang in zip(docs, detected_langs): doc.body.hash_value = detected_lang - return docs \ No newline at end of file + return docs diff --git a/lib/model/fptg.py b/lib/model/fptg.py index a66b70d..2c0731f 100644 --- a/lib/model/fptg.py +++ b/lib/model/fptg.py @@ -6,4 +6,4 @@ def __init__(self): """ Init FPTG model. Fairly standard for all vectorizers. """ - super().__init__(MODEL_NAME) \ No newline at end of file + super().__init__(MODEL_NAME) diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 40ad56c..94b831b 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -33,4 +33,4 @@ def vectorize(self, texts: List[str]) -> List[List[float]]: """ Vectorize the text! Run as batch. """ - return self.model.encode(texts).tolist() \ No newline at end of file + return self.model.encode(texts).tolist() diff --git a/lib/model/image.py b/lib/model/image.py index 80d0647..346e320 100644 --- a/lib/model/image.py +++ b/lib/model/image.py @@ -34,4 +34,4 @@ def process(self, image: schemas.Message) -> schemas.GenericItem: """ Generic function for returning the actual response. """ - return self.compute_pdq(self.get_iobytes_for_image(image)) \ No newline at end of file + return self.compute_pdq(self.get_iobytes_for_image(image)) diff --git a/lib/model/indian_sbert.py b/lib/model/indian_sbert.py index a20d681..db529ba 100644 --- a/lib/model/indian_sbert.py +++ b/lib/model/indian_sbert.py @@ -6,4 +6,4 @@ def __init__(self): """ Init IndianSbert model. Fairly standard for all vectorizers. """ - super().__init__(MODEL_NAME) \ No newline at end of file + super().__init__(MODEL_NAME) diff --git a/lib/model/mean_tokens.py b/lib/model/mean_tokens.py index f8094a9..a3f77e0 100644 --- a/lib/model/mean_tokens.py +++ b/lib/model/mean_tokens.py @@ -6,4 +6,4 @@ def __init__(self): """ Init MeanTokens model. Fairly standard for all vectorizers. """ - super().__init__(MODEL_NAME) \ No newline at end of file + super().__init__(MODEL_NAME) diff --git a/lib/model/model.py b/lib/model/model.py index 312aab6..eb99df7 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -55,4 +55,4 @@ def create(cls): abstraction for loading model based on os environment-specified model. """ model = get_class('lib.model.', os.environ.get('MODEL_NAME')) - return model() \ No newline at end of file + return model() diff --git a/lib/model/video.py b/lib/model/video.py index 517702f..2039c69 100644 --- a/lib/model/video.py +++ b/lib/model/video.py @@ -63,4 +63,4 @@ def process(self, video: schemas.Message) -> schemas.GenericItem: for file_path in [self.tmk_file_path(video_filename), temp_file_name]: if os.path.exists(file_path): os.remove(file_path) - return {"folder": self.tmk_bucket(), "filepath": self.tmk_file_path(video_filename), "hash_value": hash_value} \ No newline at end of file + return {"folder": self.tmk_bucket(), "filepath": self.tmk_file_path(video_filename), "hash_value": hash_value} diff --git a/lib/queue/processor.py b/lib/queue/processor.py index 3e3ba44..f655a88 100644 --- a/lib/queue/processor.py +++ b/lib/queue/processor.py @@ -5,16 +5,15 @@ from lib import schemas from lib.logger import logger -from lib.helpers import get_setting from lib.queue.queue import Queue class QueueProcessor(Queue): @classmethod - def create(cls, input_queue_name: str = None, batch_size: int = 10): + def create(cls, model_name: str = None, batch_size: int = 10): """ Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size. Pulls settings and then inits instance. """ - input_queue_name = Queue.get_queue_name(input_queue_name) + input_queue_name = Queue.get_output_queue_name(model_name) logger.info(f"Starting queue with: ('{input_queue_name}', {batch_size})") return QueueProcessor(input_queue_name, batch_size) @@ -24,7 +23,7 @@ def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_s """ super().__init__() self.input_queue_name = input_queue_name - self.input_queues = self.restrict_queues_to_suffix(self.get_or_create_queues(input_queue_name+"_output"), "_output") + self.input_queues = self.restrict_queues_to_suffix(self.get_or_create_queues(input_queue_name), Queue.get_queue_suffix()) self.all_queues = self.store_queue_map(self.input_queues) logger.info(f"Processor listening to queues of {self.all_queues}") self.batch_size = batch_size @@ -53,4 +52,4 @@ def send_callback(self, message): callback_url = message.body.callback_url requests.post(callback_url, json=message.dict()) except Exception as e: - logger.error(f"Callback fail! Failed with {e} on {callback_url} with message of {message}") \ No newline at end of file + logger.error(f"Callback fail! Failed with {e} on {callback_url} with message of {message}") diff --git a/lib/queue/queue.py b/lib/queue/queue.py index 382a8e8..6409562 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -5,7 +5,7 @@ import boto3 import botocore -from lib.helpers import get_setting, get_environment_setting +from lib.helpers import get_environment_setting from lib.logger import logger from lib import schemas SQS_MAX_BATCH_SIZE = 10 @@ -21,8 +21,18 @@ def get_queue_prefix(): return (get_environment_setting("QUEUE_PREFIX") or "").replace(".", "__") @staticmethod - def get_queue_name(input_queue_name): - return Queue.get_queue_prefix()+get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") + def get_queue_suffix(): + return (get_environment_setting("QUEUE_SUFFIX") or "") + + @staticmethod + def get_input_queue_name(model_name=None): + name = model_name or get_environment_setting("MODEL_NAME").replace(".", "__") + return Queue.get_queue_prefix()+name+Queue.get_queue_suffix() + + @staticmethod + def get_output_queue_name(model_name=None): + name = model_name or get_environment_setting("MODEL_NAME").replace(".", "__") + return Queue.get_queue_prefix()+name+"_output"+Queue.get_queue_suffix() def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]) -> Dict[str, boto3.resources.base.ServiceResource]: """ @@ -43,7 +53,7 @@ def restrict_queues_to_suffix(self, queues: List[boto3.resources.base.ServiceRes """ When plucking input queues, we want to omit any queues that are our paired suffix queues.. """ - return [queue for queue in queues if self.queue_name(queue).endswith(suffix)] + return [queue for queue in queues if not suffix or suffix and self.queue_name(queue).endswith(suffix)] def restrict_queues_by_suffix(self, queues: List[boto3.resources.base.ServiceResource], suffix: str) -> List[boto3.resources.base.ServiceResource]: """ @@ -56,7 +66,16 @@ def create_queue(self, queue_name: str) -> boto3.resources.base.ServiceResource: Create queue by name - may not work in production owing to permissions - mostly a local convenience function """ logger.info(f"Queue {queue_name} doesn't exist - creating") - return self.sqs.create_queue(QueueName=queue_name) + attributes = {} + if queue_name.endswith('.fifo'): + attributes['FifoQueue'] = 'true' + # Optionally enable content-based deduplication for FIFO queues + attributes['ContentBasedDeduplication'] = 'true' + # Include other FIFO-specific attributes as needed + return self.sqs.create_queue( + QueueName=queue_name, + Attributes=attributes + ) def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.ServiceResource]: """ @@ -65,6 +84,8 @@ def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.Ser try: found_queues = [q for q in self.sqs.queues.filter(QueueNamePrefix=queue_name)] exact_match_queues = [queue for queue in found_queues if queue.attributes['QueueArn'].split(':')[-1] == queue_name] + logger.info(f"found queues are {found_queues}") + logger.info(f"exact queues are {exact_match_queues}") if exact_match_queues: return exact_match_queues else: @@ -91,14 +112,6 @@ def get_sqs(self) -> boto3.resources.base.ServiceResource: logger.info(f"Using SQS Interface") return boto3.resource('sqs', region_name=get_environment_setting("AWS_DEFAULT_REGION")) - def get_output_queue_name(self, input_queue_name: str, output_queue_name: str = None) -> str: - """ - If output_queue_name was empty or None, set name for queue. - """ - if not output_queue_name: - output_queue_name = f'{input_queue_name}_output' - return output_queue_name - def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]: """ Group deletions so that we can run through a simplified set of batches rather than delete each item independently @@ -162,5 +175,8 @@ def push_message(self, queue_name: str, message: schemas.Message) -> schemas.Mes """ Actual SQS logic for pushing a message to a queue """ - self.find_queue_by_name(queue_name).send_message(MessageBody=json.dumps(message.dict())) - return message \ No newline at end of file + message_data = {"MessageBody": json.dumps(message.dict())} + if queue_name.endswith('.fifo'): + message_data["MessageGroupId"] = message.body.id + self.find_queue_by_name(queue_name).send_message(**message_data) + return message diff --git a/lib/queue/worker.py b/lib/queue/worker.py index 5fb3adc..c67d626 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -6,17 +6,16 @@ from lib.logger import logger from lib.queue.queue import Queue from lib.model.model import Model -from lib.helpers import get_setting TIMEOUT_SECONDS = int(os.getenv("WORK_TIMEOUT_SECONDS", "60")) class QueueWorker(Queue): @classmethod - def create(cls, input_queue_name: str = None): + def create(cls, model_name: str = None): """ Instantiate a queue worker. Must pass input_queue_name. Pulls settings and then inits instance. """ - input_queue_name = Queue.get_queue_name(input_queue_name) - output_queue_name = f"{input_queue_name}_output" + input_queue_name = Queue.get_input_queue_name(model_name) + output_queue_name = Queue.get_output_queue_name(model_name) logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}')") return QueueWorker(input_queue_name, output_queue_name) @@ -26,9 +25,10 @@ def __init__(self, input_queue_name: str, output_queue_name: str = None): """ super().__init__() self.input_queue_name = input_queue_name - self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output") + q_suffix = f"_output" + Queue.get_queue_suffix() + self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix) if output_queue_name: - self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name) + self.output_queue_name = Queue.get_output_queue_name() self.output_queues = self.get_or_create_queues(output_queue_name) self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues] for item in row]) logger.info(f"Worker listening to queues of {self.all_queues}") @@ -116,4 +116,4 @@ def delete_processed_messages(self, messages_with_queues: List[Tuple]): Parameters: - messages_with_queues (List[Tuple]): A list of tuples, each containing a message and its corresponding queue, to be deleted. """ - self.delete_messages(messages_with_queues) \ No newline at end of file + self.delete_messages(messages_with_queues) diff --git a/lib/s3.py b/lib/s3.py index 3fa83ae..f2a59f4 100644 --- a/lib/s3.py +++ b/lib/s3.py @@ -38,4 +38,4 @@ def upload_file_to_s3(bucket: str, filename: str): s3_client.upload_file(filename, bucket, file_name) logger.info(f'Successfully uploaded file {file_name} to S3 bucket.') except Exception as e: - logger.error(f'Failed to upload file {file_name} to S3 bucket: {e}') \ No newline at end of file + logger.error(f'Failed to upload file {file_name} to S3 bucket: {e}') diff --git a/lib/schemas.py b/lib/schemas.py index 2077c21..0969b5c 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -27,4 +27,4 @@ def set_body(cls, values): values["body"] = VideoItem(**values["body"]).dict() if model_name in ["audio__Model", "image__Model", "fptg__Model", "indian_sbert__Model", "mean_tokens__Model", "fasttext__Model"]: values["body"] = MediaItem(**values["body"]).dict() - return values \ No newline at end of file + return values diff --git a/lib/sentry.py b/lib/sentry.py index 6dd3058..8804a1e 100644 --- a/lib/sentry.py +++ b/lib/sentry.py @@ -6,4 +6,4 @@ dsn=get_environment_setting('sentry_sdk_dsn'), environment=get_environment_setting("DEPLOY_ENV"), traces_sample_rate=1.0, -) \ No newline at end of file +) diff --git a/main.py b/main.py index 68a18a9..5bd7699 100644 --- a/main.py +++ b/main.py @@ -1 +1 @@ -from lib.http import app \ No newline at end of file +from lib.http import app diff --git a/run_worker.py b/run_worker.py index 549239e..fc7353e 100644 --- a/run_worker.py +++ b/run_worker.py @@ -11,4 +11,4 @@ logger.info("Beginning work loop...") while True: - queue.process(model) \ No newline at end of file + queue.process(model) diff --git a/start_all.sh b/start_all.sh index dcbe097..5d20a58 100755 --- a/start_all.sh +++ b/start_all.sh @@ -18,4 +18,4 @@ do ) &. # run workers as background processes done -python run_processor.py \ No newline at end of file +python run_processor.py diff --git a/test/lib/model/test_audio.py b/test/lib/model/test_audio.py index 8d32594..48c7bb5 100644 --- a/test/lib/model/test_audio.py +++ b/test/lib/model/test_audio.py @@ -52,4 +52,4 @@ def test_process_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_f self.assertEqual([], result) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/model/test_fasttext.py b/test/lib/model/test_fasttext.py index d14b56a..07770da 100644 --- a/test/lib/model/test_fasttext.py +++ b/test/lib/model/test_fasttext.py @@ -21,4 +21,4 @@ def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download): self.assertEqual(response[0].body.hash_value, {'language': 'en', 'script': None, 'score': 0.9}) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/model/test_image.py b/test/lib/model/test_image.py index e376dc6..acd9201 100644 --- a/test/lib/model/test_image.py +++ b/test/lib/model/test_image.py @@ -48,4 +48,4 @@ def test_process(self, mock_compute_pdq, mock_get_iobytes_for_image): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/queue/test_processor.py b/test/lib/queue/test_processor.py index 74518a0..f258818 100644 --- a/test/lib/queue/test_processor.py +++ b/test/lib/queue/test_processor.py @@ -52,4 +52,4 @@ def test_send_callback_failure(self, mock_post): self.assertIn("Failed with Request Failed! on http://example.com with message of", cm.output[0]) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index 3bdfd89..6adcaf2 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -5,6 +5,7 @@ import numpy as np from lib.model.generic_transformer import GenericTransformerModel +from lib.queue.queue import Queue from lib.queue.worker import QueueWorker from lib import schemas from test.lib.queue.fake_sqs_message import FakeSQSMessage @@ -17,17 +18,17 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue self.model = GenericTransformerModel(None) self.model.model_name = "generic" self.mock_model = MagicMock() - self.queue_name_input = 'mean_tokens__Model' - self.queue_name_output = 'mean_tokens__Model_output' + self.queue_name_input = Queue.get_input_queue_name() + self.queue_name_output = Queue.get_output_queue_name() # Mock the SQS resource and the queues self.mock_sqs_resource = MagicMock() self.mock_input_queue = MagicMock() - self.mock_input_queue.url = "http://queue/mean_tokens__Model" - self.mock_input_queue.attributes = {"QueueArn": "queue:mean_tokens__Model"} + self.mock_input_queue.url = f"http://queue/{self.queue_name_input}" + self.mock_input_queue.attributes = {"QueueArn": f"queue:{self.queue_name_input}"} self.mock_output_queue = MagicMock() - self.mock_output_queue.url = "http://queue/mean_tokens__Model_output" - self.mock_output_queue.attributes = {"QueueArn": "queue:mean_tokens__Model_output"} + self.mock_output_queue.url = f"http://queue/{self.queue_name_output}" + self.mock_output_queue.attributes = {"QueueArn": f"queue:{self.queue_name_output}"} self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue] mock_boto_resource.return_value = self.mock_sqs_resource @@ -35,8 +36,7 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue self.queue = QueueWorker(self.queue_name_input, self.queue_name_output) def test_get_output_queue_name(self): - self.assertEqual(self.queue.get_output_queue_name('test'), 'test_output') - self.assertEqual(self.queue.get_output_queue_name('test', 'new-output'), 'new-output') + self.assertEqual(self.queue.get_output_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_output').replace(".fifo", "")) def test_process(self): self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), self.mock_input_queue)]) @@ -163,4 +163,4 @@ def test_delete_processed_messages(self, mock_delete_messages): mock_delete_messages.assert_called_once_with(messages_with_queues) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()