Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add KTOLoss #1864

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_rlhf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ Components and losses for RLHF algorithms like PPO and DPO.
loss.DPOLoss
loss.RSOLoss
loss.SimPOLoss
loss.KTOLoss
2 changes: 2 additions & 0 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
- :class:`~torchtune.rlhf.loss.KTOLoss`: Kahneman-Tversky Optimization (KTO).


Assumptions:
- Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done
Expand Down
68 changes: 67 additions & 1 deletion tests/torchtune/rlhf/loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from torchtune.rlhf.loss import DPOLoss, RSOLoss, SimPOLoss
from torchtune.rlhf.loss import DPOLoss, KTOLoss, RSOLoss, SimPOLoss


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -36,6 +36,14 @@ def simpo_loss(self):
label_smoothing=0.0,
)

@pytest.fixture
def kto_loss(self):
return KTOLoss(
beta=0.1,
undesirable_weight=1.0,
desirable_weight=1.0,
)

@pytest.fixture
def loss_inputs(self):
"""
Expand Down Expand Up @@ -123,3 +131,61 @@ def test_simpo_loss(self, simpo_loss, loss_inputs):
losses, *_ = simpo_loss(policy_chosen_logprobs, policy_rejected_logprobs)

torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5)

def test_kto_loss(self, kto_loss, loss_inputs):
"""
beta = 0.1
policy_chosen_logprobs = torch.tensor([-0.5, -10.0, -1.0])
policy_rejected_logprobs = torch.tensor([-0.1, -30.0, -21.0])

ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1])
ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1])

kl = torch.tensor([0])

chosen_logratios = policy_chosen_logprobs - ref_chosen_logprobs
chosen_logratios = torch.tensor([0., 0.1, -0.9])

chosen_losses = 1 - F.sigmoid(0.1 * (torch.tensor([0., 0.1, -0.9]) - torch.tensor([0])))
chosen_losses = torch.tensor([0.5000, 0.4975, 0.5225])

rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_logratios = torch.tensor([0.0, -9.9, -20.9])

rejected_losses = 1 - F.sigmoid(0.1 * (torch.tensor([0]) - torch.tensor([0.0, -9.9, -20.9])))
rejected_losses = torch.tensor([0.5, 0.2709, 0.1101])


desirable_weight = undesirable_weight = 1.0
Therefore:

losses = torch.cat(
(1 * chosen_losses, 1 * rejected_losses),
0
)

losses = torch.tensor([0.5000, 0.4975, 0.5225, 0.5000, 0.2709, 0.1101])

chosen_rewards = 0.1 * (chosen_logratios).detach()
rejected_rewards = 0.1 * (rejected_logratios).detach()

chosen_rewards = torch.tensor([0.0000, 0.0100, -0.0900])
rejected_rewards = torch.tensor([ 0.0000, -0.9900, -2.0900])
"""

(
policy_chosen_logprobs,
policy_rejected_logprobs,
ref_chosen_logprobs,
ref_rejected_logprobs,
) = loss_inputs

losses, *_ = kto_loss(
policy_chosen_logprobs,
policy_rejected_logprobs,
ref_chosen_logprobs,
ref_rejected_logprobs,
)

expected_losses = torch.tensor([0.5000, 0.4975, 0.5225, 0.5000, 0.2709, 0.1101])
torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5)
4 changes: 2 additions & 2 deletions torchtune/rlhf/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from .dpo import DPOLoss, RSOLoss, SimPOLoss
from .dpo import DPOLoss, KTOLoss, RSOLoss, SimPOLoss
from .ppo import PPOLoss

__all__ = ["DPOLoss", "RSOLoss", "SimPOLoss", "PPOLoss"]
__all__ = ["DPOLoss", "RSOLoss", "SimPOLoss", "PPOLoss", "KTOLoss"]
88 changes: 88 additions & 0 deletions torchtune/rlhf/loss/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,91 @@ def forward(
rejected_rewards = self.beta * (policy_rejected_logps).detach()

return losses, chosen_rewards, rejected_rewards


class KTOLoss(nn.Module):
"""
KTO: Kahneman-Tversky Optimization: https://arxiv.org/abs/2402.01306
Intuition from the paper:

The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as
the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to
encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance.

Comment on lines +241 to +244
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this has been copied from the SimPO loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I used SimPO docs as template and forgot to add my docs in this part. Other comments are about KTO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Based on the TRL implementation:
https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/kto_trainer.py

KTO is simple, but effective alternative to DPO. According KTO paper, KTO is actually >= DPO.
And work better than DPO when data is not enough quality. It still uses reference model, but HALO that
was proposed is significantly more effective.

Args:
beta (float): Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. Default is 0.1.
desirable_weight (float): Desirable losses are weighed by this factor to counter
unequal number of desirable and undesirable paris. Default is 1.0.
undesirable_weight (float): Undesirable losses are weighed by this factor to counter
unequal number of desirable and undesirable paris. Default is 1.0.
"""

def __init__(
self,
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
):
super().__init__()
self.beta = beta
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight

def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the KTO loss for a batch of policy and reference model log probabilities.

Args:
policy_chosen_logps (torch.Tensor): Log probabilities of the policy model
for the chosen responses. Shape: (batch_size)
policy_rejected_logps (torch.Tensor): Log probabilities of the policy model
for the rejected responses. Shape: (batch_size)
reference_chosen_logps (torch.Tensor): Log probabilities of the reference model
for the chosen responses. Shape: (batch_size)
reference_rejected_logps (torch.Tensor): Log probabilities of the reference model
for the rejected responses. Shape: (batch_size)

Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors:
- losses: The DPO loss for each example in the batch.
- chosen_rewards: Rewards for the chosen responses.
- rejected_rewards: Rewards for the rejected responses.

"""

kl = torch.zeros(1).to(policy_chosen_logps.device)

chosen_logratios = policy_chosen_logps - reference_chosen_logps

chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))

rejected_logratios = policy_rejected_logps - reference_rejected_logps

rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))

losses = torch.cat(
(
self.desirable_weight * chosen_losses,
self.undesirable_weight * rejected_losses,
),
0,
)

chosen_rewards = self.beta * (chosen_logratios).detach()
rejected_rewards = self.beta * (rejected_logratios).detach()
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that the KL calculation is correct. In TRL, the KL calculation is done differently (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L1100) - you can see that self.calculate_KL is True for the KTO loss (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L545).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I saved this mostly as drafted(version without KL) will bring it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, to actually get kl_logprobs we need to modify concatenated_forward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to pass cfg in concatenated_forward

Copy link
Contributor Author

@krammnic krammnic Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually ignore KL divergence right now and use APO variant of loss:

rejected_losses = F.sigmoid(self.beta * rejected_logratios)

Copy link
Contributor Author

@krammnic krammnic Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we need to actually change structure of the batch. Also would require some processing to dataset. Do you approve such changes?


return losses, chosen_rewards, rejected_rewards