-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add KTOLoss #1864
base: main
Are you sure you want to change the base?
[WIP] Add KTOLoss #1864
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1864
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@SalmanMohammadi Should I make a full run with KTOLoss? |
Hey @krammnic. Thanks so much for opening this. This is a great start. I've been thinking about how to improve the contributor experience, and the contribution workflow for PRs which may involve design decisions i.e. touch our existing codebase in a non-trivial way, require refactoring, or generally need some carefulness to make sure we're exposing new features to users in the best way possible. So, before we dive into the specific implementation details, I'd like to take a step back. To prevent thrash and minimize back-and-forth and discovering design reconsiderations later into the PR, let's de-risk things by better understanding what we're doing here, and agree on some of the necessary steps we need to see this landed. I'm going to shamelessly plug my SimPO PR (#1223) as an example of this. Would you be willing to answer some questions for me by updating the PR description?
Generally what we're looking for is to make sure we don't find any gotchas down the line and that we can have an overview of the expected changes and outcomes for this contribution. A couple disclaimers:
|
Thanks for comment! I think there are already answers on most of your question in docstring of the KTOLoss) And in general RLHF issue. I might duplicate them here if required. Recipe should not be significantly changed. Change loss to KTOLoss. Will check about the dataset but generally in all contrastives they are pretty same |
The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as | ||
the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to | ||
encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this has been copied from the SimPO loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I used SimPO docs as template and forgot to add my docs in this part. Other comments are about KTO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
We can conclude main idea of the method from its limitations! KTO >= DPO and DPO should be used in cases with high quality data, in other cases KTO is really suitable(Actually, there are 2 versions of this loss, which gives more control on situation with your dataset, but they are minorly different, so usually included in one loss(will do it as follow up)) |
torchtune/rlhf/loss/dpo.py
Outdated
kl = torch.zeros(1).to(policy_chosen_logps.device) | ||
|
||
chosen_logratios = policy_chosen_logps - reference_chosen_logps | ||
|
||
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) | ||
|
||
rejected_logratios = policy_rejected_logps - reference_rejected_logps | ||
|
||
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) | ||
|
||
losses = torch.cat( | ||
( | ||
self.desirable_weight * chosen_losses, | ||
self.undesirable_weight * rejected_losses, | ||
), | ||
0, | ||
) | ||
|
||
chosen_rewards = self.beta * (chosen_logratios).detach() | ||
rejected_rewards = self.beta * (rejected_logratios).detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced that the KL calculation is correct. In TRL, the KL calculation is done differently (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L1100) - you can see that self.calculate_KL
is True
for the KTO loss (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L545).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same in the official KTO repo https://github.com/ContextualAI/HALOs/blob/61b9ee6787c9f52c2f4578533866e1cb42f314ec/trainers.py#L836
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I saved this mostly as drafted(version without KL) will bring it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, to actually get kl_logprobs we need to modify concatenated_forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to pass cfg in concatenated_forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can actually ignore KL divergence right now and use APO variant of loss:
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And we need to actually change structure of the batch. Also would require some processing to dataset. Do you approve such changes?
That would be very helpful to see! I'm also not convinced about the dataset format, as I've also found in the official KTO repo (https://github.com/ContextualAI/HALOs?tab=readme-ov-file) that:
So we need to think a bit about how to make sure this implementation in the DPO recipe retains parity here. Do we need use some minor modification of a preference dataset here? |
Yes, probably will require special things about data loading |
Looks like it want be difficult, but actually we can create new dataset type for this. So basically to convert pair dataset to KTO format we need something like:
And obviously remove required columns. I'm not sure where to put such preprocessing. Either new dataset type, or new argument in current dataset? |
Would you be able to outline some of the necessary steps we need to make on the dataset side?
APO is a separate (but comparable) technique to KTO, right? Is this the right paper? It looks like it came out < 2 months ago, and I see very few models on the hub mentioning APO. As I mentioned in the RLHF tracker, I'm not sure we should be implementing techniques which are new and haven't been adopted by the OS community, so I'd rather we stick with the original proposal of KTO, and I'd love to keep supporting you on this. I don't fully follow your comments on the necessary modifications needed to implement KTO, you mentioned needing to modify |
and map it to dataset
|
@SalmanMohammadi Can we continue from here? |
Hey @krammnic - thanks for summarizing here, it seems like we can make things simpler if we don't try to add APO too? I need a bit of time to look into your other questions which I may not get round to until later this week. If you have an idea of how you'd like to proceed here I'd say go for it and we can try iterate if we need to. I'm interested in your thoughts on whether you think we'll need to make any changes to the recipe itself - for example, does Out of curiosity, do you have a reference for this code so I can follow along?
|
Thanks for the answer! Yes some changes will be required as we need to get batch["reference_KL_logps"] and policy_KL_logps to calculate KL divergence(will write what we need to do here soon). Dataset processing implementation is from TRL: https://github.com/huggingface/trl/blob/main/trl/data_utils.py |
Added dataset processing, now we need to add correct KL divergence calculation |
@SalmanMohammadi Are we supporting different ref_model in current DPO recipe? As I see, we are not:
|
What do you mean by a different ref model? Because we're fine-tuning with LoRA, we can just disable the LoRA adapters to recover the original reference model without having to keep a separate copy of the model in memory. The reference model should always be the base model we're finetuning from. |
Actually in TRL there is possibility to set something different!
I'm not sure why they give this possibility either, so that why I've asked. Maybe there is some special use case? |
Ah good catch! I'm not 100% sure either. In their example scripts they always initialize with the same model (https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py). |
Actually, it is probably pretty critical to DPO, for example according to https://arxiv.org/pdf/2407.13709
I'm pretty sure that it will be actual for classic DPO for example |
I would consider this as a follow up, which we might need to add pretty soon |
About this PR. Added all required KL staff, which is not affecting the batch. KTO test right now will fail |
Tasks:
|
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
#1850
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example