Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add KTOLoss #1864

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

[WIP] Add KTOLoss #1864

wants to merge 8 commits into from

Conversation

krammnic
Copy link
Contributor

@krammnic krammnic commented Oct 18, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#1850

Changelog

What are the changes made in this PR?

  • Adding KTO loss to torchtune

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1864

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 18, 2024
@krammnic
Copy link
Contributor Author

krammnic commented Oct 18, 2024

  • Run some recipe with this loss.
  • Add APO loss variant(It is pretty minor, but in current design it's not that trivial)
  • KL divergence?

@krammnic
Copy link
Contributor Author

@SalmanMohammadi Should I make a full run with KTOLoss?

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 18, 2024

Hey @krammnic. Thanks so much for opening this. This is a great start.

I've been thinking about how to improve the contributor experience, and the contribution workflow for PRs which may involve design decisions i.e. touch our existing codebase in a non-trivial way, require refactoring, or generally need some carefulness to make sure we're exposing new features to users in the best way possible.

So, before we dive into the specific implementation details, I'd like to take a step back. To prevent thrash and minimize back-and-forth and discovering design reconsiderations later into the PR, let's de-risk things by better understanding what we're doing here, and agree on some of the necessary steps we need to see this landed. I'm going to shamelessly plug my SimPO PR (#1223) as an example of this.

Would you be willing to answer some questions for me by updating the PR description?

  • Could you provide a very very high-level overview of what this method is? What are the main ways it differs from e.g. DPO? You could provide some information from the paper here.
  • Have you used a reference implementation? If so, which one? Code pointers here are very useful. I like to use TRL (https://github.com/huggingface/trl) which I believe has a KTOTrainer, but others also use official repos from paper authors.
  • What would integrating this into the DPO recipes look like? Does it fit neatly into the existing DPO logic, or do we need some specific logic in the recipe for KTO?
    • From skimming the KTOTrainer docpage, it looks like TRL specify a different dataset format which differs from our preference dataset format in a meaningful way. As above with outlining the main differences between KTO and DPO, how would this different dataset format be used?

Generally what we're looking for is to make sure we don't find any gotchas down the line and that we can have an overview of the expected changes and outcomes for this contribution.

A couple disclaimers:

  • If this sounds like overkill - let me know! If you feel like there's a different way you'd like to go about it, please feel free! Similarly, I'm more than happy to help investigate on some of these points. You can also drop me a line on Discord : )
  • Also, if this seems like a lot of effort, we can come back to this PR and you could try one of the RLHF techniques that you mentioned which may be easier to implement into our codebase. That's totally fine.

@krammnic
Copy link
Contributor Author

Thanks for comment! I think there are already answers on most of your question in docstring of the KTOLoss) And in general RLHF issue. I might duplicate them here if required.

Recipe should not be significantly changed. Change loss to KTOLoss. Will check about the dataset but generally in all contrastives they are pretty same

Comment on lines +241 to +244
The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as
the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to
encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this has been copied from the SimPO loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I used SimPO docs as template and forgot to add my docs in this part. Other comments are about KTO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@krammnic
Copy link
Contributor Author

We can conclude main idea of the method from its limitations! KTO >= DPO and DPO should be used in cases with high quality data, in other cases KTO is really suitable(Actually, there are 2 versions of this loss, which gives more control on situation with your dataset, but they are minorly different, so usually included in one loss(will do it as follow up))

Comment on lines 300 to 319
kl = torch.zeros(1).to(policy_chosen_logps.device)

chosen_logratios = policy_chosen_logps - reference_chosen_logps

chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))

rejected_logratios = policy_rejected_logps - reference_rejected_logps

rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))

losses = torch.cat(
(
self.desirable_weight * chosen_losses,
self.undesirable_weight * rejected_losses,
),
0,
)

chosen_rewards = self.beta * (chosen_logratios).detach()
rejected_rewards = self.beta * (rejected_logratios).detach()
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that the KL calculation is correct. In TRL, the KL calculation is done differently (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L1100) - you can see that self.calculate_KL is True for the KTO loss (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L545).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I saved this mostly as drafted(version without KL) will bring it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, to actually get kl_logprobs we need to modify concatenated_forward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to pass cfg in concatenated_forward

Copy link
Contributor Author

@krammnic krammnic Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually ignore KL divergence right now and use APO variant of loss:

rejected_losses = F.sigmoid(self.beta * rejected_logratios)

Copy link
Contributor Author

@krammnic krammnic Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we need to actually change structure of the batch. Also would require some processing to dataset. Do you approve such changes?

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 18, 2024

I might duplicate them here if required.

