-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add online trainers #204
Add online trainers #204
Conversation
open_instruct/online_sft_trainer.py
Outdated
with torch.no_grad(): | ||
queries = data["input_ids"].to(device) | ||
# repeat interleave [q1, q2, q3] -> [q1, q1, q1, q2, q2, q2, q3, q3, q3] | ||
queries = queries.repeat_interleave(args.num_generation_per_prompt, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is how the online SFT logic works:
we first repeat interleave the queries like [q1, q2, q3] -> [q1, q1, q1, q2, q2, q2, q3, q3, q3]
. This means queries[0] == queries[1]
.
open_instruct/online_sft_trainer.py
Outdated
# online SFT logic: | ||
# at this point, we have interleave repeated the queries for `args.num_generation_per_prompt` times | ||
# say, we have 64 queries repeated 10 times, we have 640 responses | ||
# we reshape the scores to (10, 64), so each query has 10 scores. | ||
# we then find the index of each query's best response | ||
scores_per_query = scores.reshape(args.num_generation_per_prompt, args.local_batch_size) | ||
best_idxes = scores_per_query.argmax(0) | ||
worst_idxes = scores_per_query.argmin(0) | ||
best_idxes_offset = ( | ||
best_idxes + torch.arange(args.local_batch_size, device=device) * args.num_generation_per_prompt | ||
) | ||
worst_idxes_offset = ( | ||
worst_idxes + torch.arange(args.local_batch_size, device=device) * args.num_generation_per_prompt | ||
) | ||
best_query_responses = query_responses[best_idxes_offset] | ||
# worst_query_responses = query_responses[worst_idxes_offset] TODO: maybe interesting to see the worse responses. | ||
best_scores = scores[best_idxes_offset] | ||
worst_scores = scores[worst_idxes_offset] | ||
scores_margin = best_scores - worst_scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we basically find out the indices corresponding to the best scores and get the best query responses for online SFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good to me! I'll test it out soon
732e3e4
to
34d5bbc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great Costa, I have some nitpick about dop_utils.py
for DPO data processing in L173 which was contributed by Nathan and Jacob I guess. I usually prefer processing the data (in this case concatenating chosen and rejected ids) before passing the data to DataLoader to avoid looping over the batch. But all good since the tensors are not placed in the GPUs yet
@nouhadziri, that's a good point. We can look into refactoring Merging as is now. Thanks @nouhadziri and @ValentinaPy for review. |
The online trainers are ready for review! The docs are available here if you want to try: https://github.com/allenai/open-instruct/blob/online-trainers/docs/algorithms/online_dpo.md
Check out this wandb report: https://wandb.ai/ai2-llm/open_instruct_internal/reports/PPO-vs-online-DPO--Vmlldzo5MzM3NDU0
Screen.Recording.2024-09-11.at.4.40.59.PM.mov
Implemented auto resume as well, but only tested with small models. Larger models may take more time to save and is blocked by https://github.com/allenai/beaker/issues/5420