Skip to content

Commit

Permalink
sotopia-pi prompt template fix (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasonqi146 authored Apr 29, 2024
1 parent 79cc507 commit 1f90e24
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
13 changes: 8 additions & 5 deletions sotopia_pi_generate.py → sotopia_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TypeVar
from functools import cache
import logging
import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
Expand All @@ -22,9 +23,10 @@
PromptTemplate,
)
from langchain.schema import BaseOutputParser, OutputParserException
import spaces

from message_classes import ActionType, AgentAction
from utils import format_docstring

from langchain_callback_handler import LoggingCallbackHandler

HF_TOKEN_KEY_FILE="./hf_token.key"
Expand Down Expand Up @@ -89,7 +91,7 @@ def prepare_model(model_name):
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
cache_dir="./.cache",
device_map='cuda'
# device_map='cuda'
)
model = PeftModel.from_pretrained(model, model_name).to("cuda")

Expand All @@ -98,7 +100,7 @@ def prepare_model(model_name):
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
cache_dir="./.cache",
device_map='cuda',
# device_map='cuda',
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
Expand All @@ -114,7 +116,7 @@ def prepare_model(model_name):
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
cache_dir="./.cache",
device_map='cuda'
# device_map='cuda'
)

else:
Expand All @@ -131,7 +133,7 @@ def obtain_chain_hf(
max_tokens: int = 2700
) -> LLMChain:
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(template=template, input_variables=input_variables)
prompt=PromptTemplate(template="[INST] " + template + " [/INST]", input_variables=input_variables)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
model, tokenizer = prepare_model(model_name)
Expand All @@ -148,6 +150,7 @@ def obtain_chain_hf(
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
return chain


def generate(
model_name: str,
template: str,
Expand Down
2 changes: 1 addition & 1 deletion sotopia_space/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
from collections import defaultdict
from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
from sotopia_pi_generate import prepare_model, generate_action
from sotopia_generate import prepare_model, generate_action
from sotopia_space.constants import MODEL_OPTIONS

DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
Expand Down
2 changes: 1 addition & 1 deletion sotopia_space/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"cmu-lti/sotopia-pi-mistral-7b-BC_SR",
"cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
"mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
# "mistralai/Mixtral-8x7B-Instruct-v0.1", # TODO: Add these model
# "togethercomputer/llama-2-7b-chat",
# "togethercomputer/llama-2-70b-chat",
# "togethercomputer/mpt-30b-chat",
Expand Down

0 comments on commit 1f90e24

Please sign in to comment.