That would be very helpful to see! I'm also not convinced about the dataset format, as I've also found in the official KTO repo (https://github.com/ContextualAI/HALOs?tab=readme-ov-file) that:

KTO doesn't use preference pairs, just knowledge of whether outputs are desirable or undesirable. This means we use dataloader.UnpairedPreferenceDataLoader. However, that dataloader assumes that you're working with datasets that originally contain preference pairs, like Anthropic HH or SHP

So we need to think a bit about how to make sure this implementation in the DPO recipe retains parity here. Do we need use some minor modification of a preference dataset here?

@krammnic
Copy link
Contributor Author

I might duplicate them here if required.

That would be very helpful to see! I'm also not convinced about the dataset format, as I've also found in the official KTO repo (https://github.com/ContextualAI/HALOs?tab=readme-ov-file) that:

KTO doesn't use preference pairs, just knowledge of whether outputs are desirable or undesirable. This means we use dataloader.UnpairedPreferenceDataLoader. However, that dataloader assumes that you're working with datasets that originally contain preference pairs, like Anthropic HH or SHP

So we need to think a bit about how to make sure this implementation in the DPO recipe retains parity here. Do we need use some minor modification of a preference dataset here?

Yes, probably will require special things about data loading

@krammnic
Copy link
Contributor Author

I might duplicate them here if required.

That would be very helpful to see! I'm also not convinced about the dataset format, as I've also found in the official KTO repo (https://github.com/ContextualAI/HALOs?tab=readme-ov-file) that:

KTO doesn't use preference pairs, just knowledge of whether outputs are desirable or undesirable. This means we use dataloader.UnpairedPreferenceDataLoader. However, that dataloader assumes that you're working with datasets that originally contain preference pairs, like Anthropic HH or SHP

So we need to think a bit about how to make sure this implementation in the DPO recipe retains parity here. Do we need use some minor modification of a preference dataset here?

Yes, probably will require special things about data loading

Looks like it want be difficult, but actually we can create new dataset type for this. So basically to convert pair dataset to KTO format we need something like:

batch_size = len(examples["chosen"])
new_rows = {
        "completion": examples["chosen"] + examples["rejected"],
        "label": [True] * batch_size + [False] * batch_size,
 }
if "prompt" in examples:
        new_rows["prompt"] = examples["prompt"] + examples["prompt"]

And obviously remove required columns. I'm not sure where to put such preprocessing. Either new dataset type, or new argument in current dataset?

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 21, 2024

And we need to actually change structure of the batch. Also would require some processing to dataset. Do you approve such changes?

Would you be able to outline some of the necessary steps we need to make on the dataset side?

We can actually ignore KL divergence right now and use APO variant of loss:
rejected_losses = F.sigmoid(self.beta * rejected_logratios)

APO is a separate (but comparable) technique to KTO, right? Is this the right paper? It looks like it came out < 2 months ago, and I see very few models on the hub mentioning APO. As I mentioned in the RLHF tracker, I'm not sure we should be implementing techniques which are new and haven't been adopted by the OS community, so I'd rather we stick with the original proposal of KTO, and I'd love to keep supporting you on this.

I don't fully follow your comments on the necessary modifications needed to implement KTO, you mentioned needing to modify concatenated_forward? It'd be really helpful if you could also outline why we need to make these changes - that way we can figure out the best way to implement them.

@krammnic
Copy link
Contributor Author

krammnic commented Oct 21, 2024

  1. About APO. It looks like it is just KTO without KL, but ok, let's skip
  2. Therefore, we don't need to pass anything extra in forward!
  3. We need to do this preprocessing to dataset:
batch_size = len(examples["chosen"])
new_rows = {
        "completion": examples["chosen"] + examples["rejected"],
        "label": [True] * batch_size + [False] * batch_size,
 }
if "prompt" in examples:
        new_rows["prompt"] = examples["prompt"] + examples["prompt"]

and map it to dataset

  1. But we need to calculate: policy_KL_logps, reference_KL_logps
  2. then we need batch["reference_KL_logps"] and policy_KL_logps

@krammnic
Copy link
Contributor Author

  1. About APO. It looks like it is just KTO without KL, but ok, let's skip
  2. Therefore, we don't need to pass anything extra in forward!
  3. We need to do this preprocessing to dataset:
batch_size = len(examples["chosen"])
new_rows = {
        "completion": examples["chosen"] + examples["rejected"],
        "label": [True] * batch_size + [False] * batch_size,
 }
if "prompt" in examples:
        new_rows["prompt"] = examples["prompt"] + examples["prompt"]

and map it to dataset

  1. But we need to calculate: policy_KL_logps, reference_KL_logps
  2. then we need batch["reference_KL_logps"] and policy_KL_logps

@SalmanMohammadi Can we continue from here?

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 21, 2024

  1. About APO. It looks like it is just KTO without KL, but ok, let's skip
  2. Therefore, we don't need to pass anything extra in forward!
  3. We need to do this preprocessing to dataset:
batch_size = len(examples["chosen"])
new_rows = {
        "completion": examples["chosen"] + examples["rejected"],
        "label": [True] * batch_size + [False] * batch_size,
 }
if "prompt" in examples:
        new_rows["prompt"] = examples["prompt"] + examples["prompt"]

and map it to dataset

  1. But we need to calculate: policy_KL_logps, reference_KL_logps
  2. then we need batch["reference_KL_logps"] and policy_KL_logps

@SalmanMohammadi Can we continue from here?

Hey @krammnic - thanks for summarizing here, it seems like we can make things simpler if we don't try to add APO too?

I need a bit of time to look into your other questions which I may not get round to until later this week. If you have an idea of how you'd like to proceed here I'd say go for it and we can try iterate if we need to. I'm interested in your thoughts on whether you think we'll need to make any changes to the recipe itself - for example, does concatenated_forward need to be changed? Or can we make modifications just at the dataset level?

Out of curiosity, do you have a reference for this code so I can follow along?

batch_size = len(examples["chosen"])
new_rows = {
"completion": examples["chosen"] + examples["rejected"],
"label": [True] * batch_size + [False] * batch_size,
}
if "prompt" in examples:
new_rows["prompt"] = examples["prompt"] + examples["prompt"]

@krammnic
Copy link
Contributor Author

Thanks for the answer!

Yes some changes will be required as we need to get batch["reference_KL_logps"] and policy_KL_logps to calculate KL divergence(will write what we need to do here soon).

Dataset processing implementation is from TRL:

https://github.com/huggingface/trl/blob/main/trl/data_utils.py

@krammnic
Copy link
Contributor Author

Added dataset processing, now we need to add correct KL divergence calculation

@krammnic
Copy link
Contributor Author

@SalmanMohammadi Are we supporting different ref_model in current DPO recipe? As I see, we are not:

with torch.no_grad(), disable_adapter(self._model):
      (
          reference_chosen_log_probs,
          reference_rejected_log_probs,
                            _,
                            _,
      ) = self.concatenated_forward(self._model, batch)

@SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi Are we supporting different ref_model in current DPO recipe? As I see, we are not:

with torch.no_grad(), disable_adapter(self._model):
      (
          reference_chosen_log_probs,
          reference_rejected_log_probs,
                            _,
                            _,
      ) = self.concatenated_forward(self._model, batch)

What do you mean by a different ref model? Because we're fine-tuning with LoRA, we can just disable the LoRA adapters to recover the original reference model without having to keep a separate copy of the model in memory. The reference model should always be the base model we're finetuning from.

@krammnic
Copy link
Contributor Author

krammnic commented Oct 23, 2024

@SalmanMohammadi Are we supporting different ref_model in current DPO recipe? As I see, we are not:

with torch.no_grad(), disable_adapter(self._model):
      (
          reference_chosen_log_probs,
          reference_rejected_log_probs,
                            _,
                            _,
      ) = self.concatenated_forward(self._model, batch)

What do you mean by a different ref model? Because we're fine-tuning with LoRA, we can just disable the LoRA adapters to recover the original reference model without having to keep a separate copy of the model in memory. The reference model should always be the base model we're finetuning from.

Actually in TRL there is possibility to set something different!

ref_model (`PreTrainedModelWrapper`):
           Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
           reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.

I'm not sure why they give this possibility either, so that why I've asked. Maybe there is some special use case?

@SalmanMohammadi
Copy link
Collaborator

Ah good catch! I'm not 100% sure either. In their example scripts they always initialize with the same model (https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py).

@krammnic
Copy link
Contributor Author

krammnic commented Oct 23, 2024

Actually, it is probably pretty critical to DPO, for example according to https://arxiv.org/pdf/2407.13709

Our experiments reveal an interesting finding: stronger
reference models can indeed offer more benefits
than the SFT model, but only when they are compat-
ible with the model being fine-tuned

I'm pretty sure that it will be actual for classic DPO for example

@krammnic
Copy link
Contributor Author

I would consider this as a follow up, which we might need to add pretty soon

@krammnic
Copy link
Contributor Author

About this PR. Added all required KL staff, which is not affecting the batch. KTO test right now will fail

@krammnic
Copy link
Contributor Author

krammnic commented Oct 23, 2024

Tasks:

  • Fix test
  • Modify batch
  • Do same changes to distributed recipe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants