diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 2e8bb56..5fb88b6 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -1,3 +1,4 @@ +import os from typing import Union, Dict, List from sentence_transformers import SentenceTransformer @@ -11,7 +12,7 @@ def __init__(self, model_name: str): """ self.model = None if model_name: - self.model = SentenceTransformer(model_name) + self.model = SentenceTransformer(model_name, cache_folder=os.getenv("MODEL_DIR", "./models")) def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.TextOutput]: """