Skip to content

Commit

Permalink
add dataset processing for KTO
Browse files Browse the repository at this point in the history
  • Loading branch information
krammnic committed Oct 23, 2024
1 parent dbb35fb commit 3f3fd67
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, List
from warnings import warn

import torch
Expand All @@ -33,6 +33,8 @@
from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm

from torchtune.torchtune.rlhf.loss import KTOLoss

log = utils.get_logger("DEBUG")


Expand Down Expand Up @@ -334,11 +336,22 @@ def _setup_lr_scheduler(
log.info("Learning rate scheduler is initialized.")
return lr_scheduler

def _unpair_row(self, examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[str, List[Dict[str, str]]]]:
batch_size = len(examples["chosen"])
new_rows = {
"completion": examples["chosen"] + examples["rejected"],
"label": [True] * batch_size + [False] * batch_size,
}
if "prompt" in examples:
new_rows["prompt"] = examples["prompt"] + examples["prompt"]
return new_rows

def _setup_data(
self,
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
num_proc: Optional[int] = None
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports
Expand All @@ -354,6 +367,15 @@ def _setup_data(
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

# We need extra processing for KTO.
if isinstance(self._loss_fn, KTOLoss):
column_names = ds.column_names
# We are assuming that we get dataset in correct formatting or pairs dataset like for any contrastive procedure.
# See reference implementation in: https://github.com/huggingface/trl/blob/main/trl/data_utils.py
if "chosen" in column_names and "rejected" in column_names:
ds = ds.map(self._unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc)


sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down Expand Up @@ -439,6 +461,8 @@ def concatenated_forward(
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""


concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
concatenated_labels = concatenated_labels.to(self._device)
Expand Down

0 comments on commit 3f3fd67

Please sign in to comment.