diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index 44370b1..6d0070d 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -14,6 +14,7 @@ class TestQueueWorker(unittest.TestCase): @patch('lib.helpers.get_environment_setting', return_value='us-west-1') def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queues_by_suffix): self.model = GenericTransformerModel(None) + self.model.model_name = "generic" self.mock_model = MagicMock() self.queue_name_input = 'mean_tokens__Model' self.queue_name_output = 'mean_tokens__Model_output' @@ -113,7 +114,7 @@ def test_extract_messages(self): (FakeSQSMessage(receipt_handle="blah", body=json.dumps({"text": "Test message 1", "model_name": "TestModel"})), self.mock_input_queue), (FakeSQSMessage(receipt_handle="blah", body=json.dumps({"text": "Test message 2", "model_name": "TestModel"})), self.mock_input_queue) ] - extracted_messages = QueueWorker.extract_messages(messages_with_queues) + extracted_messages = QueueWorker.extract_messages(messages_with_queues, self.model) self.assertEqual(len(extracted_messages), 2) self.assertEqual(extracted_messages[0].text, "Test message 1") self.assertEqual(extracted_messages[1].text, "Test message 2")