Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add online trainers #204

Merged
merged 74 commits into from
Sep 13, 2024
Merged

Add online trainers #204

merged 74 commits into from
Sep 13, 2024

Conversation

vwxyzjn
Copy link
Collaborator

@vwxyzjn vwxyzjn commented Jul 24, 2024

The online trainers are ready for review! The docs are available here if you want to try: https://github.com/allenai/open-instruct/blob/online-trainers/docs/algorithms/online_dpo.md

Check out this wandb report: https://wandb.ai/ai2-llm/open_instruct_internal/reports/PPO-vs-online-DPO--Vmlldzo5MzM3NDU0

Screen.Recording.2024-09-11.at.4.40.59.PM.mov

Implemented auto resume as well, but only tested with small models. Larger models may take more time to save and is blocked by https://github.com/allenai/beaker/issues/5420

image

with torch.no_grad():
queries = data["input_ids"].to(device)
# repeat interleave [q1, q2, q3] -> [q1, q1, q1, q2, q2, q2, q3, q3, q3]
queries = queries.repeat_interleave(args.num_generation_per_prompt, 0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is how the online SFT logic works:

we first repeat interleave the queries like [q1, q2, q3] -> [q1, q1, q1, q2, q2, q2, q3, q3, q3]. This means queries[0] == queries[1].

Comment on lines 379 to 397
# online SFT logic:
# at this point, we have interleave repeated the queries for `args.num_generation_per_prompt` times
# say, we have 64 queries repeated 10 times, we have 640 responses
# we reshape the scores to (10, 64), so each query has 10 scores.
# we then find the index of each query's best response
scores_per_query = scores.reshape(args.num_generation_per_prompt, args.local_batch_size)
best_idxes = scores_per_query.argmax(0)
worst_idxes = scores_per_query.argmin(0)
best_idxes_offset = (
best_idxes + torch.arange(args.local_batch_size, device=device) * args.num_generation_per_prompt
)
worst_idxes_offset = (
worst_idxes + torch.arange(args.local_batch_size, device=device) * args.num_generation_per_prompt
)
best_query_responses = query_responses[best_idxes_offset]
# worst_query_responses = query_responses[worst_idxes_offset] TODO: maybe interesting to see the worse responses.
best_scores = scores[best_idxes_offset]
worst_scores = scores[worst_idxes_offset]
scores_margin = best_scores - worst_scores
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we basically find out the indices corresponding to the best scores and get the best query responses for online SFT.

@vwxyzjn vwxyzjn marked this pull request as ready for review September 10, 2024 21:50
Copy link
Contributor

@ValentinaPy ValentinaPy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good to me! I'll test it out soon

Copy link
Contributor

@nouhadziri nouhadziri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great Costa, I have some nitpick about dop_utils.py for DPO data processing in L173 which was contributed by Nathan and Jacob I guess. I usually prefer processing the data (in this case concatenating chosen and rejected ids) before passing the data to DataLoader to avoid looping over the batch. But all good since the tensors are not placed in the GPUs yet

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Sep 13, 2024

@nouhadziri, that's a good point. We can look into refactoring dpo_utils.py after the current projects :)

Merging as is now. Thanks @nouhadziri and @ValentinaPy for review.

@vwxyzjn vwxyzjn merged commit 649c9e3 into main Sep 13, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants