diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index bf985aa..83485fc 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -81,7 +81,7 @@ def tokenize_phi3(self, prompts: list[str], role: str) -> str: inputs = [{"role": role, "content": prompt} for prompt in prompts] inputs = self.tokenizer.apply_chat_template( prompts, add_generation_prompt=True, return_tensors="pt" - ).to("cuda") + ) return inputs def generate(