-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
奖励模型问题 #926
Comments
在学习,能告诉我,这个强化学习模型怎么运行的吗 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
我用脚本评分后,出现的分数为零,请问这是什么情况呢?下面是我的评分代码:
import torch
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained(
"/root/autodl-tmp/xtuner/work_dirs/internlm2_chat_1_8b_reward_qlora_varlenattn_ultrafeedback_copy/iter_15230_hf",
device_map="cuda",
torch_dtype=torch.float16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/xtuner/work_dirs/internlm2_chat_1_8b_reward_qlora_varlenattn_ultrafeedback_copy/iter_15230_hf", trust_remote_code=True)
chat_1 = [
{"role": "user", "content": "Hello! What's your name?"},
{"role": "assistant", "content": "My name is InternLM2! A helpful AI assistant. What can I do for you?"}
]
chat_2 = [
{"role": "user", "content": "Hello! What's your name?"},
{"role": "assistant", "content": "I have no idea."}
]
get reward score for a single chat
score1 = model.get_score(tokenizer, chat_1)
score2 = model.get_score(tokenizer, chat_2)
print("score1: ", score1)
print("score2: ", score2)
>>> score1: 0.767578125
>>> score2: -2.22265625
batch inference, get multiple scores at once
scores = model.get_scores(tokenizer, [chat_1, chat_2])
print("scores: ", scores)
>>> scores: [0.767578125, -2.22265625]
compare whether chat_1 is better than chat_2
compare_res = model.compare(tokenizer, chat_1, chat_2)
print("compare_res: ", compare_res)
>>> compare_res: True
rank multiple chats, it will return the ranking index of each chat
the chat with the highest score will have ranking index as 0
rank_res = model.rank(tokenizer, [chat_1, chat_2])
print("rank_res: ", rank_res) # lower index means higher score
>>> rank_res: [0, 1]
The text was updated successfully, but these errors were encountered: