-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add KTOLoss #1864
base: main
Are you sure you want to change the base?
[WIP] Add KTOLoss #1864
Changes from 3 commits
b74be00
3ac4de0
dbb35fb
3f3fd67
ab1a2fb
93b3944
7bf8655
dd05c01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,3 +231,91 @@ def forward( | |
rejected_rewards = self.beta * (policy_rejected_logps).detach() | ||
|
||
return losses, chosen_rewards, rejected_rewards | ||
|
||
|
||
class KTOLoss(nn.Module): | ||
""" | ||
KTO: Kahneman-Tversky Optimization: https://arxiv.org/abs/2402.01306 | ||
Intuition from the paper: | ||
|
||
The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as | ||
the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to | ||
encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. | ||
|
||
Based on the TRL implementation: | ||
https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/kto_trainer.py | ||
|
||
KTO is simple, but effective alternative to DPO. According KTO paper, KTO is actually >= DPO. | ||
And work better than DPO when data is not enough quality. It still uses reference model, but HALO that | ||
was proposed is significantly more effective. | ||
|
||
Args: | ||
beta (float): Parameter controlling the deviation from the reference model. Higher β means less deviation from the | ||
reference model. Default is 0.1. | ||
desirable_weight (float): Desirable losses are weighed by this factor to counter | ||
unequal number of desirable and undesirable paris. Default is 1.0. | ||
undesirable_weight (float): Undesirable losses are weighed by this factor to counter | ||
unequal number of desirable and undesirable paris. Default is 1.0. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
beta: float = 0.1, | ||
desirable_weight: float = 1.0, | ||
undesirable_weight: float = 1.0, | ||
): | ||
super().__init__() | ||
self.beta = beta | ||
self.desirable_weight = desirable_weight | ||
self.undesirable_weight = undesirable_weight | ||
|
||
def forward( | ||
self, | ||
policy_chosen_logps: torch.Tensor, | ||
policy_rejected_logps: torch.Tensor, | ||
reference_chosen_logps: torch.Tensor, | ||
reference_rejected_logps: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Compute the KTO loss for a batch of policy and reference model log probabilities. | ||
|
||
Args: | ||
policy_chosen_logps (torch.Tensor): Log probabilities of the policy model | ||
for the chosen responses. Shape: (batch_size) | ||
policy_rejected_logps (torch.Tensor): Log probabilities of the policy model | ||
for the rejected responses. Shape: (batch_size) | ||
reference_chosen_logps (torch.Tensor): Log probabilities of the reference model | ||
for the chosen responses. Shape: (batch_size) | ||
reference_rejected_logps (torch.Tensor): Log probabilities of the reference model | ||
for the rejected responses. Shape: (batch_size) | ||
|
||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: | ||
- losses: The DPO loss for each example in the batch. | ||
- chosen_rewards: Rewards for the chosen responses. | ||
- rejected_rewards: Rewards for the rejected responses. | ||
|
||
""" | ||
|
||
kl = torch.zeros(1).to(policy_chosen_logps.device) | ||
|
||
chosen_logratios = policy_chosen_logps - reference_chosen_logps | ||
|
||
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) | ||
|
||
rejected_logratios = policy_rejected_logps - reference_rejected_logps | ||
|
||
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) | ||
|
||
losses = torch.cat( | ||
( | ||
self.desirable_weight * chosen_losses, | ||
self.undesirable_weight * rejected_losses, | ||
), | ||
0, | ||
) | ||
|
||
chosen_rewards = self.beta * (chosen_logratios).detach() | ||
rejected_rewards = self.beta * (rejected_logratios).detach() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not convinced that the KL calculation is correct. In TRL, the KL calculation is done differently (https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py#L1100) - you can see that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same in the official KTO repo https://github.com/ContextualAI/HALOs/blob/61b9ee6787c9f52c2f4578533866e1cb42f314ec/trainers.py#L836 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I saved this mostly as drafted(version without KL) will bring it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, to actually get kl_logprobs we need to modify concatenated_forward There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want to pass cfg in concatenated_forward There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can actually ignore KL divergence right now and use APO variant of loss:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And we need to actually change structure of the batch. Also would require some processing to dataset. Do you approve such changes? |
||
|
||
return losses, chosen_rewards, rejected_rewards |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this has been copied from the SimPO loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I used SimPO docs as template and forgot to add my docs in this part. Other comments are about KTO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed