diff --git a/docs/source/api_ref_rlhf.rst b/docs/source/api_ref_rlhf.rst index e08c699cb4..2e33f0d82c 100644 --- a/docs/source/api_ref_rlhf.rst +++ b/docs/source/api_ref_rlhf.rst @@ -17,3 +17,4 @@ Components and losses for RLHF algorithms like PPO and DPO. loss.DPOLoss loss.RSOLoss loss.SimPOLoss + loss.KTOLoss diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index c158d17875..59bc488dde 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -8,7 +8,7 @@ import time from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, List from warnings import warn import torch @@ -30,7 +30,7 @@ ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.rlhf.loss import SimPOLoss +from torchtune.rlhf.loss import SimPOLoss, KTOLoss from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -57,6 +57,8 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): - :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). - :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). + - :class:`~torchtune.rlhf.loss.KTOLoss`: Kahneman-Tversky Optimization (KTO). + Assumptions: - Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done @@ -332,11 +334,22 @@ def _setup_lr_scheduler( log.info("Learning rate scheduler is initialized.") return lr_scheduler + def _unpair_row(self, examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[str, List[Dict[str, str]]]]: + batch_size = len(examples["chosen"]) + new_rows = { + "completion": examples["chosen"] + examples["rejected"], + "label": [True] * batch_size + [False] * batch_size, + } + if "prompt" in examples: + new_rows["prompt"] = examples["prompt"] + examples["prompt"] + return new_rows + def _setup_data( self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + num_proc: Optional[int] = None ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports @@ -352,6 +365,14 @@ def _setup_data( else: ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + # We need extra processing for KTO. + if isinstance(self._loss_fn, KTOLoss): + column_names = ds.column_names + # We are assuming that we get dataset in correct formatting or pairs dataset like for any contrastive procedure. + # See reference implementation in: https://github.com/huggingface/trl/blob/main/trl/data_utils.py + if "chosen" in column_names and "rejected" in column_names: + ds = ds.map(self._unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc) + sampler = DistributedSampler( ds, num_replicas=1, @@ -426,7 +447,7 @@ def save_checkpoint(self, epoch: int) -> None: def concatenated_forward( self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Run forward pass of the model with chosen and rejected samples concatenated. @@ -437,10 +458,22 @@ def concatenated_forward( Returns: Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. """ + concatenated_input_ids, concatenated_labels = batch concatenated_input_ids = concatenated_input_ids.to(self._device) concatenated_labels = concatenated_labels.to(self._device) + if isinstance(self._loss_fn, KTOLoss): + with torch.no_grad(): + KL_logits = model( + concatenated_input_ids, + ).logits + + KL_logps = rlhf.get_batch_log_probs( + KL_logits, + concatenated_labels, # FIXME: Must be KL labels! + ) + # formed by concatenating an equal number of "chosen" and "rejected". len_chosen = concatenated_input_ids.shape[0] // 2 @@ -459,7 +492,10 @@ def concatenated_forward( chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] - return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + if not isinstance(self._loss_fn, KTOLoss): + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + else: + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits, KL_logps) def train(self) -> None: """ @@ -490,13 +526,23 @@ def train(self) -> None: break # batch is input_ids, labels - num_tokens += batch[0].numel() - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - ) = self.concatenated_forward(self._model, batch) + if not isinstance(self._loss_fn, KTOLoss): + num_tokens += batch[0].numel() + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self._model, batch) + else: + num_tokens += batch[0].numel() + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + policy_KL_logps, + ) = self.concatenated_forward(self._model, batch) policy_chosen_logits_mean = policy_chosen_logits.detach().mean() policy_rejected_logits_mean = policy_rejected_logits.detach().mean() @@ -508,6 +554,24 @@ def train(self) -> None: loss, chosen_rewards, rejected_rewards = self._loss_fn( policy_chosen_log_probs, policy_rejected_log_probs ) + elif isinstance(self._loss_fn, KTOLoss): + # In case of KTOLoss we require reference_KL_logps + with torch.no_grad(), disable_adapter(self._model): + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + reference_KL_logps, + ) = self.concatenated_forward(self._model, batch) + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_KL_logps, + reference_chosen_log_probs, + reference_rejected_log_probs, + reference_KL_logps + ) else: # reference based losses (e.g. DPO) explicitly regularize the objective fn based on # the reference model's output - reference-free losses (such as SimPO) don't require this. @@ -568,6 +632,7 @@ def train(self) -> None: "logits/rejected": policy_rejected_logits_mean.cpu(), "logits/chosen": policy_chosen_logits_mean.cpu(), } + if self._log_peak_memory_stats: log_dict.update( training.get_memory_stats(device=self._device) diff --git a/tests/torchtune/rlhf/loss/test_dpo_loss.py b/tests/torchtune/rlhf/loss/test_dpo_loss.py index 6c3e2dd4e0..37017c75d8 100644 --- a/tests/torchtune/rlhf/loss/test_dpo_loss.py +++ b/tests/torchtune/rlhf/loss/test_dpo_loss.py @@ -6,7 +6,7 @@ import pytest import torch -from torchtune.rlhf.loss import DPOLoss, RSOLoss, SimPOLoss +from torchtune.rlhf.loss import DPOLoss, KTOLoss, RSOLoss, SimPOLoss @pytest.fixture(autouse=True) @@ -36,6 +36,14 @@ def simpo_loss(self): label_smoothing=0.0, ) + @pytest.fixture + def kto_loss(self): + return KTOLoss( + beta=0.1, + undesirable_weight=1.0, + desirable_weight=1.0, + ) + @pytest.fixture def loss_inputs(self): """ @@ -47,6 +55,8 @@ def loss_inputs(self): ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1]) ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1]) + + return ( policy_chosen_logprobs, policy_rejected_logprobs, @@ -123,3 +133,73 @@ def test_simpo_loss(self, simpo_loss, loss_inputs): losses, *_ = simpo_loss(policy_chosen_logprobs, policy_rejected_logprobs) torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) + + def test_kto_loss(self, kto_loss, loss_inputs): + """ + beta = 0.1 + policy_chosen_logprobs = torch.tensor([-0.5, -10.0, -1.0]) + policy_rejected_logprobs = torch.tensor([-0.1, -30.0, -21.0]) + + ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1]) + ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1]) + + policy_KL_logps = torch.tensor([-1.0, 0.25, 1.0]) + reference_KL_logps = torch.tensor([0.4, 0.2, -0.2]) + + kl = (policy_KL_logps - reference_KL_logps).mean().detach() + kl = kl.mean().clamp(min=0) + + kl = torch.tensor([0]) + + chosen_logratios = policy_chosen_logprobs - ref_chosen_logprobs + chosen_logratios = torch.tensor([0., 0.1, -0.9]) + + chosen_losses = 1 - F.sigmoid(0.1 * (torch.tensor([0., 0.1, -0.9]) - torch.tensor([0]))) + chosen_losses = torch.tensor([0.5000, 0.4975, 0.5225]) + + rejected_logratios = policy_rejected_logps - reference_rejected_logps + rejected_logratios = torch.tensor([0.0, -9.9, -20.9]) + + rejected_losses = 1 - F.sigmoid(0.1 * (torch.tensor([0]) - torch.tensor([0.0, -9.9, -20.9]))) + rejected_losses = torch.tensor([0.5, 0.2709, 0.1101]) + + + desirable_weight = undesirable_weight = 1.0 + Therefore: + + losses = torch.cat( + (1 * chosen_losses, 1 * rejected_losses), + 0 + ) + + losses = torch.tensor([0.5000, 0.4975, 0.5225, 0.5000, 0.2709, 0.1101]) + + chosen_rewards = 0.1 * (chosen_logratios).detach() + rejected_rewards = 0.1 * (rejected_logratios).detach() + + chosen_rewards = torch.tensor([0.0000, 0.0100, -0.0900]) + rejected_rewards = torch.tensor([ 0.0000, -0.9900, -2.0900]) + """ + + ( + policy_chosen_logprobs, + policy_rejected_logprobs, + ref_chosen_logprobs, + ref_rejected_logprobs, + ) = loss_inputs + + # We don't add it in general fixture, as it is only relevant for KTO + policy_KL_logps = torch.tensor([-1.0, 0.25, 1.0]) + reference_KL_logps = torch.tensor([0.4, 0.2, -0.2]) + + losses, *_ = kto_loss( + policy_chosen_logprobs, + policy_rejected_logprobs, + policy_KL_logps, + ref_chosen_logprobs, + ref_rejected_logprobs, + reference_KL_logps + ) + + expected_losses = torch.tensor([0.5000, 0.4975, 0.5225, 0.5000, 0.2709, 0.1101]) + torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) diff --git a/torchtune/rlhf/loss/__init__.py b/torchtune/rlhf/loss/__init__.py index 5c4b649587..6d2ccbfc6d 100644 --- a/torchtune/rlhf/loss/__init__.py +++ b/torchtune/rlhf/loss/__init__.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from .dpo import DPOLoss, RSOLoss, SimPOLoss +from .dpo import DPOLoss, KTOLoss, RSOLoss, SimPOLoss from .ppo import PPOLoss -__all__ = ["DPOLoss", "RSOLoss", "SimPOLoss", "PPOLoss"] +__all__ = ["DPOLoss", "RSOLoss", "SimPOLoss", "PPOLoss", "KTOLoss"] diff --git a/torchtune/rlhf/loss/dpo.py b/torchtune/rlhf/loss/dpo.py index 29f66a20c3..c98c416be8 100644 --- a/torchtune/rlhf/loss/dpo.py +++ b/torchtune/rlhf/loss/dpo.py @@ -231,3 +231,94 @@ def forward( rejected_rewards = self.beta * (policy_rejected_logps).detach() return losses, chosen_rewards, rejected_rewards + + +class KTOLoss(nn.Module): + """ + KTO: Kahneman-Tversky Optimization: https://arxiv.org/abs/2402.01306 + Intuition from the paper: + + The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as + the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to + encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. + + Based on the TRL implementation: + https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/kto_trainer.py + + KTO is simple, but effective alternative to DPO. According KTO paper, KTO is actually >= DPO. + And work better than DPO when data is not enough quality. It still uses reference model, but HALO that + was proposed is significantly more effective. + + Args: + beta (float): Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. Default is 0.1. + desirable_weight (float): Desirable losses are weighed by this factor to counter + unequal number of desirable and undesirable paris. Default is 1.0. + undesirable_weight (float): Undesirable losses are weighed by this factor to counter + unequal number of desirable and undesirable paris. Default is 1.0. + """ + + def __init__( + self, + beta: float = 0.1, + desirable_weight: float = 1.0, + undesirable_weight: float = 1.0, + ): + super().__init__() + self.beta = beta + self.desirable_weight = desirable_weight + self.undesirable_weight = undesirable_weight + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + policy_KL_logps, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + reference_KL_logps + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the KTO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps (torch.Tensor): Log probabilities of the policy model + for the chosen responses. Shape: (batch_size) + policy_rejected_logps (torch.Tensor): Log probabilities of the policy model + for the rejected responses. Shape: (batch_size) + reference_chosen_logps (torch.Tensor): Log probabilities of the reference model + for the chosen responses. Shape: (batch_size) + reference_rejected_logps (torch.Tensor): Log probabilities of the reference model + for the rejected responses. Shape: (batch_size) + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: + - losses: The DPO loss for each example in the batch. + - chosen_rewards: Rewards for the chosen responses. + - rejected_rewards: Rewards for the rejected responses. + + """ + + kl = (policy_KL_logps - reference_KL_logps).mean().detach() + kl = kl.mean().clamp(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + + losses = torch.cat( + ( + self.desirable_weight * chosen_losses, + self.undesirable_weight * rejected_losses, + ), + 0, + ) + + chosen_rewards = self.beta * (chosen_logratios).detach() + rejected_rewards = self.beta * (rejected_logratios).detach() + + return losses, chosen_rewards, rejected_rewards