diff --git a/finetune/adapter.py b/finetune/adapter.py index 1114e7d0e5..b5f81298c3 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -8,6 +8,7 @@ import lightning as L import torch +from torch.utils.data import DataLoader from lightning.fabric.loggers import CSVLogger from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy @@ -20,6 +21,7 @@ from generate.base import generate from lit_gpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.data import Alpaca, LitDataModule, apply_prompt_template from lit_gpt.tokenizer import Tokenizer from lit_gpt.utils import ( CLI, @@ -28,8 +30,8 @@ get_default_supported_precision, load_checkpoint, num_parameters, + CycleIterator, ) -from scripts.prepare_alpaca import generate_prompt def setup( @@ -37,11 +39,10 @@ def setup( quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: int = 1, seed: int = 1337, + data: Optional[LitDataModule] = None, io: IOArgs = IOArgs( - train_data_dir=Path("data/alpaca"), - val_data_dir=Path("data/alpaca"), checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - out_dir=Path("out/adapter/alpaca"), + out_dir=Path("out/adapter"), ), train: TrainArgs = TrainArgs( save_interval=1000, @@ -50,13 +51,16 @@ def setup( micro_batch_size=4, lr_warmup_steps=100, epochs=5, - epoch_size=50000, learning_rate=1e-3, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), ) -> None: + print(locals()) + if data is None: + data = Alpaca() + precision = precision or get_default_supported_precision(training=True) plugins = None @@ -85,25 +89,24 @@ def setup( logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval) fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) - fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), io, train, eval) + fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), data, io, train, eval) -def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: +def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: validate_args(io, train, eval) - steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) - lr_max_steps = train.epochs * steps_per_epoch - check_valid_checkpoint_dir(io.checkpoint_dir) + tokenizer = Tokenizer(io.checkpoint_dir) + train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) + fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(io.out_dir, exist_ok=True) - train_data = torch.load(io.train_data_dir / "train.pt") - val_data = torch.load(io.val_data_dir / "test.pt") - checkpoint_path = io.checkpoint_dir / "lit_model.pth" fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") with fabric.init_module(empty_init=(devices > 1)): @@ -131,10 +134,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, # strict=False because missing keys due to Adapter weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) - fabric.seed_everything(1337 + fabric.global_rank) - train_time = time.perf_counter() - fit(fabric, model, optimizer, scheduler, train_data, val_data, devices, io, train, eval) + fit(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader, devices, io, train, eval) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -149,34 +150,37 @@ def fit( model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, - train_data: List[Dict], - val_data: List[Dict], + train_dataloader: DataLoader, + val_dataloader: DataLoader, devices: int, io: IOArgs, train: TrainArgs, eval: EvalArgs, ) -> None: tokenizer = Tokenizer(io.checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) - validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check + validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2)) # sanity check + train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) + max_steps = train.max_steps or float("inf") step_count = 0 + iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() - for iter_num in range(1, train.max_iters(devices) + 1): + while step_count < max_steps and train_iterator.epoch < train.epochs: + iter_num += 1 iter_t0 = time.perf_counter() - input_ids, targets = get_batch( - fabric, train_data, train.micro_batch_size, train.max_seq_length, longest_seq_ix if iter_num == 1 else None - ) + batch = next(train_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): @@ -207,7 +211,7 @@ def fit( if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() - val_loss = validate(fabric, model, val_data, tokenizer, eval, train) + val_loss = validate(fabric, model, val_dataloader, tokenizer, eval) t1 = time.perf_counter() - t0 fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms") fabric.barrier() @@ -219,13 +223,15 @@ def fit( # the adapter "kv cache" cannot be initialized under `inference_mode` @torch.no_grad() def validate( - fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs + fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, ) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval.max_iters) + val_iterator = iter(val_dataloader) for k in range(eval.max_iters): - input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length) + batch = next(val_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() @@ -234,7 +240,7 @@ def validate( instruction = "Recommend a movie for me to watch during the weekend and explain the reason." fabric.print(instruction) sample = {"instruction": instruction, "input": ""} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here @@ -250,44 +256,6 @@ def validate( return val_loss -def get_batch( - fabric: L.Fabric, - data: List[Dict], - micro_batch_size: int, - max_seq_length: Optional[int], - longest_seq_ix: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - ix = torch.randint(len(data), (micro_batch_size,)) - if longest_seq_ix is not None: - # force the longest sample at the beginning so potential OOMs happen right away - ix[0] = longest_seq_ix - - input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] - labels = [data[i]["labels"].type(torch.int64) for i in ix] - - # this could be `longest_seq_length` to have a fixed size for all batches - max_len = max(len(s) for s in input_ids) - - def pad_right(x, pad_id): - # pad right based on the longest sequence - n = max_len - len(x) - return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) - - x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) - y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) - - # Truncate if needed - if max_seq_length: - x = x[:, :max_seq_length] - y = y[:, :max_seq_length] - - if fabric.device.type == "cuda" and x.device.type == "cpu": - x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) - else: - x, y = fabric.to_device((x, y)) - return x, y - - def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) @@ -295,6 +263,17 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) +def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]: + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + with fabric.rank_zero_first(): + data.prepare_data() + data.setup() + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + return train_dataloader, val_dataloader + + def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] @@ -316,14 +295,16 @@ def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [ - (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), - (train, ["epoch_size", "epochs"]), + (io, ["checkpoint_dir"]), + (train, ["epochs"]), (eval, ["max_new_tokens"]), ] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") + if not train.epochs and not train.max_steps: + issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 34e3607b28..2e081e74a7 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -8,6 +8,7 @@ import lightning as L import torch +from torch.utils.data import DataLoader from lightning.fabric.loggers import CSVLogger from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy @@ -20,6 +21,7 @@ from generate.base import generate from lit_gpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.data import Alpaca, LitDataModule, apply_prompt_template from lit_gpt.tokenizer import Tokenizer from lit_gpt.utils import ( CLI, @@ -28,8 +30,8 @@ get_default_supported_precision, load_checkpoint, num_parameters, + CycleIterator, ) -from scripts.prepare_alpaca import generate_prompt def setup( @@ -37,11 +39,10 @@ def setup( quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: int = 1, seed: int = 1337, + data: Optional[LitDataModule] = None, io: IOArgs = IOArgs( - train_data_dir=Path("data/alpaca"), - val_data_dir=Path("data/alpaca"), checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - out_dir=Path("out/adapter_v2/alpaca"), + out_dir=Path("out/adapter_v2"), ), train: TrainArgs = TrainArgs( save_interval=1000, @@ -50,13 +51,16 @@ def setup( micro_batch_size=2, lr_warmup_steps=100, epochs=5, - epoch_size=50000, learning_rate=1e-3, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), ) -> None: + print(locals()) + if data is None: + data = Alpaca() + precision = precision or get_default_supported_precision(training=True) plugins = None @@ -85,25 +89,24 @@ def setup( logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval) fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) - fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), io, train, eval) + fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), data, io, train, eval) -def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: +def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: validate_args(io, train, eval) - steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) - lr_max_steps = train.epochs * steps_per_epoch - check_valid_checkpoint_dir(io.checkpoint_dir) + tokenizer = Tokenizer(io.checkpoint_dir) + train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) + fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(io.out_dir, exist_ok=True) - train_data = torch.load(io.train_data_dir / "train.pt") - val_data = torch.load(io.val_data_dir / "test.pt") - checkpoint_path = io.checkpoint_dir / "lit_model.pth" fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") with fabric.init_module(empty_init=(devices > 1)): @@ -131,10 +134,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, # strict=False because missing keys due to Adapter weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) - fabric.seed_everything(1337 + fabric.global_rank) - train_time = time.perf_counter() - fit(fabric, model, optimizer, scheduler, train_data, val_data, devices, io, train, eval) + fit(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader, devices, io, train, eval) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -149,34 +150,37 @@ def fit( model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, - train_data: List[Dict], - val_data: List[Dict], + train_dataloader: DataLoader, + val_dataloader: DataLoader, devices: int, io: IOArgs, train: TrainArgs, eval: EvalArgs, ) -> None: tokenizer = Tokenizer(io.checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) - validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check + validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2)) # sanity check + train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) + max_steps = train.max_steps or float("inf") step_count = 0 + iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() - for iter_num in range(1, train.max_iters(devices) + 1): + while step_count < max_steps and train_iterator.epoch < train.epochs: + iter_num += 1 iter_t0 = time.perf_counter() - input_ids, targets = get_batch( - fabric, train_data, train.micro_batch_size, train.max_seq_length, longest_seq_ix if iter_num == 1 else None - ) + batch = next(train_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): @@ -207,7 +211,7 @@ def fit( if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() - val_loss = validate(fabric, model, val_data, tokenizer, eval, train) + val_loss = validate(fabric, model, val_dataloader, tokenizer, eval) t1 = time.perf_counter() - t0 fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms") fabric.barrier() @@ -219,13 +223,15 @@ def fit( # the adapter "kv cache" cannot be initialized under `inference_mode` @torch.no_grad() def validate( - fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs + fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, ) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval.max_iters) + val_iterator = iter(val_dataloader) for k in range(eval.max_iters): - input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length) + batch = next(val_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() @@ -234,7 +240,7 @@ def validate( instruction = "Recommend a movie for me to watch during the weekend and explain the reason." fabric.print(instruction) sample = {"instruction": instruction, "input": ""} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here @@ -250,44 +256,6 @@ def validate( return val_loss -def get_batch( - fabric: L.Fabric, - data: List[Dict], - micro_batch_size: int, - max_seq_length: Optional[int], - longest_seq_ix: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - ix = torch.randint(len(data), (micro_batch_size,)) - if longest_seq_ix is not None: - # force the longest sample at the beginning so potential OOMs happen right away - ix[0] = longest_seq_ix - - input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] - labels = [data[i]["labels"].type(torch.int64) for i in ix] - - # this could be `longest_seq_length` to have a fixed size for all batches - max_len = max(len(s) for s in input_ids) - - def pad_right(x, pad_id): - # pad right based on the longest sequence - n = max_len - len(x) - return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) - - x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) - y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) - - # Truncate if needed - if max_seq_length: - x = x[:, :max_seq_length] - y = y[:, :max_seq_length] - - if fabric.device.type == "cuda" and x.device.type == "cpu": - x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) - else: - x, y = fabric.to_device((x, y)) - return x, y - - def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) @@ -295,6 +263,17 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) +def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]: + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + with fabric.rank_zero_first(): + data.prepare_data() + data.setup() + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + return train_dataloader, val_dataloader + + def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] @@ -316,14 +295,16 @@ def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [ - (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), - (train, ["epoch_size", "epochs"]), + (io, ["checkpoint_dir"]), + (train, ["epochs"]), (eval, ["max_new_tokens"]), ] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") + if not train.epochs and not train.max_steps: + issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) diff --git a/finetune/full.py b/finetune/full.py index 470fbb96db..43f8947aaf 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -7,8 +7,10 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import lightning as L import torch +from torch.utils.data import DataLoader + +import lightning as L from lightning.fabric.loggers import CSVLogger from lightning.fabric.strategies import FSDPStrategy from torchmetrics import RunningMean @@ -21,6 +23,7 @@ from lit_gpt.args import EvalArgs, IOArgs, TrainArgs from lit_gpt.model import GPT, Block, Config from lit_gpt.tokenizer import Tokenizer +from lit_gpt.data import Alpaca, LitDataModule, apply_prompt_template from lit_gpt.utils import ( CLI, check_valid_checkpoint_dir, @@ -28,8 +31,8 @@ get_default_supported_precision, load_checkpoint, num_parameters, + CycleIterator, ) -from scripts.prepare_alpaca import generate_prompt def setup( @@ -37,11 +40,10 @@ def setup( devices: int = 1, resume: Union[bool, Path] = False, seed: int = 1337, + data: Optional[LitDataModule] = None, io: IOArgs = IOArgs( - train_data_dir=Path("data/alpaca"), - val_data_dir=Path("data/alpaca"), checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - out_dir=Path("out/full/alpaca"), + out_dir=Path("out/full"), ), train: TrainArgs = TrainArgs( save_interval=1000, @@ -50,13 +52,17 @@ def setup( micro_batch_size=1, lr_warmup_steps=100, epochs=5, - epoch_size=50000, learning_rate=3e-3, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), ) -> None: + print(locals()) + + if data is None: + data = Alpaca() + precision = precision or get_default_supported_precision(training=True) if devices > 1: @@ -72,7 +78,7 @@ def setup( logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval) fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) - fabric.launch(main, devices, resume, seed, Config.from_name(name=io.checkpoint_dir.name), io, train, eval) + fabric.launch(main, devices, resume, seed, Config.from_name(name=io.checkpoint_dir.name), data, io, train, eval) def main( @@ -81,25 +87,24 @@ def main( resume: Union[bool, Path], seed: int, config: Config, + data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs, ) -> None: validate_args(io, train, eval) - - steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) - lr_max_steps = train.epochs * steps_per_epoch - check_valid_checkpoint_dir(io.checkpoint_dir) + tokenizer = Tokenizer(io.checkpoint_dir) + train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) + fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(io.out_dir, exist_ok=True) - train_data = torch.load(io.train_data_dir / "train.pt") - val_data = torch.load(io.val_data_dir / "test.pt") - checkpoint_path = io.checkpoint_dir / "lit_model.pth" fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") with fabric.init_module(empty_init=(devices > 1)): @@ -123,10 +128,8 @@ def main( else: load_checkpoint(fabric, state["model"], checkpoint_path) - fabric.seed_everything(1337 + fabric.global_rank) - train_time = time.perf_counter() - fit(fabric, state, train_data, val_data, devices, resume, io, train, eval) + fit(fabric, state, train_dataloader, val_dataloader, devices, resume, io, train, eval) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -138,8 +141,8 @@ def main( def fit( fabric: L.Fabric, state: Dict, - train_data: List[Dict], - val_data: List[Dict], + train_dataloader: DataLoader, + val_dataloader: DataLoader, devices: int, resume: Union[bool, Path], io: IOArgs, @@ -150,21 +153,23 @@ def fit( optimizer = state["optimizer"] scheduler = state["scheduler"] tokenizer = Tokenizer(io.checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) - validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check + validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2)) # sanity check initial_iter = state["iter_num"] + max_steps = train.max_steps or float("inf") + train_iterator = CycleIterator(train_dataloader) # resume data loader state by fast-forwarding through all seen batches if resume: resume_t0 = time.perf_counter() for resume_iter in range(initial_iter): - get_batch(fabric, train_data, None) + next(train_iterator) if resume_iter % 1000 == 0: fabric.print(f"Resuming dataset: {resume_iter} / {initial_iter}") fabric.barrier() @@ -178,16 +183,11 @@ def fit( ) fabric.barrier() - for state["iter_num"] in range(state["iter_num"] + 1, train.max_iters(devices) + 1): + while state["step_count"] < max_steps and train_iterator.epoch < train.epochs: + state["iter_num"] += 1 iter_t0 = time.perf_counter() - - input_ids, targets = get_batch( - fabric, - train_data, - train.micro_batch_size, - train.max_seq_length, - longest_seq_ix if state["iter_num"] == 1 else None, - ) + batch = next(train_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): @@ -226,7 +226,7 @@ def fit( if not is_accumulating and state["step_count"] % eval.interval == 0: t0 = time.perf_counter() - val_loss = validate(fabric, model, val_data, tokenizer, eval, train) + val_loss = validate(fabric, model, val_dataloader, tokenizer, eval) t1 = time.perf_counter() - t0 fabric.print(f"iter {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms") metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} @@ -241,13 +241,15 @@ def fit( # FSDP has issues with `inference_mode` @torch.no_grad() def validate( - fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs + fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs ) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval.max_iters) + val_iterator = iter(val_dataloader) for k in range(eval.max_iters): - input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length) + batch = next(val_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() @@ -256,7 +258,7 @@ def validate( instruction = "Recommend a movie for me to watch during the weekend and explain the reason." fabric.print(instruction) sample = {"instruction": instruction, "input": ""} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here @@ -272,44 +274,6 @@ def validate( return val_loss -def get_batch( - fabric: L.Fabric, - data: List[Dict], - micro_batch_size: int, - max_seq_length: Optional[int], - longest_seq_ix: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - ix = torch.randint(len(data), (micro_batch_size,)) - if longest_seq_ix is not None: - # force the longest sample at the beginning so potential OOMs happen right away - ix[0] = longest_seq_ix - - input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] - labels = [data[i]["labels"].type(torch.int64) for i in ix] - - # this could be `longest_seq_length` to have a fixed size for all batches - max_len = max(len(s) for s in input_ids) - - def pad_right(x, pad_id): - # pad right based on the longest sequence - n = max_len - len(x) - return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) - - x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) - y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) - - # Truncate if needed - if max_seq_length: - x = x[:, :max_seq_length] - y = y[:, :max_seq_length] - - if fabric.device.type == "cuda" and x.device.type == "cpu": - x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) - else: - x, y = fabric.to_device((x, y)) - return x, y - - def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) @@ -317,6 +281,17 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) +def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]: + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + with fabric.rank_zero_first(): + data.prepare_data() + data.setup() + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + return train_dataloader, val_dataloader + + def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] @@ -333,14 +308,16 @@ def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [ - (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), - (train, ["epoch_size", "epochs"]), + (io, ["checkpoint_dir"]), + (train, ["epochs"]), (eval, ["max_new_tokens"]), ] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") + if not train.epochs and not train.max_steps: + issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) diff --git a/finetune/lora.py b/finetune/lora.py index a05e38d431..fc92b8e3cc 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -8,6 +8,7 @@ import lightning as L import torch +from torch.utils.data import DataLoader from lightning.fabric.loggers import CSVLogger from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy @@ -19,6 +20,7 @@ from generate.base import generate from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.data import LitDataModule, Alpaca, apply_prompt_template from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable from lit_gpt.tokenizer import Tokenizer from lit_gpt.utils import ( @@ -27,9 +29,8 @@ chunked_cross_entropy, get_default_supported_precision, load_checkpoint, - num_parameters, + num_parameters, CycleIterator, ) -from scripts.prepare_alpaca import generate_prompt def setup( @@ -46,11 +47,10 @@ def setup( lora_projection: bool = False, lora_mlp: bool = False, lora_head: bool = False, + data: Optional[LitDataModule] = None, io: IOArgs = IOArgs( - train_data_dir=Path("data/alpaca"), - val_data_dir=Path("data/alpaca"), checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - out_dir=Path("out/lora/alpaca"), + out_dir=Path("out/lora"), ), train: TrainArgs = TrainArgs( save_interval=1000, @@ -59,13 +59,16 @@ def setup( micro_batch_size=4, lr_warmup_steps=100, epochs=5, - epoch_size=50000, learning_rate=3e-4, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), ) -> None: + print(locals()) + if data is None: + data = Alpaca() + precision = precision or get_default_supported_precision(training=True) plugins = None @@ -113,28 +116,28 @@ def setup( to_mlp=lora_mlp, to_head=lora_head, ), + data, io, train, eval, ) -def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: +def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: validate_args(io, train, eval) - steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) - lr_max_steps = train.epochs * steps_per_epoch - check_valid_checkpoint_dir(io.checkpoint_dir) + tokenizer = Tokenizer(io.checkpoint_dir) + train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) + fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(io.out_dir, exist_ok=True) - train_data = torch.load(io.train_data_dir / "train.pt") - val_data = torch.load(io.val_data_dir / "test.pt") - checkpoint_path = io.checkpoint_dir / "lit_model.pth" fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") with fabric.init_module(empty_init=(devices > 1)): @@ -162,10 +165,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, # strict=False because missing keys due to LoRA weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) - fabric.seed_everything(1337 + fabric.global_rank) - train_time = time.perf_counter() - fit(fabric, model, optimizer, scheduler, train_data, val_data, devices, io, train, eval) + fit(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader, devices, io, train, eval) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -180,34 +181,37 @@ def fit( model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, - train_data: List[Dict], - val_data: List[Dict], + train_dataloader: DataLoader, + val_dataloader: DataLoader, devices: int, io: IOArgs, train: TrainArgs, eval: EvalArgs, ) -> None: tokenizer = Tokenizer(io.checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) - validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check + validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2)) # sanity check + train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) + max_steps = train.max_steps or float("inf") step_count = 0 + iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() - for iter_num in range(1, train.max_iters(devices) + 1): + while step_count < max_steps and train_iterator.epoch < train.epochs: + iter_num += 1 iter_t0 = time.perf_counter() - input_ids, targets = get_batch( - fabric, train_data, train.micro_batch_size, train.max_seq_length, longest_seq_ix if iter_num == 1 else None - ) + batch = next(train_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): @@ -238,7 +242,7 @@ def fit( if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() - val_loss = validate(fabric, model, val_data, tokenizer, eval, train) + val_loss = validate(fabric, model, val_dataloader, tokenizer, eval) t1 = time.perf_counter() - t0 fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms") fabric.barrier() @@ -250,13 +254,15 @@ def fit( # FSDP has issues with `inference_mode` @torch.no_grad() def validate( - fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs + fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, ) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval.max_iters) + val_iterator = iter(val_dataloader) for k in range(eval.max_iters): - input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length) + batch = next(val_iterator) + input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() @@ -265,7 +271,7 @@ def validate( instruction = "Recommend a movie for me to watch during the weekend and explain the reason." fabric.print(instruction) sample = {"instruction": instruction, "input": ""} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here @@ -281,44 +287,6 @@ def validate( return val_loss -def get_batch( - fabric: L.Fabric, - data: List[Dict], - micro_batch_size: int, - max_seq_length: Optional[int], - longest_seq_ix: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - ix = torch.randint(len(data), (micro_batch_size,)) - if longest_seq_ix is not None: - # force the longest sample at the beginning so potential OOMs happen right away - ix[0] = longest_seq_ix - - input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] - labels = [data[i]["labels"].type(torch.int64) for i in ix] - - # this could be `longest_seq_length` to have a fixed size for all batches - max_len = max(len(s) for s in input_ids) - - def pad_right(x, pad_id): - # pad right based on the longest sequence - n = max_len - len(x) - return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) - - x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) - y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) - - # Truncate if needed - if max_seq_length: - x = x[:, :max_seq_length] - y = y[:, :max_seq_length] - - if fabric.device.type == "cuda" and x.device.type == "cpu": - x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) - else: - x, y = fabric.to_device((x, y)) - return x, y - - def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) @@ -326,6 +294,17 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) +def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]: + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + with fabric.rank_zero_first(): + data.prepare_data() + data.setup() + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + return train_dataloader, val_dataloader + + def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] @@ -347,14 +326,16 @@ def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [ - (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), - (train, ["epoch_size", "epochs"]), + (io, ["checkpoint_dir"]), + (train, ["epochs"]), (eval, ["max_new_tokens"]), ] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") + if not train.epochs and not train.max_steps: + issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) diff --git a/generate/adapter.py b/generate/adapter.py index 15e5df514b..650a067fc6 100644 --- a/generate/adapter.py +++ b/generate/adapter.py @@ -17,7 +17,8 @@ from lit_gpt import Tokenizer from lit_gpt.adapter import GPT, Config from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load -from scripts.prepare_alpaca import generate_prompt +from lit_gpt.data import apply_prompt_template +from lit_gpt.data.alpaca import prompt_template def main( @@ -72,7 +73,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) sample = {"instruction": prompt, "input": input} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index c799a0eacf..3fb19d3a03 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -17,7 +17,8 @@ from lit_gpt import Tokenizer from lit_gpt.adapter_v2 import GPT, Config from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load -from scripts.prepare_alpaca import generate_prompt +from lit_gpt.data import apply_prompt_template +from lit_gpt.data.alpaca import prompt_template def main( @@ -72,7 +73,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) sample = {"instruction": prompt, "input": input} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens diff --git a/generate/full.py b/generate/full.py index ca1554e489..2a85377591 100644 --- a/generate/full.py +++ b/generate/full.py @@ -16,7 +16,8 @@ from generate.base import generate from lit_gpt import GPT, Config, Tokenizer from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint -from scripts.prepare_alpaca import generate_prompt +from lit_gpt.data import apply_prompt_template +from lit_gpt.data.alpaca import prompt_template def main( @@ -71,7 +72,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) sample = {"instruction": prompt, "input": input} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens diff --git a/generate/lora.py b/generate/lora.py index 006b75baa1..2a1962d359 100644 --- a/generate/lora.py +++ b/generate/lora.py @@ -17,7 +17,8 @@ from lit_gpt import Tokenizer from lit_gpt.lora import GPT, Config, merge_lora_weights from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load -from scripts.prepare_alpaca import generate_prompt +from lit_gpt.data import apply_prompt_template +from lit_gpt.data.alpaca import prompt_template def main( @@ -91,8 +92,9 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" tokenizer = Tokenizer(checkpoint_dir) + sample = {"instruction": prompt, "input": input} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens diff --git a/lit_gpt/args.py b/lit_gpt/args.py index 6221707662..bc0c003e86 100644 --- a/lit_gpt/args.py +++ b/lit_gpt/args.py @@ -19,11 +19,11 @@ class TrainArgs: """Number of iterations with learning rate warmup active""" epochs: Optional[int] = None """Number of epochs to run""" - epoch_size: Optional[int] = None - """Size of the epoch""" # TODO: pretrain/tinyllama is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs? max_tokens: Optional[int] = None """Total number of tokens to train on""" + max_steps: Optional[int] = None + """Limits the number of optimizer steps to run.""" max_seq_length: Optional[int] = None """Limits the length of samples. Off by default""" @@ -35,12 +35,6 @@ class TrainArgs: max_norm: Optional[float] = None min_lr: float = 6e-5 - def max_iters(self, devices: int) -> int: - """Number of iterations""" - max_iters = self.epochs * self.epoch_size // devices // self.micro_batch_size - assert max_iters > 0 - return max_iters - def gradient_accumulation_iters(self, devices: int) -> int: """Number of iterations between gradient synchronizations""" gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size @@ -70,12 +64,7 @@ class EvalArgs: class IOArgs: """Inputs and outputs related arguments""" - # Optional because pretrain/tinyllama hardcodes the path - train_data_dir: Optional[Path] = Path("data/alpaca") - """Where to read training data from""" - val_data_dir: Optional[Path] = None - """Where to read validation data from""" checkpoint_dir: Optional[Path] = None """Where to read weights and tokenizer data from""" - out_dir: Path = Path("out/adapter/alpaca") + out_dir: Path = Path("out") """Where to save artifacts""" diff --git a/lit_gpt/packed_dataset.py b/lit_gpt/packed_dataset.py deleted file mode 100644 index 2b5b3d6d1a..0000000000 --- a/lit_gpt/packed_dataset.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -# Very loosely inspired by indexed_dataset in Fairseq, Megatron -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py - - -import os -import random -import struct - -import numpy as np -import torch -from torch.utils.data import IterableDataset, get_worker_info - -dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} - - -def code(dtype): - for k in dtypes: - if dtypes[k] == dtype: - return k - raise ValueError(dtype) - - -HDR_MAGIC = b"LITPKDS" -HDR_SIZE = 24 # bytes - - -class PackedDataset(IterableDataset): - def __init__( - self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 - ): - self._filenames = filenames - self._n_chunks = n_chunks - self._block_size = block_size - self._seed = seed - self._shuffle = shuffle - self._wrap = wrap - self._num_processes = num_processes - self._process_rank = process_rank - - def __iter__(self): - worker_info = get_worker_info() - num_workers = worker_info.num_workers if worker_info is not None else 1 - worker_id = worker_info.id if worker_info is not None else 0 - num_shards = num_workers * self._num_processes - shard_id = self._process_rank * num_workers + worker_id - - max_num_files = len(self._filenames) // num_shards * num_shards - filenames = self._filenames[shard_id:max_num_files:num_shards] - - return PackedDatasetIterator( - filenames=filenames, - n_chunks=self._n_chunks, - block_size=self._block_size, - seed=self._seed, - shuffle=self._shuffle, - wrap=self._wrap, - ) - - -class PackedDatasetBuilder(object): - def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): - if dtype == "auto": - if vocab_size is None: - raise ValueError("vocab_size cannot be None when dtype='auto'") - if vocab_size is not None and vocab_size < 65500: - self._dtype = np.uint16 - else: - self._dtype = np.int32 - else: - self._dtype = dtype - self._counter = 0 - self._chunk_size = chunk_size - self._outdir = outdir - self._prefix = prefix - self._sep_token = sep_token - self._arr = np.zeros(self._chunk_size, dtype=self._dtype) - self._arr.fill(self._sep_token) - self._idx = 0 - self._version = 1 - self._filenames = [] - - def _write_chunk(self): - filename = f"{self._prefix}_{self._counter:010d}.bin" - filename = os.path.join(self._outdir, filename) - - with open(filename, "wb") as f: - f.write(HDR_MAGIC) - f.write(struct.pack(" self._chunk_size: - part_len = self._chunk_size - self._idx - self._arr[self._idx : self._idx + part_len] = arr[:part_len] - self._write_chunk() - arr = arr[part_len:] - - arr_len = arr.shape[0] - self._arr[self._idx : self._idx + arr_len] = arr - self._idx += arr_len - - def write_reminder(self): - self._write_chunk() - - -class PackedDatasetIterator: - def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): - self._seed = seed - self._shuffle = shuffle - self._rng = np.random.default_rng(seed) if shuffle else None - self._block_idxs = None - - self._wrap = wrap - - # TODO: instead of filenames, we could have a single text stream - # (or text file) with the sequence of all files to be - # fetched/loaded. - self._filenames = filenames - self._file_idx = 0 - - self._n_chunks = n_chunks - - self._dtype = None - self._block_size = block_size - self._n_blocks = None - - self._mmaps = [] - self._buffers = [] - - self._block_idxs = [] - self._curr_idx = 0 - - self._load_n_chunks() - - def _read_header(self, path): - with open(path, "rb") as f: - magic = f.read(len(HDR_MAGIC)) - assert magic == HDR_MAGIC, "File doesn't match expected format." - version = struct.unpack(" len(self._filenames[self._file_idx :]): - if not self._wrap: - raise StopIteration - self._file_idx = 0 - - for i in range(self._n_chunks): - filename = self._filenames[self._file_idx + i] - if self._dtype is None: - self._dtype, self._chunk_size = self._read_header(filename) - self._n_blocks = self._chunk_size // self._block_size - # TODO: check header matches with previous files - mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) - self._mmaps.append(mmap) - self._buffers.append(memoryview(mmap)) - - self._file_idx += self._n_chunks - n_all_blocks = self._n_chunks * self._n_blocks - - self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) - - self._curr_idx = 0 - - def __del__(self): - self._close_mmaps() - del self._mmaps - del self._buffers - - def __iter__(self): - return self - - def __next__(self): - if self._curr_idx >= len(self._block_idxs): - self._load_n_chunks() - # TODO: trigger fetching next next n_chunks if remote - block_idx = self._block_idxs[self._curr_idx] - chunk_id = block_idx // self._n_blocks - buffer = self._buffers[chunk_id] - elem_id = (block_idx % self._n_blocks) * self._block_size - offset = np.dtype(self._dtype).itemsize * elem_id - arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) - self._curr_idx += 1 - return torch.from_numpy(arr.astype(np.int64)) - - -class CombinedDataset(IterableDataset): - def __init__(self, datasets, seed, weights=None): - self._seed = seed - self._datasets = datasets - self._weights = weights - n_datasets = len(datasets) - if weights is None: - self._weights = [1 / n_datasets] * n_datasets - else: - self._weights = [w / sum(weights) for w in weights] - - def __iter__(self): - return CombinedDatasetIterator(self._datasets, self._seed, self._weights) - - -class CombinedDatasetIterator: - def __init__(self, datasets, seed, weights): - self._datasets = [iter(el) for el in datasets] - self._weights = weights - self._rng = random.Random(seed) - - def __next__(self): - (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) - return next(dataset) diff --git a/pretrain/openwebtext.py b/pretrain/openwebtext.py index 653b21fa66..734557895b 100644 --- a/pretrain/openwebtext.py +++ b/pretrain/openwebtext.py @@ -30,7 +30,7 @@ def setup( resume: Union[bool, Path] = False, seed: int = 1337, devices: int = 1, - io: IOArgs = IOArgs(train_data_dir=Path("data/openwebtext"), val_data_dir=None, out_dir=Path("out/openwebtext")), + io: IOArgs = IOArgs(out_dir=Path("out/openwebtext")), train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, @@ -38,7 +38,6 @@ def setup( micro_batch_size=5, lr_warmup_steps=100, epochs=1, - epoch_size=600000, learning_rate=6e-4, weight_decay=1e-1, beta1=0.9, diff --git a/pretrain/redpajama.py b/pretrain/redpajama.py deleted file mode 100644 index 31e543190e..0000000000 --- a/pretrain/redpajama.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -import math -import sys -import time -from pathlib import Path -from typing import Optional, Tuple, Union - -import lightning as L -import torch -from lightning.fabric.loggers import CSVLogger -from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities import ThroughputMonitor, measure_flops -from torch.utils.data import DataLoader - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_gpt import Config -from lit_gpt.args import EvalArgs, IOArgs, TrainArgs -from lit_gpt.model import GPT, Block -from lit_gpt.packed_dataset import CombinedDataset, PackedDataset -from lit_gpt.utils import CLI, chunked_cross_entropy, estimate_flops, get_default_supported_precision, num_parameters - -# Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1 -data_config = [ - ("arxiv", 2.5), - ("book", 4.5), - ("c4", 15.0), - ("cc", 67.0), - ("github", 4.5), - ("stackexchange", 2.0), - ("wikipedia", 4.5), -] - - -def setup( - model_name: str = "Llama-2-7b-hf", - val_data_dir: Optional[Path] = None, - precision: Optional[str] = None, - resume: Union[bool, Path] = False, - seed: int = 1337, - devices: int = 4, - io: IOArgs = IOArgs(train_data_dir=Path("data/redpajama_sample"), val_data_dir=None, out_dir=Path("out/redpajama")), - train: TrainArgs = TrainArgs( - save_interval=1000, - log_interval=1, - global_batch_size=125, - micro_batch_size=6, - lr_warmup_steps=100, - epochs=1, - epoch_size=600000, - learning_rate=6e-4, - weight_decay=1e-1, - beta1=0.9, - beta2=0.95, - max_norm=1.0, - min_lr=6e-5, - ), - eval: EvalArgs = EvalArgs(interval=1000, max_iters=100), -) -> None: - print(locals()) - precision = precision or get_default_supported_precision(training=True) - - if devices > 1: - strategy = FSDPStrategy( - auto_wrap_policy={Block}, - activation_checkpointing_policy={Block}, - state_dict_type="full", - limit_all_gathers=True, - cpu_offload=False, - ) - else: - strategy = "auto" - - logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval) - fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) - - fabric.launch(main, devices, resume, seed, Config.from_name(name=model_name), io, train, eval) - - -def main( - fabric: L.Fabric, - devices: int, - resume: Union[bool, Path], - seed: int, - config: Config, - io: IOArgs, - train: TrainArgs, - eval: EvalArgs, -) -> None: - validate_args(io, train, eval) - - if fabric.global_rank == 0: - io.out_dir.mkdir(parents=True, exist_ok=True) - - train_dataloader, val_dataloader = create_dataloaders( - batch_size=train.micro_batch_size, - block_size=config.block_size, - fabric=fabric, - train_data_dir=io.train_data_dir, - val_data_dir=io.val_data_dir, - seed=(seed + fabric.global_rank), - ) - if val_dataloader is None: - train_dataloader = fabric.setup_dataloaders(train_dataloader) - else: - train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) - - fabric.seed_everything(seed) # same seed for every process to init model (FSDP) - - fabric.print(f"Loading model with {config.__dict__}") - t0 = time.perf_counter() - with fabric.init_module(empty_init=(fabric.world_size > 1)): - model = GPT(config) - model.apply(model._init_weights) - - fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - fabric.print(f"Total parameters {num_parameters(model):,}") - - model = fabric.setup(model) - optimizer = torch.optim.AdamW( - model.parameters(), - lr=train.learning_rate, - weight_decay=train.weight_decay, - betas=(train.beta1, train.beta2), - foreach=False, - ) - optimizer = fabric.setup_optimizers(optimizer) - - state = {"model": model, "optimizer": optimizer, "iter_num": 0, "step_count": 0} - - if resume is True: - resume = max(io.out_dir.glob("*.pth"), key=lambda p: int(p.name.split("-")[1])) - if resume: - fabric.print(f"Resuming training from {resume}") - fabric.load(resume, state) - - train_time = time.perf_counter() - fit(fabric, devices, state, train_dataloader, val_dataloader, io, train, eval) - fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") - if fabric.device.type == "cuda": - fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -def fit( - fabric: L.Fabric, - devices: int, - state: dict, - train_dataloader: DataLoader, - val_dataloader: Optional[DataLoader], - io: IOArgs, - train: TrainArgs, - eval: EvalArgs, -) -> None: - model = state["model"] - optimizer = state["optimizer"] - - if val_dataloader is not None: - validate(fabric, model, val_dataloader, max_iters=2) # sanity check - - with torch.device("meta"): - meta_model = GPT(model.config) - # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. - # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, - # consider passing `flops_per_batch=estimated_flops` instead - estimated_flops = estimate_flops(meta_model, training=True) * train.micro_batch_size - fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") - x = torch.randint(0, 1, (train.micro_batch_size, model.max_seq_length)) - forward_fn = lambda: meta_model(x) - loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0) - measured_flops = measure_flops(meta_model, forward_fn, loss_fn) - fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") - del meta_model, x - - throughput = ThroughputMonitor(fabric, window_size=50) - total_t0 = time.perf_counter() - - lr_warmup_iters = train.lr_warmup_steps * train.gradient_accumulation_iters(devices) - for state["iter_num"], train_data in enumerate(train_dataloader, state["iter_num"]): - if state["iter_num"] >= train.max_iters(devices): - break - - # determine and set the learning rate for this iteration - lr = get_lr( - train.learning_rate, state["iter_num"], lr_warmup_iters, train.max_iters(devices), min_lr=train.min_lr - ) - for param_group in optimizer.param_groups: - param_group["lr"] = lr - - iter_num = state["iter_num"] + 1 - iter_t0 = time.perf_counter() - - input_ids = train_data[:, 0 : model.max_seq_length].contiguous() - targets = train_data[:, 1 : model.max_seq_length + 1].contiguous() - - is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 - with fabric.no_backward_sync(model, enabled=is_accumulating): - logits = model(input_ids) - loss = chunked_cross_entropy(logits, targets, chunk_size=0) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) - - if not is_accumulating: - fabric.clip_gradients(model, optimizer, max_norm=train.max_norm) - optimizer.step() - optimizer.zero_grad() - state["step_count"] += 1 - - if iter_num % train.log_interval == 0: - loss_item = loss.item() # expensive device-to-host synchronization - t1 = time.perf_counter() - throughput.update( - time=t1 - total_t0, - batches=iter_num, - samples=iter_num * train.micro_batch_size, - lengths=iter_num * train.micro_batch_size * model.max_seq_length, - flops=measured_flops * train.log_interval, - ) - throughput.compute_and_log(step=iter_num) - fabric.print( - f"iter {iter_num} step {state['step_count']}: loss {loss_item:.4f}, iter time:" - f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" - ) - - if val_dataloader is not None and not is_accumulating and state["step_count"] % eval.interval == 0: - t0 = time.perf_counter() - val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters) - t1 = time.perf_counter() - t0 - fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms") - fabric.barrier() - if not is_accumulating and state["step_count"] % train.save_interval == 0: - checkpoint_path = io.out_dir / f"iter-{iter_num:06d}-ckpt.pth" - fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") - fabric.save(checkpoint_path, state) - - -# FSDP has issues with `inference_mode` -@torch.no_grad() -def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor: - fabric.print("Validating ...") - model.eval() - - losses = torch.zeros(max_iters, device=fabric.device) - for k, val_data in enumerate(val_dataloader): - if k >= max_iters: - break - input_ids = val_data[:, 0 : model.max_seq_length].contiguous() - targets = val_data[:, 1 : model.max_seq_length + 1].contiguous() - logits = model(input_ids) - losses[k] = chunked_cross_entropy(logits, targets, chunk_size=0) - out = losses.mean() - - model.train() - return out - - -def create_dataloader( - batch_size: int, block_size: int, data_dir: Path, fabric: L.Fabric, shuffle: bool = True, seed: int = 12345 -) -> DataLoader: - datasets = [] - for prefix, _ in data_config: - filenames = list(data_dir.glob(f"{prefix}*")) - if not filenames: - raise FileNotFoundError( - f"No files found at {str(data_dir)} with prefix {prefix}. Did you forget to run `prepare_redpajama.py`?" - ) - dataset = PackedDataset( - filenames, - n_chunks=4, - block_size=block_size, - shuffle=shuffle, - seed=seed, - num_processes=fabric.world_size, - process_rank=fabric.global_rank, - ) - datasets.append(dataset) - - if not datasets: - raise RuntimeError( - f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." - ) - - weights = [weight for _, weight in data_config] - sum_weights = sum(weights) - weights = [el / sum_weights for el in weights] - - combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) - - return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) - - -def create_dataloaders( - batch_size: int, - block_size: int, - fabric: L.Fabric, - train_data_dir: Path = Path("data/redpajama_sample"), - val_data_dir: Optional[Path] = None, - seed: int = 12345, -) -> Tuple[DataLoader, Optional[DataLoader]]: - # Increase by one because we need the next word as well - effective_block_size = block_size + 1 - train_dataloader = create_dataloader( - batch_size=batch_size, - block_size=effective_block_size, - fabric=fabric, - data_dir=train_data_dir, - shuffle=True, - seed=seed, - ) - val_dataloader = ( - create_dataloader( - batch_size=batch_size, - block_size=effective_block_size, - fabric=fabric, - data_dir=val_data_dir, - shuffle=False, - seed=seed, - ) - if val_data_dir - else None - ) - return train_dataloader, val_dataloader - - -# learning rate decay scheduler (cosine with linear warmup) -def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float: - # 1) linear warmup for warmup_iters steps - if it < warmup_iters: - return learning_rate * it / warmup_iters - # 2) if it > max_iters, return min learning rate - if it > max_iters: - return min_lr - # 3) in between, use cosine decay down to min learning rate - decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) - assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 - return min_lr + coeff * (learning_rate - min_lr) - - -def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: - issues = [] - unsupported = [(io, ["checkpoint_dir"]), (train, ["max_tokens"]), (eval, ["max_new_tokens"])] - for args, names in unsupported: - for name in names: - if getattr(args, name) is not None: - issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") - required = [(io, ["train_data_dir"]), (train, ["epoch_size", "epochs", "max_norm"])] - for args, names in required: - for name in names: - if getattr(args, name) is None: - issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") - if issues: - raise ValueError("\n".join(issues)) - - -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - - CLI(setup) diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py index 7da9dae918..45be182051 100644 --- a/pretrain/tinyllama.py +++ b/pretrain/tinyllama.py @@ -43,7 +43,7 @@ def setup( seed: int = 1337, data: LitDataModule = TinyLlama(), io: IOArgs = IOArgs( - out_dir=Path(os.getenv("LIGHTNING_ARTIFACTS_DIR", "out")) / "lit-tiny-llama-1.1b", train_data_dir=None + out_dir=Path(os.getenv("LIGHTNING_ARTIFACTS_DIR", "out")) / "lit-tiny-llama-1.1b", ), train: TrainArgs = TrainArgs( save_interval=1000, @@ -328,8 +328,8 @@ def choose_logger(out_dir: Path, logger_name: str, name: str, resume: Union[bool def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: issues = [] unsupported = [ - (io, ["train_data_dir", "val_data_dir", "checkpoint_dir"]), - (train, ["epoch_size", "epochs"]), + (io, ["checkpoint_dir"]), + (train, ["max_steps", "epochs"]), (eval, ["max_new_tokens"]), ] for args, names in unsupported: diff --git a/scripts/prepare_alpaca.py b/scripts/prepare_alpaca.py deleted file mode 100644 index 61ca7bf3b5..0000000000 --- a/scripts/prepare_alpaca.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -"""Implementation derived from https://github.com/tloen/alpaca-lora""" - -import json -import sys -from pathlib import Path -from typing import Optional - -import torch -from lightning_utilities.core.imports import RequirementCache -from torch.utils.data import random_split -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_gpt.tokenizer import Tokenizer -from lit_gpt.utils import CLI - - -def prepare( - destination_path: Path = Path("data/alpaca"), - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - test_split_fraction: float = 0.03865, # to get exactly 2000 test samples, - seed: int = 42, - mask_inputs: bool = False, # as in alpaca-lora - data_file_name: str = "alpaca_data_cleaned_archive.json", - data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", - ignore_index: int = -1, - max_seq_length: Optional[int] = None, -) -> None: - """Prepare the Alpaca dataset for instruction tuning. - - The output is a training and test dataset saved as `train.pt` and `test.pt`, - which stores the preprocessed and tokenized prompts and labels. - """ - if max_seq_length is None: - with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file: - config = json.load(file) - max_seq_length = config["block_size"] - - destination_path.mkdir(parents=True, exist_ok=True) - data_file_path = destination_path / data_file_name - print("Loading data file...") - download_if_missing(data_file_path, data_file_url) - with open(data_file_path, "r", encoding="utf-8") as file: - data = json.load(file) - - print("Loading tokenizer...") - tokenizer = Tokenizer(checkpoint_dir) - - # Partition the dataset into train and test - train_set, test_set = random_split( - data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed) - ) - train_set, test_set = list(train_set), list(test_set) - - print(f"train has {len(train_set):,} samples") - print(f"test has {len(test_set):,} samples") - - print("Processing train split ...") - train_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(train_set) - ] - torch.save(train_set, destination_path / "train.pt") - - print("Processing test split ...") - test_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(test_set) - ] - torch.save(test_set, destination_path / "test.pt") - - -def download_if_missing(file_path: Path, file_url: str) -> None: - """Downloads the raw json data file and saves it in the given destination.""" - if file_path.exists() and file_path.stat().st_size > 0: - return - requests_available = RequirementCache("requests") - if not requests_available: - raise ModuleNotFoundError(str(requests_available)) - import requests - - with open(file_path, "w", encoding="utf-8") as f: - f.write(requests.get(file_url).text) - - -def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: - """Processes a single sample. - - Each sample in the dataset consists of: - - instruction: A string describing the task - - input: A string holding a special input value for the instruction. - This only applies to some samples, and in others this is empty. - - output: The response string - - This function processes this data to produce a prompt text and a label for - supervised training. The prompt text is formed as a single message including both - the instruction and the input. The label/target is the same message but with the - response attached. - - Finally, both the prompt and the label get tokenized. If desired, all tokens - in the label that correspond to the original input prompt get masked out (default). - """ - full_prompt = generate_prompt(example) - full_prompt_and_response = full_prompt + example["output"] - encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) - encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - - return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} - - -def generate_prompt(example: dict) -> str: - """Generates a standardized message to prompt the model with an instruction, optional input and a - 'response' field.""" - - if example["input"]: - return ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" - ) - return ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Response:" - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/scripts/prepare_csv.py b/scripts/prepare_csv.py deleted file mode 100644 index 89dd43f911..0000000000 --- a/scripts/prepare_csv.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -import json -import logging -import sys -from pathlib import Path -from typing import Optional, Tuple - -import torch -from torch.utils.data import random_split -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -logger = logging.getLogger(__name__) -sys.path.append(str(wd)) - -from lit_gpt.tokenizer import Tokenizer -from lit_gpt.utils import CLI - - -def prepare( - csv_path: Path, - destination_path: Path = Path("data/csv"), - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - test_split_fraction: float = 0.1, - seed: int = 42, - mask_inputs: bool = False, - ignore_index: int = -1, - max_seq_length: Optional[int] = None, - columns: Tuple[str, ...] = ("instruction", "input", "output"), -) -> None: - """Prepare a CSV dataset for instruction tuning. - - The output is a training and test dataset saved as `train.pt` and `test.pt`, - which stores the preprocessed and tokenized prompts and labels. - """ - if max_seq_length is None: - with open(checkpoint_dir / "lit_config.json", "r") as file: - config = json.load(file) - max_seq_length = config["block_size"] - - destination_path.mkdir(parents=True, exist_ok=True) - logger.info("Loading data file ...") - import pandas as pd - - df = pd.read_csv(csv_path, dtype=str).fillna("") - if not (df.columns.values == columns).all(): - raise ValueError(f"CSV columns must be {columns}, found {df.columns.values}") - data = json.loads(df.to_json(orient="records", indent=4)) - - print("Loading tokenizer...") - tokenizer = Tokenizer(checkpoint_dir) - - # Partition the dataset into train and test - train_set, test_set = random_split( - data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed) - ) - train_set, test_set = list(train_set), list(test_set) - - print(f"train has {len(train_set):,} samples") - print(f"test has {len(test_set):,} samples") - - print("Processing train split ...") - train_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(train_set) - ] - torch.save(train_set, destination_path / "train.pt") - - print("Processing test split ...") - test_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(test_set) - ] - torch.save(test_set, destination_path / "test.pt") - - -def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: - """Processes a single sample. - - Each sample in the dataset consists of: - - instruction: A string describing the task - - input: A string holding a special input value for the instruction. - This only applies to some samples, and in others this is empty. - - output: The response string - - This function processes this data to produce a prompt text and a label for - supervised training. The prompt text is formed as a single message including both - the instruction and the input. The label/target is the same message but with the - response attached. - - Finally, both the prompt and the label get tokenized. If desired, all tokens - in the label that correspond to the original input prompt get masked out (default). - """ - full_prompt = generate_prompt(example) - full_prompt_and_response = full_prompt + example["output"] - encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) - encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - - return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} - - -def generate_prompt(example: dict) -> str: - """Generates a standardized message to prompt the model with an instruction, optional input and a - 'response' field.""" - - if example["input"]: - return ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" - ) - return ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Response:" - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/scripts/prepare_dolly.py b/scripts/prepare_dolly.py deleted file mode 100644 index 56da37ce5a..0000000000 --- a/scripts/prepare_dolly.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -"""Implementation derived from https://github.com/tloen/alpaca-lora""" - -import json -import sys -from pathlib import Path -from typing import Optional - -import torch -from torch.utils.data import random_split -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_gpt.tokenizer import Tokenizer -from lit_gpt.utils import CLI -from scripts.prepare_alpaca import download_if_missing - - -def prepare( - destination_path: Path = Path("data/dolly"), - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - test_split_fraction: float = 0.1, - seed: int = 42, - mask_inputs: bool = False, - data_file_name: str = "dolly_data_cleaned.json", - data_file_url: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl", - ignore_index: int = -1, - max_seq_length: Optional[int] = None, -) -> None: - """Prepare the Dolly 15k dataset for instruction tuning. - - The output is a training and test dataset saved as `train.pt` and `test.pt`, - which stores the preprocessed and tokenized prompts and labels. - """ - - if max_seq_length is None: - with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file: - config = json.load(file) - max_seq_length = config["block_size"] - - destination_path.mkdir(parents=True, exist_ok=True) - data_file_path = destination_path / data_file_name - print("Loading data file...") - download_if_missing(data_file_path, data_file_url) - - with open(data_file_path, "r", encoding="utf-8") as file: - data = file.readlines() - data = [json.loads(line) for line in data] - for item in data: - item["input"] = item.pop("context") - item["output"] = item.pop("response") - - print("Loading tokenizer...") - tokenizer = Tokenizer(checkpoint_dir) - - # Partition the dataset into train and test - train_set, test_set = random_split( - data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed) - ) - train_set, test_set = list(train_set), list(test_set) - - print(f"train has {len(train_set):,} samples") - print(f"test has {len(test_set):,} samples") - - print("Processing train split ...") - train_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(train_set) - ] - torch.save(train_set, destination_path / "train.pt") - - print("Processing test split ...") - test_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(test_set) - ] - torch.save(test_set, destination_path / "test.pt") - - -def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: - """Processes a single sample. - - Each sample in the dataset consists of: - - instruction: A string describing the task - - input: A string holding a special input value for the instruction. - This only applies to some samples, and in others this is empty. - - output: The response string - - This function processes this data to produce a prompt text and a label for - supervised training. The prompt text is formed as a single message including both - the instruction and the input. The label/target is the same message but with the - response attached. - - Finally, both the prompt and the label get tokenized. If desired, all tokens - in the label that correspond to the original input prompt get masked out (default). - """ - full_prompt = generate_prompt(example) - full_prompt_and_response = full_prompt + example["output"] - encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) - encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - - return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} - - -def generate_prompt(example: dict) -> str: - """Generates a standardized message to prompt the model with an instruction, optional input and a - 'response' field.""" - - if example["input"]: - return ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" - ) - return ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Response:" - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/scripts/prepare_flan.py b/scripts/prepare_flan.py deleted file mode 100644 index 59d3a7fae0..0000000000 --- a/scripts/prepare_flan.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -"""Implementation derived from https://github.com/tloen/alpaca-lora""" -import json -import sys -from pathlib import Path -from typing import Optional - -import torch -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_gpt.tokenizer import Tokenizer -from lit_gpt.utils import CLI -from scripts.prepare_alpaca import download_if_missing - - -def load_jsonl(filename): - data = [] - with open(filename, "r", encoding="utf-8") as f: - for line in f: - data.append(json.loads(line)) - return data - - -def prepare( - destination_path: Path = Path("data/flan"), - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - mask_inputs: bool = False, # as in alpaca-lora - subsets: Optional[str] = None, - ignore_index: int = -1, - max_seq_length: Optional[int] = None, -) -> None: - """Prepare the FLAN-collection datasets for instruction tuning. - - The output is a training and test dataset saved as `train.pt` and `test.pt`, - which stores the preprocessed and tokenized prompts and labels. - - Since the original test set does not have responses, the validation set - is used as the test set. - """ - - supported_subsets = { - "aeslc_10templates", - "ag_news_subset_10templates", - "anli_r1_10templates", - "anli_r2_10templates", - "anli_r3_10templates", - "arc_challenge_10templates", - "arc_easy_10templates", - "bool_q_10templates", - "cb_10templates", - "cnn_dailymail_10templates", - "cola_10templates", - "common_gen_10templates", - "copa_10templates", - "coqa_10templates", - "cosmos_qa_10templates", - "dart_10templates", - "definite_pronoun_resolution_10templates", - "drop_10templates", - "e2e_nlg_10templates", - "fix_punct_10templates", - "gigaword_10templates", - "glue_mrpc_10templates", - "glue_qqp_10templates", - "hellaswag_10templates", - "imdb_reviews_10templates", - "math_dataset_10templates", - "mnli_matched_10templates", - "mnli_mismatched_10templates", - "multi_news_10templates", - "multirc_10templates", - "natural_questions_10templates", - "openbookqa_10templates", - "opinion_abstracts_idebate_10templates", - "opinion_abstracts_rotten_tomatoes_10templates", - "para_crawl_enes_10templates", - "paws_wiki_10templates", - "piqa_10templates", - "qnli_10templates", - "quac_10templates", - "record_10templates", - "rte_10templates", - "samsum_10templates", - "sentiment140_10templates", - "snli_10templates", - "squad_v1_10templates", - "squad_v2_10templates", - "sst2_10templates", - "story_cloze_10templates", - "stsb_10templates", - "trec_10templates", - "trivia_qa_10templates", - "true_case_10templates", - "web_nlg_en_10templates", - "wic_10templates", - "wiki_lingua_english_en_10templates", - "wmt14_enfr_10templates", - "wmt16_translate_csen_10templates", - "wmt16_translate_deen_10templates", - "wmt16_translate_fien_10templates", - "wmt16_translate_roen_10templates", - "wmt16_translate_ruen_10templates", - "wmt16_translate_tren_10templates", - "wnli_10templates", - "word_segment_10templates", - "wsc_10templates", - "yelp_polarity_reviews_10templates", - } - - if subsets is not None: - subsets = subsets.split(",") - for sub in subsets: - if sub not in supported_subsets: - raise ValueError(f"{sub} not in {supported_subsets}") - else: - subsets = list(supported_subsets) - - if max_seq_length is None: - with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file: - config = json.load(file) - max_seq_length = config["block_size"] - - destination_path.mkdir(parents=True, exist_ok=True) - print("Loading data file...") - - base_url = "https://huggingface.co/datasets/Muennighoff/flan/resolve/main/" - - train_set, test_set = [], [] - for sub in subsets: - train_sub = sub + "_train" - data_file_name = train_sub + ".jsonl" - data_file_path = destination_path / data_file_name - data_file_url = base_url + "train/" + data_file_name - - print(f"Loading training data file {sub}...") - download_if_missing(data_file_path, data_file_url) - sub_train_set = load_jsonl(data_file_path) - train_set.extend(sub_train_set) - - test_sub = sub + "_test" - data_file_name = test_sub + ".jsonl" - data_file_path = destination_path / data_file_name - data_file_url = base_url + "test/" + data_file_name - - print(f"Loading test data file {sub}...") - download_if_missing(data_file_path, data_file_url) - sub_test_set = load_jsonl(data_file_path) - test_set.extend(sub_test_set) - - print("Loading tokenizer...") - tokenizer = Tokenizer(checkpoint_dir) - - train_set, test_set = list(train_set), list(test_set) - - print(f"train has {len(train_set):,} samples") - print(f"test has {len(test_set):,} samples") - - print("Processing train split ...") - train_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(train_set) - ] - torch.save(train_set, destination_path / "train.pt") - - print("Processing test split ...") - test_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(test_set) - ] - torch.save(test_set, destination_path / "test.pt") - - -def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int): - """Processes a single sample. - - Each sample in the dataset consists of: - - instruction: A string describing the task - - input: A string holding a special input value for the instruction. - This only applies to some samples, and in others this is empty. - - output: The response string - - This function processes this data to produce a prompt text and a label for - supervised training. The prompt text is formed as a single message including both - the instruction and the input. The label/target is the same message but with the - response attached. - - Finally, both the prompt and the label get tokenized. If desired, all tokens - in the label that correspond to the original input prompt get masked out (default). - """ - full_prompt = generate_prompt(example) - full_prompt_and_response = full_prompt + example["targets"] - encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) - encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - - return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} - - -def generate_prompt(example): - """Generates a standardized message to prompt the model with an instruction, optional input and a - 'response' field.""" - - return ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['inputs']}\n\n### Response:" - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/scripts/prepare_lima.py b/scripts/prepare_lima.py deleted file mode 100644 index ca35e62be6..0000000000 --- a/scripts/prepare_lima.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -"""Implementation derived from https://github.com/tloen/alpaca-lora""" - -import json -import os -import sys -from pathlib import Path -from typing import List, Optional - -import torch -from torch.utils.data import random_split -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_gpt.tokenizer import Tokenizer -from lit_gpt.utils import CLI - - -def prepare( - destination_path: Path = Path("data/lima"), - test_split_fraction: float = 0.1, - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - mask_inputs: bool = False, # as in alpaca-lora - seed: int = 42, - include_multiturn_conversations: bool = False, - data_repo_id: str = "GAIR/lima", - ignore_index: int = -1, - access_token: Optional[str] = os.getenv("HF_TOKEN"), - max_seq_length: Optional[int] = None, -) -> None: - """Prepare the LIMA dataset for instruction tuning. - - The output is a training and test dataset saved as `train.pt` and `test.pt`, - which stores the preprocessed and tokenized prompts and labels. - """ - - if access_token is None: - raise ValueError( - "LIMA requires authentication, please set the `HF_TOKEN=your_token` environment" - " variable or pass --access_token=your_token. You can find your token by visiting" - " https://huggingface.co/settings/tokens" - ) - - if max_seq_length is None: - with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file: - config = json.load(file) - max_seq_length = config["block_size"] - - destination_path.mkdir(parents=True, exist_ok=True) - print("Loading data file...") - - from datasets import load_dataset - - dataset = load_dataset(data_repo_id, token=access_token) - train_data = format_dataset(dataset["train"], include_multiturn_conversations) - - # test set is present but doesn't have any solutions, so we cannot use it here - # but have to create our own - # for consistency with prepare_alpaca.py and prepare_dolly.py - # test_set = format_dataset(dataset["test"], include_multiturn_conversations) - - print("Loading tokenizer...") - tokenizer = Tokenizer(checkpoint_dir) - - # Partition the dataset into train and test - train_set, test_set = random_split( - train_data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed) - ) - train_set, test_set = list(train_set), list(test_set) - - print(f"train has {len(train_set):,} samples") - print(f"test has {len(test_set):,} samples") - - print("Processing train split ...") - train_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(train_set) - ] - torch.save(train_set, destination_path / "train.pt") - - print("Processing test split ...") - test_set = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(test_set) - ] - torch.save(test_set, destination_path / "test.pt") - - -def format_dataset(dataset_partition: dict, include_multi_turn_conversations: bool) -> List[dict]: - formatted_ds = [] - - for entry in dataset_partition: - convo = entry["conversations"] - if include_multi_turn_conversations: - for i in range(0, len(convo) - 1, 2): - formatted_ds.append({"instruction": convo[i], "input": "", "output": convo[i + 1]}) - - else: - formatted_ds.append({"instruction": convo[0], "input": "", "output": convo[1]}) - - return formatted_ds - - -def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: - """Processes a single sample. - - Each sample in the dataset consists of: - - instruction: A string describing the task - - input: A string holding a special input value for the instruction. - This only applies to some samples, and in others this is empty. - - output: The response string - - This function processes this data to produce a prompt text and a label for - supervised training. The prompt text is formed as a single message including both - the instruction and the input. The label/target is the same message but with the - response attached. - - Finally, both the prompt and the label get tokenized. If desired, all tokens - in the label that correspond to the original input prompt get masked out (default). - """ - full_prompt = generate_prompt(example) - full_prompt_and_response = full_prompt + example["output"] - encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) - encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - - return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} - - -def generate_prompt(example: dict) -> str: - """Generates a standardized message to prompt the model with an instruction, optional input and a - 'response' field.""" - - if example["input"]: - return ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" - ) - return ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['instruction']}\n\n### Response:" - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/scripts/prepare_longform.py b/scripts/prepare_longform.py deleted file mode 100644 index 2a46e7dd51..0000000000 --- a/scripts/prepare_longform.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -"""Implementation derived from https://github.com/tloen/alpaca-lora""" - -import json -import sys -from pathlib import Path -from typing import Optional - -import torch -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_gpt.tokenizer import Tokenizer -from lit_gpt.utils import CLI -from scripts.prepare_alpaca import download_if_missing - - -def prepare( - destination_path: Path = Path("data/longform"), - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - mask_inputs: bool = False, # as in alpaca-lora - ignore_index: int = -1, - max_seq_length: Optional[int] = None, -) -> None: - """Prepare the Alpaca dataset for instruction tuning. - - The output is a training and test dataset saved as `train.pt` and `test.pt`, - which stores the preprocessed and tokenized prompts and labels. - """ - if max_seq_length is None: - with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file: - config = json.load(file) - max_seq_length = config["block_size"] - - destination_path.mkdir(parents=True, exist_ok=True) - - train_file_name = "train.json" - # val_file_name = "val.json" - test_file_name = "test.json" - - train_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/train.json" - # val_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/val.json" - test_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/test.json" - - train_file_path = destination_path / train_file_name - print("Loading train data file...") - download_if_missing(train_file_path, train_file_url) - with open(train_file_path, "r", encoding="utf-8") as file: - train_data = json.load(file) - - test_file_path = destination_path / test_file_name - print("Loading test data file...") - download_if_missing(test_file_path, test_file_url) - with open(test_file_path, "r", encoding="utf-8") as file: - test_data = json.load(file) - - print("Loading tokenizer...") - tokenizer = Tokenizer(checkpoint_dir) - - print(f"train has {len(train_data):,} samples") - print(f"test has {len(test_data):,} samples") - - print("Processing train set ...") - train_data = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(train_data) - ] - torch.save(train_data, destination_path / "train.pt") - - print("Processing test set ...") - test_data = [ - prepare_sample( - example=sample, - tokenizer=tokenizer, - max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, - ) - for sample in tqdm(test_data) - ] - torch.save(test_data, destination_path / "test.pt") - - -def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: - """Processes a single sample. - - Each sample in the dataset consists of: - - instruction: A string describing the task - - input: A string holding a special input value for the instruction. - This only applies to some samples, and in others this is empty. - - output: The response string - - This function processes this data to produce a prompt text and a label for - supervised training. The prompt text is formed as a single message including both - the instruction and the input. The label/target is the same message but with the - response attached. - - Finally, both the prompt and the label get tokenized. If desired, all tokens - in the label that correspond to the original input prompt get masked out (default). - """ - full_prompt = generate_prompt(example) - full_prompt_and_response = full_prompt + example["output"] - encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) - encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - - return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} - - -def generate_prompt(example: dict) -> str: - """Generates a standardized message to prompt the model with an instruction and a - 'response' field.""" - - return ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - f"### Instruction:\n{example['input']}\n\n### Response:" - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/scripts/prepare_redpajama.py b/scripts/prepare_redpajama.py deleted file mode 100644 index f2c87a335c..0000000000 --- a/scripts/prepare_redpajama.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -import glob -import json -import os -import sys -from pathlib import Path - -import numpy as np -from tqdm import tqdm - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -import lit_gpt.packed_dataset as packed_dataset -from lit_gpt import Config, Tokenizer -from lit_gpt.utils import CLI - -filenames_sample = [ - "arxiv_sample.jsonl", - "book_sample.jsonl", - "c4_sample.jsonl", - "cc_2019-30_sample.jsonl", - "cc_2020-05_sample.jsonl", - "cc_2021-04_sample.jsonl", - "cc_2022-05_sample.jsonl", - "cc_2023-06_sample.jsonl", - "github_sample.jsonl", - "stackexchange_sample.jsonl", - "wikipedia_sample.jsonl", -] - -filename_sets = { - "arxiv": "arxiv/arxiv*", - "book": "book/book*", - "c4": "c4/c4-train*", - "common_crawl": "common_crawl/*", - "github": "github/filtered*", - "stackexchange": "stackexchange/stackexchange*", - "wikipedia": "wikipedia/wiki*", -} - - -def prepare_sample( - source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" -) -> None: - """Prepare the "Red Pajama" dataset using the original tokenizer.""" - destination_path.mkdir(parents=True, exist_ok=True) - - tokenizer = Tokenizer(checkpoint_dir) - - for name in filenames_sample: - if match and match not in name: - continue - - filepath = source_path / name - - if not filepath.is_file(): - raise RuntimeError( - f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" - " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" - " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" - " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" - ) - - prefix, _ = os.path.splitext(name) - - builder = packed_dataset.PackedDatasetBuilder( - outdir=destination_path, - prefix=prefix, - chunk_size=chunk_size, - sep_token=tokenizer.eos_id, - dtype="auto", - vocab_size=tokenizer.vocab_size, - ) - - print(f"Processing {name}") - - with open(filepath, encoding="utf-8") as f: - for row in tqdm(f): - text = json.loads(row)["text"] - text_ids = tokenizer.encode(text) - builder.add_array(np.array(text_ids, dtype=builder.dtype)) - - builder.write_reminder() - - -def prepare_full( - source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" -) -> None: - """Prepare the "Red Pajama" dataset using the original tokenizer.""" - import zstandard as zstd - - destination_path.mkdir(parents=True, exist_ok=True) - - tokenizer = Tokenizer(checkpoint_dir) - - for set_name, pattern in filename_sets.items(): - if match and match not in set_name: - continue - - is_cc = set_name == "common_crawl" - - filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) - - if not filenames: - raise RuntimeError( - f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" - " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" - " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" - " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" - ) - - builder = packed_dataset.PackedDatasetBuilder( - outdir=destination_path, - prefix=set_name, - chunk_size=chunk_size, - sep_token=tokenizer.eos_id, - dtype="auto", - vocab_size=tokenizer.vocab_size, - ) - - for name in filenames: - filepath = source_path / name - - print(f"Processing {name}") - - if is_cc: - with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: - for row in tqdm(f): - text = json.loads(row)["text"] - text_ids = tokenizer.encode(text) - builder.add_array(np.array(text_ids, dtype=builder.dtype)) - else: - with open(filepath, encoding="utf-8") as f: - for row in tqdm(f): - text = json.loads(row)["text"] - text_ids = tokenizer.encode(text) - builder.add_array(np.array(text_ids, dtype=builder.dtype)) - - builder.write_reminder() - - -def prepare( - source_path: Path = Path("data/RedPajama-Data-1T-Sample"), - checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), - destination_path: Path = Path("data/redpajama_sample"), - sample: bool = True, - match: str = "", -) -> None: - """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" - config = Config.from_checkpoint(checkpoint_dir) - - prepare_fn = prepare_sample if sample else prepare_full - prepare_fn( - source_path=source_path, - checkpoint_dir=checkpoint_dir, - destination_path=destination_path, - chunk_size=(config.block_size + 1) * 1024, # block size + 1 for causal, 1024 blocks - match=match, - ) - - -if __name__ == "__main__": - CLI(prepare) diff --git a/tests/conftest.py b/tests/conftest.py index 3414ce4c1a..7721fed50f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os +import shutil import sys from pathlib import Path from typing import List @@ -74,6 +75,30 @@ def mock_tockenizer(): return MockTokenizer() +@pytest.fixture() +def alpaca_path(tmp_path): + file = Path(__file__).parent / "data" / "fixtures" / "alpaca.json" + shutil.copyfile(file, tmp_path / "alpaca.json") + return tmp_path / "alpaca.json" + + +@pytest.fixture() +def dolly_path(tmp_path): + file = Path(__file__).parent / "data" / "fixtures" / "dolly.json" + shutil.copyfile(file, tmp_path / "dolly.json") + return tmp_path / "dolly.json" + + +@pytest.fixture() +def longform_path(tmp_path): + path = tmp_path / "longform" + path.mkdir() + for split in ("train", "val"): + file = Path(__file__).parent / "data" / "fixtures" / f"longform_{split}.json" + shutil.copyfile(file, path / f"{split}.json") + return path + + def RunIf(**kwargs): reasons, marker_kwargs = _runif_reasons(**kwargs) return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs) diff --git a/tests/data/conftest.py b/tests/data/conftest.py deleted file mode 100644 index f048f5dcb1..0000000000 --- a/tests/data/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -import shutil -from pathlib import Path - -import pytest - - -@pytest.fixture() -def alpaca_path(tmp_path): - file = Path(__file__).parent / "fixtures" / "alpaca.json" - shutil.copyfile(file, tmp_path / "alpaca.json") - return tmp_path / "alpaca.json" - - -@pytest.fixture() -def dolly_path(tmp_path): - file = Path(__file__).parent / "fixtures" / "dolly.json" - shutil.copyfile(file, tmp_path / "dolly.json") - return tmp_path / "dolly.json" - - -@pytest.fixture() -def longform_path(tmp_path): - path = tmp_path / "longform" - path.mkdir() - for split in ("train", "val"): - file = Path(__file__).parent / "fixtures" / f"longform_{split}.json" - shutil.copyfile(file, path / f"{split}.json") - return path diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 27e5163bff..323d879878 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,7 +1,9 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import os from contextlib import redirect_stdout from dataclasses import asdict from io import StringIO +from unittest import mock from unittest.mock import Mock import pytest @@ -47,17 +49,11 @@ def test_adapter_filter(tmp_path): assert set(saved) == expected -def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch): +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): import finetune.adapter as module + from lit_gpt.data import Alpaca from lit_gpt.args import EvalArgs, IOArgs, TrainArgs - - data = [ - {"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([1, 2, 3])}, - {"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([2, 3, 4])}, - ] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - from lit_gpt.config import name_to_config model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0) @@ -67,17 +63,21 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch): tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock - tokenizer_mock.encode = lambda *_, **kwargs: torch.tensor([3, 2, 1], **kwargs) + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) stdout = StringIO() with redirect_stdout(stdout): module.setup( - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0 ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), precision="32-true", - train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, epoch_size=6, micro_batch_size=1), + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) @@ -134,27 +134,27 @@ def test_adapter_compile(): @RunIf(min_cuda_gpus=1) -def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): +def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): from lit_gpt.args import IOArgs + from lit_gpt.config import name_to_config + from lit_gpt.data import Alpaca + import finetune.adapter as module if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") from bitsandbytes.optim import PagedAdamW - import finetune.adapter as module - - data = [] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - - from lit_gpt.config import name_to_config - model_config = dict( block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True ) monkeypatch.setitem(name_to_config, "tmp", model_config) + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + monkeypatch.setattr(module, "load_checkpoint", Mock()) train_mock = Mock() monkeypatch.setattr(module, "fit", train_mock) @@ -162,11 +162,15 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): stdout = StringIO() with redirect_stdout(stdout): module.setup( + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0, + ), precision="16-true", quantize="bnb.nf4-dq", - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path - ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), ) args, kwargs = train_mock.call_args diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index f0e0423bb9..b123e28ad0 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -1,9 +1,10 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - +import os import sys from contextlib import redirect_stdout from io import StringIO from pathlib import Path +from unittest import mock from unittest.mock import Mock import pytest @@ -71,17 +72,11 @@ def test_adapter_v2_filter(tmp_path): assert set(saved) == expected -def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch): +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): import finetune.adapter_v2 as module from lit_gpt.args import EvalArgs, IOArgs, TrainArgs - - data = [ - {"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([1, 2, 3])}, - {"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([2, 3, 4])}, - ] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - + from lit_gpt.data import Alpaca from lit_gpt.config import name_to_config model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0) @@ -91,17 +86,21 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch): tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock - tokenizer_mock.encode = lambda *_, **kwargs: torch.tensor([3, 2, 1], **kwargs) + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) stdout = StringIO() with redirect_stdout(stdout): module.setup( - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0 ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), precision="32-true", - train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, epoch_size=6, micro_batch_size=1), + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) @@ -224,27 +223,27 @@ def test_against_hf_mixtral(): @RunIf(min_cuda_gpus=1) -def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): +def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): from lit_gpt.args import IOArgs + from lit_gpt.config import name_to_config + from lit_gpt.data import Alpaca + import finetune.adapter_v2 as module if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") from bitsandbytes.optim import PagedAdamW - import finetune.adapter_v2 as module - - data = [] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - - from lit_gpt.config import name_to_config - model_config = dict( block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True ) monkeypatch.setitem(name_to_config, "tmp", model_config) + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + monkeypatch.setattr(module, "load_checkpoint", Mock()) train_mock = Mock() monkeypatch.setattr(module, "fit", train_mock) @@ -252,11 +251,15 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): stdout = StringIO() with redirect_stdout(stdout): module.setup( + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0 + ), precision="16-true", quantize="bnb.nf4-dq", - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path - ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), ) args, kwargs = train_mock.call_args diff --git a/tests/test_full.py b/tests/test_full.py index f543c27809..023a32db8d 100644 --- a/tests/test_full.py +++ b/tests/test_full.py @@ -10,17 +10,10 @@ @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) -def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch): +def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): import finetune.full as module from lit_gpt.args import EvalArgs, IOArgs, TrainArgs - - data = [ - {"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([1, 2, 3])}, - {"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([2, 3, 4])}, - ] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - + from lit_gpt.data import Alpaca from lit_gpt.config import name_to_config model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) @@ -29,17 +22,21 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch): tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock - tokenizer_mock.encode = lambda *_, **kwargs: torch.tensor([3, 2, 1], **kwargs) + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) stdout = StringIO() with redirect_stdout(stdout): module.setup( - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0 ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), precision="32-true", - train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, epoch_size=6, micro_batch_size=1), + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) diff --git a/tests/test_lora.py b/tests/test_lora.py index 88fe72c08c..68780dd47a 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -1,10 +1,11 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - +import os import sys from contextlib import redirect_stdout from io import StringIO from itertools import product from pathlib import Path +from unittest import mock from unittest.mock import Mock import pytest @@ -179,17 +180,11 @@ def test_lora_filter(tmp_path): assert set(saved) == expected -def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch): +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): import finetune.lora as module from lit_gpt.args import EvalArgs, IOArgs, TrainArgs - - data = [ - {"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([1, 2, 3])}, - {"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([2, 3, 4])}, - ] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - + from lit_gpt.data import Alpaca from lit_gpt.config import name_to_config model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) @@ -198,17 +193,21 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch): tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock - tokenizer_mock.encode = lambda *_, **kwargs: torch.tensor([3, 2, 1], **kwargs) + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) stdout = StringIO() with redirect_stdout(stdout): module.setup( - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0 ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), precision="32-true", - train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, epoch_size=6, micro_batch_size=1), + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) @@ -582,22 +581,17 @@ def test_against_hf_mixtral(): @RunIf(min_cuda_gpus=1) -def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): +def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): from lit_gpt.args import IOArgs + from lit_gpt.config import name_to_config + from lit_gpt.data import Alpaca + import finetune.lora as module if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") from bitsandbytes.optim import PagedAdamW - import finetune.lora as module - - data = [] - torch.save(data, tmp_path / "train.pt") - torch.save(data, tmp_path / "test.pt") - - from lit_gpt.config import name_to_config - model_config = dict( block_size=128, n_layer=2, @@ -614,6 +608,11 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): ) monkeypatch.setitem(name_to_config, "tmp", model_config) + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + monkeypatch.setattr(module, "load_checkpoint", Mock()) train_mock = Mock() monkeypatch.setattr(module, "fit", train_mock) @@ -621,9 +620,13 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): stdout = StringIO() with redirect_stdout(stdout): module.setup( - io=IOArgs( - train_data_dir=tmp_path, val_data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path + data=Alpaca( + download_dir=alpaca_path.parent, + data_file_name=alpaca_path.name, + test_split_fraction=0.5, + num_workers=0, ), + io=IOArgs(checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path), precision="16-true", quantize="bnb.nf4-dq", ) diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py deleted file mode 100644 index 36914489df..0000000000 --- a/tests/test_packed_dataset.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -import os -from unittest.mock import MagicMock - -import pytest -from torch.utils.data import IterableDataset - -from scripts.prepare_alpaca import download_if_missing - - -def test_packed_dataset(tmp_path): - tmp_path.mkdir(parents=True, exist_ok=True) - - vocabulary_path = tmp_path / "tokenizer.json" - download_if_missing( - vocabulary_path, "https://huggingface.co/stabilityai/stablelm-base-alpha-3b/raw/main/tokenizer.json" - ) - - tokenizer_path = tmp_path / "tokenizer_config.json" - download_if_missing( - tokenizer_path, "https://huggingface.co/stabilityai/stablelm-base-alpha-3b/raw/main/tokenizer_config.json" - ) - - from lit_gpt import Tokenizer - - tokenizer = Tokenizer(tmp_path) - - texts = ["The moment of truth is upon us. " * 4, "Time to open the fridge. " * 4] - - from lit_gpt.packed_dataset import HDR_SIZE, PackedDataset, PackedDatasetBuilder - - block_size = 10 - n_blocks = 2 - chunk_size = block_size * n_blocks - - builder = PackedDatasetBuilder( - outdir=tmp_path, - prefix="packed_dataset", - chunk_size=chunk_size, - sep_token=tokenizer.eos_id, - dtype="auto", - vocab_size=tokenizer.vocab_size, - ) - - for text in texts: - text_ids = tokenizer.encode(text) - print(len(text_ids)) - builder.add_array(text_ids) - - filenames = builder.filenames - - assert len(filenames) == 2 - assert os.path.basename(filenames[0]) == "packed_dataset_0000000000.bin" - assert os.path.basename(filenames[1]) == "packed_dataset_0000000001.bin" - - import numpy as np - - ex_tokenized = [tokenizer.encode(text).numpy().astype(builder.dtype) for text in texts] - ex_tokenized = np.concatenate(ex_tokenized) - ex_tokenized = ex_tokenized[: 2 * chunk_size] - - for filename, el in zip(filenames, np.array_split(ex_tokenized, 2)): - mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) - count = len(mmap) // np.dtype(builder.dtype).itemsize - arr = np.frombuffer(mmap, dtype=builder.dtype, count=count, offset=0) - where_eos = np.where(arr == tokenizer.eos_id) - # we expect two EOS tokens, one per file - assert len(where_eos) == 1 - assert np.array_equal(arr, el) - - dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, shuffle=False) - - ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size) - - for item, el in zip(dataset, ex_split): - assert np.array_equal(item, el) - - dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345) - - for i, item in enumerate(dataset): - block_idxs = iter(dataset)._block_idxs - assert np.array_equal(item, ex_split[block_idxs[i]]) - - dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345, wrap=True) - - for i, item in enumerate(dataset): - if i > 24: - break - - dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, seed=12345) - - for i, item in enumerate(dataset): - block_idxs = iter(dataset)._block_idxs - chunk_idx = i // n_blocks * n_blocks - assert np.array_equal(item, ex_split[chunk_idx + block_idxs[i % n_blocks]]) - - block_size_ = block_size // 2 - ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size_) - dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size_, seed=12345) - - for i, item in enumerate(dataset): - block_idxs = iter(dataset)._block_idxs - assert np.array_equal(item, ex_split[block_idxs[i]]) - - block_size_ = block_size // 3 - n_chunks = 2 - ex_chunks = np.split(ex_tokenized, n_chunks) - n_splits = ex_tokenized.shape[0] // n_chunks // block_size_ - ex_splits = [np.split(el[: n_splits * block_size_], n_splits) for el in ex_chunks] - ex_split = sum(ex_splits, []) - - dataset = PackedDataset(filenames=filenames, n_chunks=n_chunks, block_size=block_size_, seed=12345) - - for i, item in enumerate(dataset): - block_idxs = iter(dataset)._block_idxs - assert np.array_equal(item, ex_split[block_idxs[i]]) - - -class SimpleDataset(IterableDataset): - def __init__(self, start, end): - super().__init__() - self._start = start - self._end = end - - def __iter__(self): - return iter(range(self._start, self._end)) - - -def test_combined_dataset(): - from lit_gpt.packed_dataset import CombinedDataset - - dataset1 = SimpleDataset(0, 10) - dataset2 = SimpleDataset(10, 20) - dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) - - res = list(dataset) - assert res == list(range(0, 10)) - - dataset1 = SimpleDataset(0, 10) - dataset2 = SimpleDataset(10, 20) - dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) - - res = list(dataset) - assert res == list(range(10, 20)) - - dataset1 = SimpleDataset(0, 10) - dataset2 = SimpleDataset(10, 20) - dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) - - res = list(dataset) - assert 9 in res or 19 in res - if len(res) > 10: - assert 0 in res - assert 10 in res - - -def test_sharded_packed_dataset(monkeypatch): - import lit_gpt.packed_dataset - from lit_gpt.packed_dataset import PackedDataset - - dataset_iterator_mock = MagicMock() - monkeypatch.setattr(lit_gpt.packed_dataset, "PackedDatasetIterator", dataset_iterator_mock) - filenames = [str(i) for i in range(10)] - - # world_size = 1, rank = 0 - iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2)) - assert dataset_iterator_mock.call_args[1]["filenames"] == filenames - dataset_iterator_mock.reset_mock() - # world_size = 2, rank = 0 - iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=0)) - assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "2", "4", "6", "8"] - dataset_iterator_mock.reset_mock() - # world_size = 2, rank = 1 - iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=1)) - assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "3", "5", "7", "9"] - dataset_iterator_mock.reset_mock() - - # world_size = 3, rank = 0 (dataset size not cleanly divisible by world size) - iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=0)) - assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "3", "6"] - dataset_iterator_mock.reset_mock() - # world_size = 3, rank = 1 (dataset size not cleanly divisible by world size) - iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=1)) - assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "4", "7"] - dataset_iterator_mock.reset_mock() - # world_size = 3, rank = 2 (dataset size not cleanly divisible by world size) - iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=2)) - assert dataset_iterator_mock.call_args[1]["filenames"] == ["2", "5", "8"] - - -@pytest.mark.parametrize( - ("weights", "expected"), - [ - ([1], [1]), - ([2], [1]), - ([2, 0.5], [0.8, 0.2]), - ([1, 1, 1], [1 / 3, 1 / 3, 1 / 3]), - ([0.3, 0, 0], [1.0, 0, 0]), - (None, [0.5, 0.5]), - ], -) -def test_combined_dataset_normalizes_weights(weights, expected): - from lit_gpt.packed_dataset import CombinedDataset - - combined_dataset = CombinedDataset([[1], [2, 3]], weights=weights, seed=1) - assert combined_dataset._weights == expected diff --git a/tests/test_prepare_csv.py b/tests/test_prepare_csv.py deleted file mode 100644 index 5d8c3dc065..0000000000 --- a/tests/test_prepare_csv.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -import json -import subprocess -import sys -from pathlib import Path -from unittest import mock -from unittest.mock import ANY, call - - -def test_prepare_csv(tmp_path, fake_checkpoint_dir): - with mock.patch("lit_gpt.tokenizer.Tokenizer"): - from scripts.prepare_csv import prepare - - # create fake data - config = dict(block_size=128, padded_vocab_size=256, n_layer=3, n_head=8, n_embd=16) - with open(fake_checkpoint_dir / "lit_config.json", "w") as fp: - json.dump(config, fp) - csv_path = tmp_path / "data.csv" - mock_data = ( - "instruction,input,output\n" - "Add,2+2,4\n" - "Subtract,5-3,2\n" - "Multiply,6*4,24\n" - "Divide,10/2,5\n" - "Exponentiate,2^3,8\n" - "Square root,√9,3\n" - ) - with open(csv_path, "w", encoding="utf-8") as fp: - fp.write(mock_data) - - with mock.patch("torch.save") as save_mock: - prepare(csv_path, destination_path=tmp_path, checkpoint_dir=fake_checkpoint_dir, test_split_fraction=0.5) - - assert len(save_mock.mock_calls) == 2 - train_calls, test_calls = save_mock.mock_calls - assert train_calls == call( - [ - {"instruction": "Add", "input": "2+2", "output": "4", "input_ids": ANY, "labels": ANY}, - {"instruction": "Divide", "input": "10/2", "output": "5", "input_ids": ANY, "labels": ANY}, - {"instruction": "Multiply", "input": "6*4", "output": "24", "input_ids": ANY, "labels": ANY}, - ], - tmp_path / "train.pt", - ) - assert test_calls == call( - [ - {"instruction": "Exponentiate", "input": "2^3", "output": "8", "input_ids": ANY, "labels": ANY}, - {"instruction": "Subtract", "input": "5-3", "output": "2", "input_ids": ANY, "labels": ANY}, - {"instruction": "Square root", "input": "√9", "output": "3", "input_ids": ANY, "labels": ANY}, - ], - tmp_path / "test.pt", - ) - - -def test_cli(): - cli_path = Path(__file__).parent.parent / "scripts" / "prepare_csv.py" - output = subprocess.check_output([sys.executable, cli_path, "-h"]) - output = str(output.decode()) - assert "Prepare a CSV dataset" in output diff --git a/tests/test_prepare_redpajama.py b/tests/test_prepare_redpajama.py deleted file mode 100644 index 059e47e518..0000000000 --- a/tests/test_prepare_redpajama.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - -import json -import os -import subprocess -import sys -from pathlib import Path -from unittest import mock - -from scripts.prepare_alpaca import download_if_missing - - -def test_prepare_sample(tmp_path): - vocabulary_path = tmp_path / "tokenizer.json" - download_if_missing( - vocabulary_path, "https://huggingface.co/stabilityai/stablelm-base-alpha-3b/raw/main/tokenizer.json" - ) - tokenizer_path = tmp_path / "tokenizer_config.json" - download_if_missing( - tokenizer_path, "https://huggingface.co/stabilityai/stablelm-base-alpha-3b/raw/main/tokenizer_config.json" - ) - with open(tmp_path / "lit_config.json", "w") as f: - json.dump({"block_size": 2048}, f) - - sample_path = tmp_path / "sample" - source_path = sample_path / "source" - dest_path = sample_path / "dest" - - source_path.mkdir(parents=True) - - sample = {"meta": {"some": "info"}, "text": "some text"} - - jsonl_sample = "\n".join([json.dumps(el) for el in [sample] * 2]) - - import scripts.prepare_redpajama as prepare_redpajama - - for filename in prepare_redpajama.filenames_sample: - with open(source_path / filename, "w") as f: - f.write(jsonl_sample) - - prepare_redpajama.prepare(source_path=source_path, checkpoint_dir=tmp_path, destination_path=dest_path) - - bin_files = [el.replace(".jsonl", "_0000000000.bin") for el in prepare_redpajama.filenames_sample] - - assert set(os.listdir(dest_path)) == set(bin_files) - - from lit_gpt import Tokenizer - from lit_gpt.packed_dataset import PackedDataset - - tokenizer = Tokenizer(tmp_path) - - # artificially set block_size to fit the text - block_size = len(tokenizer.encode("some text")) - - for filename in bin_files: - filenames = [os.path.join(dest_path, filename)] - dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, shuffle=False) - dataset_iter = iter(dataset) - assert tokenizer.decode(next(dataset_iter)) == "some text" - assert tokenizer.decode(next(dataset_iter)) == "some text" - - -def test_prepare_full(tmp_path): - vocabulary_path = tmp_path / "tokenizer.json" - download_if_missing( - vocabulary_path, "https://huggingface.co/stabilityai/stablelm-base-alpha-3b/raw/main/tokenizer.json" - ) - tokenizer_path = tmp_path / "tokenizer_config.json" - download_if_missing( - tokenizer_path, "https://huggingface.co/stabilityai/stablelm-base-alpha-3b/raw/main/tokenizer_config.json" - ) - with open(tmp_path / "lit_config.json", "w") as f: - json.dump({"block_size": 2048}, f) - - full_path = tmp_path / "full" - source_path = full_path / "source" - dest_path = full_path / "dest" - - source_path.mkdir(parents=True) - - sample = {"meta": {"some": "info"}, "text": "some text"} - - jsonl_sample = "\n".join([json.dumps(el) for el in [sample] * 2]) - - import scripts.prepare_redpajama as prepare_redpajama - - arxiv_file = source_path / "arxiv" / "arxiv_0.jsonl" - arxiv_file.parent.mkdir(parents=True) - with open(arxiv_file, "w") as f: - f.write(jsonl_sample) - - import zstandard as zstd - - cc_file = source_path / "common_crawl" / "cc_0.jsonl" - cc_file.parent.mkdir(parents=True) - with zstd.open(cc_file, "wt", encoding="utf-8") as f: - f.write(jsonl_sample) - - filename_sets = {"arxiv": "arxiv/arxiv*", "common_crawl": "common_crawl/*"} - - with mock.patch.object(prepare_redpajama, "filename_sets", filename_sets): - prepare_redpajama.prepare( - source_path=source_path, checkpoint_dir=tmp_path, destination_path=dest_path, sample=False - ) - - all_names = prepare_redpajama.filename_sets.keys() - bin_files = [el + "_0000000000.bin" for el in all_names] - - assert set(os.listdir(dest_path)) == set(bin_files) - - from lit_gpt import Tokenizer - from lit_gpt.packed_dataset import PackedDataset - - tokenizer = Tokenizer(tmp_path) - - # artificially set block_size to fit the text - block_size = len(tokenizer.encode("some text")) - - filenames = [os.path.join(dest_path, el) for el in bin_files] - - for filename in filenames: - dataset = PackedDataset(filenames=[filename], n_chunks=1, block_size=block_size, shuffle=False) - dataset_iter = iter(dataset) - assert tokenizer.decode(next(dataset_iter)) == "some text" - assert tokenizer.decode(next(dataset_iter)) == "some text" - - -def test_cli(): - cli_path = Path(__file__).parent.parent / "scripts" / "prepare_redpajama.py" - output = subprocess.check_output([sys.executable, cli_path, "-h"]) - output = str(output.decode()) - assert 'Prepare the "Red Pajama"' in output diff --git a/tests/test_pretrain_tinyllama.py b/tests/test_pretrain_tinyllama.py index f781f1db76..830abd300f 100644 --- a/tests/test_pretrain_tinyllama.py +++ b/tests/test_pretrain_tinyllama.py @@ -30,7 +30,7 @@ def test_pretrain_tiny_llama(tmp_path, monkeypatch): module.setup( devices=2, model=model_config, - io=IOArgs(out_dir=tmp_path, train_data_dir=None), + io=IOArgs(out_dir=tmp_path), train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), eval=EvalArgs(interval=1, max_iters=1), ) diff --git a/xla/generate/adapter.py b/xla/generate/adapter.py index bac6c79a56..f7150103ba 100644 --- a/xla/generate/adapter.py +++ b/xla/generate/adapter.py @@ -16,7 +16,9 @@ from lit_gpt import Tokenizer from lit_gpt.adapter import GPT, Block, Config from lit_gpt.utils import check_valid_checkpoint_dir, lazy_load -from scripts.prepare_alpaca import generate_prompt +from lit_gpt.data import apply_prompt_template +from lit_gpt.data.alpaca import prompt_template + from xla.generate.base import generate from xla.utils import rank_print @@ -88,7 +90,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) sample = {"instruction": prompt, "input": input} - prompt = generate_prompt(sample) + prompt = apply_prompt_template(prompt_template, sample) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens