Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andreea-popescu-reef committed Sep 5, 2024
1 parent 96bd34e commit 0fd660f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
9 changes: 8 additions & 1 deletion src/compute_horde_prompt_gen/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
AutoModelForCausalLM,
)

from model import LLAMA3, PHI3

MODEL_PATHS = {
LLAMA3: "meta-llama/Meta-Llama-3.1-8B-Instruct",
PHI3: "microsoft/Phi-3-mini-4k-instruct",
}

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Save huggingface model")
parser.add_argument(
Expand All @@ -29,7 +36,7 @@
args = parser.parse_args()

model = AutoModelForCausalLM.from_pretrained(
args.model_name,
MODEL_PATHS[args.model_name],
# either give token directly or assume logged in with huggingface-cli
token=args.huggingface_token or True,
)
Expand Down
8 changes: 2 additions & 6 deletions src/compute_horde_prompt_gen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

LLAMA3 = "llama3"
PHI3 = "phi3"
MODELS = {
LLAMA3: "meta-llama/Meta-Llama-3.1-8B-Instruct",
PHI3: "microsoft/Phi-3-mini-4k-instruct",
}

PROMPT_ENDING = " }}assistant"

Expand Down Expand Up @@ -58,7 +54,7 @@ def __init__(self, model_name: str, model_path: str, quantize: bool = False):
if self.model_name == LLAMA3:
self.tokenizer.pad_token = self.tokenizer.eos_token

def tokenize_llama3(self, prompts: str, role: str) -> str:
def tokenize_llama3(self, prompts: list[str], role: str) -> str:
role_templates = {
"system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
Expand All @@ -81,7 +77,7 @@ def tokenize(prompt: str) -> str:
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
return inputs

def tokenize_phi3(self, prompts: str, role: str) -> str:
def tokenize_phi3(self, prompts: list[str], role: str) -> str:
inputs = [{"role": role, "content": prompt} for prompt in prompts]
inputs = self.tokenizer.apply_chat_template(
prompts, add_generation_prompt=True, return_tensors="pt"
Expand Down

0 comments on commit 0fd660f

Please sign in to comment.