Skip to content

Commit

Permalink
[fix][1760] Added fix for the missing context key issye in dolly!
Browse files Browse the repository at this point in the history
[fix][1760] Added fix for the missing `context` key issye in dolly!
  • Loading branch information
pytholic committed Oct 2, 2024
1 parent a8aa4ba commit e3b75c8
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 32 deletions.
9 changes: 6 additions & 3 deletions litgpt/data/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from litgpt.prompts import PromptStyle
from litgpt.data import Alpaca, SFTDataset

_URL: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
_URL: str = (
"https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
)


@dataclass
Expand Down Expand Up @@ -70,7 +72,8 @@ def setup(self, stage: str = "") -> None:
)


# TODO: break test with old behavior
def _transform(item: dict) -> dict:
item["input"] = item.pop("context")
item["output"] = item.pop("response")
item["input"] = item.get("context", "")
item["output"] = item.get("response", "")
return item
107 changes: 79 additions & 28 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/lora"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
quantize: Optional[
Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]
] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
lora_r: int = 8,
Expand Down Expand Up @@ -123,7 +125,9 @@ def setup(
)

precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)
logger = choose_logger(
logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval
)

plugins = None
if quantize is not None and quantize.startswith("bnb."):
Expand All @@ -134,7 +138,9 @@ def setup(
"LitGPT only supports bitsandbytes v0.42.0. "
"This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[
precision
]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

Expand Down Expand Up @@ -166,7 +172,9 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)
fabric.launch(
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer
)


def main(
Expand Down Expand Up @@ -199,7 +207,9 @@ def main(
mark_only_lora_as_trainable(model)

fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")
fabric.print(
f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}"
)

model = fabric.setup_module(model)

Expand All @@ -209,7 +219,9 @@ def main(
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
scheduler = get_lr_scheduler(
optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps
)

# strict=False because missing keys due to LoRA weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
Expand All @@ -235,10 +247,14 @@ def main(

# Final evaluation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = validate(
fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))
)
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics)
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
fabric.print(
f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}"
)

# Save the final LoRA checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth.lora"
Expand Down Expand Up @@ -267,26 +283,32 @@ def fit(
data: DataModule,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]))
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)

if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = validate(
fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))
)
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
validate(
fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False
) # sanity check
val_loss = "n/a"

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
fabric.device
)
running_loss = RunningMean(
window=train.gradient_accumulation_iters(devices), sync_on_compute=False
).to(fabric.device)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
Expand Down Expand Up @@ -320,7 +342,10 @@ def fit(
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
time=t1 - total_t0,
batches=iter_num,
samples=iter_num * train.micro_batch_size,
lengths=total_lengths,
)
throughput.compute_and_log(step=iter_num)
metrics = {
Expand All @@ -330,7 +355,9 @@ def fit(
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": iter_num * train.micro_batch_size * model.config.block_size,
"total_tokens": (iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size),
"total_tokens": (
iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size
),
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
Expand All @@ -349,12 +376,18 @@ def fit(
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
fabric.print(
f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms\n"
)
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()

if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
if (
train.save_interval is not None
and not is_accumulating
and step_count % train.save_interval == 0
):
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
Expand All @@ -366,7 +399,9 @@ def fit(

# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True) -> torch.Tensor:
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
Expand All @@ -385,9 +420,11 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva


@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
def generate_example(
fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule
):
fabric.print("Generating sample ...")
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
Expand All @@ -399,12 +436,16 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
model,
encoded,
max_returned_tokens=max_returned_tokens,
temperature=0.8,
eos_id=tokenizer.eos_id,
)
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
Expand All @@ -416,14 +457,20 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=(max_steps - warmup_steps)
)
return torch.optim.lr_scheduler.SequentialLR(
optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]
)


def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
data.connect(
tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length
)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
Expand Down Expand Up @@ -452,13 +499,17 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
issues.append(
f"{__file__} doesn't support the {name!r} argument. This is set in {args}"
)
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
issues.append(
f"{__file__} requires either epochs or max_steps to be set. This is set in {train}"
)
if issues:
raise ValueError("\n".join(issues))
59 changes: 58 additions & 1 deletion tests/data/test_dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@


def test_dolly(mock_tokenizer, dolly_path):
dolly = Dolly(val_split_fraction=0.5, download_dir=dolly_path.parent, file_name=dolly_path.name, num_workers=0)
dolly = Dolly(
val_split_fraction=0.5,
download_dir=dolly_path.parent,
file_name=dolly_path.name,
num_workers=0,
)
assert isinstance(dolly.prompt_style, AlpacaPromptStyle)
dolly.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
dolly.prepare_data()
Expand All @@ -29,3 +34,55 @@ def test_dolly(mock_tokenizer, dolly_path):

# has attributes from super class `LightningDataModule`
assert dolly.prepare_data_per_node


def test_dolly_missing_keys(mock_tokenizer, dolly_path):
"""
Notes
-----
- Added only for the dolly dataset.
References
----------
- Reference issue: https://github.com/Lightning-AI/litgpt/issues/1760
Methodology
-----------
- Simulate the original behavior by popping `context` key.
- Run dataloader which will apply `transform`.
- Previously it would have thrown missing `context` key error because we `popped` the key.
- Now we are using `get` method to not remove they key(s).
"""

dolly = Dolly(
val_split_fraction=0.5,
download_dir=dolly_path.parent,
file_name=dolly_path.name,
num_workers=0,
)
dolly.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
dolly.prepare_data()
dolly.setup()

# check if the dataset was created without errors
assert dolly.train_dataset is not None
assert dolly.test_dataset is not None

# Verify that the transform function handled missing keys correctly
for dataset in [dolly.train_dataset, dolly.test_dataset]:
for item in dataset.data:
assert "context" in item
assert "response" in item
assert isinstance(item["context"], str)
assert isinstance(item["response"], str)
# Drop `context` and `response` keys
# This is to simulate the behavior of original issue with `item.pop`
item.pop("context")
item.pop("response")

# Check if we can iterate through the dataloader without errors
# Previous approach would through key error here since we already popped the keys
train_dataloader = dolly.train_dataloader()
train_batch = next(iter(train_dataloader))
assert "input_ids" in train_batch
assert "labels" in train_batch

0 comments on commit e3b75c8

Please sign in to comment.