Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(4/n) Data Refactor - Finetuning Scripts #950

Merged
merged 76 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
1655751
alpaca
awaelchli Feb 15, 2024
0b65b09
fixes
awaelchli Feb 23, 2024
1e273fe
fixes
awaelchli Feb 23, 2024
719ac2f
separate
awaelchli Feb 23, 2024
8b31c77
lima
awaelchli Feb 23, 2024
e4daef9
Merge branch 'main' into refactor/data
awaelchli Feb 24, 2024
a2047f6
integrate
awaelchli Feb 24, 2024
815e03b
remove converted datasets
awaelchli Feb 24, 2024
34bf71c
tinyllama
awaelchli Feb 24, 2024
1768d56
update
awaelchli Feb 24, 2024
0b9eca8
small typo fix: laoder -> loader
rasbt Feb 26, 2024
c83f468
refactor base class
awaelchli Feb 26, 2024
1f4a4ee
args stuff
awaelchli Feb 26, 2024
a8e6ce4
Merge branch 'refactor/data' of ssh://github.com/Lightning-AI/lit-gpt…
awaelchli Feb 26, 2024
65d4a21
max_seq_length needs to be specified differently
awaelchli Feb 27, 2024
19cbc37
fix for max steps
awaelchli Feb 27, 2024
095cc7b
fix init
awaelchli Feb 27, 2024
35537c0
tinyllama
awaelchli Feb 27, 2024
2f5658e
model config
awaelchli Feb 27, 2024
7f81bbe
remove epoch size
awaelchli Feb 27, 2024
da0710f
simplify
awaelchli Feb 27, 2024
9638b0b
fix
awaelchli Feb 27, 2024
edbd4e0
refactor
awaelchli Feb 27, 2024
ff0ba0e
init
awaelchli Feb 27, 2024
0d36d09
revert
awaelchli Feb 27, 2024
eae63aa
docs
awaelchli Feb 27, 2024
e475962
docs
awaelchli Feb 27, 2024
b759d48
fix test
awaelchli Feb 27, 2024
0c18563
Update tests/test_pretrain_tinyllama.py
awaelchli Feb 27, 2024
a5d1ae5
Update pretrain/tinyllama.py
awaelchli Feb 27, 2024
bb250f2
update gitnignore
awaelchli Feb 27, 2024
066b776
tests
awaelchli Feb 27, 2024
097a58b
no test loader
awaelchli Feb 27, 2024
7edd888
rename base
awaelchli Feb 27, 2024
175f223
remove name arg
awaelchli Feb 27, 2024
6bb4ec3
datasets collides with hf datasets import :(
awaelchli Feb 27, 2024
463cb56
Merge branch 'refactor/data-tinyllama' into refactor/data
awaelchli Feb 27, 2024
a13dfb6
move
awaelchli Feb 27, 2024
b30771c
Merge branch 'main' into refactor/data
awaelchli Feb 27, 2024
ec45af5
restore
awaelchli Feb 28, 2024
abce4e0
restore
awaelchli Feb 28, 2024
1e5bf65
tests
awaelchli Feb 28, 2024
7835dea
test
awaelchli Feb 28, 2024
c990ce5
test
awaelchli Feb 28, 2024
006c09d
update
awaelchli Feb 28, 2024
adc99b5
csv
awaelchli Feb 28, 2024
b4c9b71
test csv
awaelchli Feb 28, 2024
63ba730
remove old test
awaelchli Feb 28, 2024
216da43
dolly
awaelchli Feb 28, 2024
f28ffd7
longform
awaelchli Feb 28, 2024
49b5b0a
fixes
awaelchli Feb 28, 2024
5b99057
flan
awaelchli Feb 28, 2024
c47785f
fix
awaelchli Feb 28, 2024
d9e035f
update
awaelchli Feb 28, 2024
e1c2766
optional data
awaelchli Feb 28, 2024
3e90c4a
fix test split
awaelchli Feb 28, 2024
0b0fa20
todos
awaelchli Feb 28, 2024
1315b3f
Merge branch 'main' into refactor/data
awaelchli Feb 28, 2024
e8a7677
update test
awaelchli Feb 28, 2024
efd5b7e
tinyllama
awaelchli Feb 28, 2024
0cd28d6
update
awaelchli Feb 29, 2024
9c6135c
lora
awaelchli Feb 29, 2024
2014b18
adapter
awaelchli Feb 29, 2024
e4c6396
adapter v2
awaelchli Feb 29, 2024
8fcfe26
update
awaelchli Feb 29, 2024
eac6bd6
update tests
awaelchli Feb 29, 2024
563c580
update
awaelchli Feb 29, 2024
ebb5b7b
update
awaelchli Feb 29, 2024
2053cc0
tests
awaelchli Feb 29, 2024
e5637e3
tests
awaelchli Feb 29, 2024
d69c73a
reset
awaelchli Feb 29, 2024
490c51e
Merge branch 'main' into refactor/data
awaelchli Feb 29, 2024
6bcff6a
Run CI on wip branch
carmocca Feb 29, 2024
8d7a2b3
Merge branch 'main' into refactor/data
awaelchli Feb 29, 2024
7215ac5
require either epochs or max_steps to be set
awaelchli Feb 29, 2024
5e10c9e
don't inline max_steps redefinition
awaelchli Feb 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/azure-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ trigger:
branches:
include:
- "main"
- "wip"

pr:
branches:
include:
- "main"
- "wip"
- "carmocca/*"

jobs:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: CPU tests

on:
push:
branches: [main]
branches: [main, wip]
pull_request:
branches: [main, "carmocca/*"]
branches: [main, "carmocca/*", wip]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
Expand Down
117 changes: 49 additions & 68 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import lightning as L
import torch
from torch.utils.data import DataLoader
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
Expand All @@ -20,6 +21,7 @@
from generate.base import generate
from lit_gpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.data import Alpaca, LitDataModule, apply_prompt_template
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import (
CLI,
Expand All @@ -28,20 +30,19 @@
get_default_supported_precision,
load_checkpoint,
num_parameters,
CycleIterator,
)
from scripts.prepare_alpaca import generate_prompt


def setup(
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: int = 1,
seed: int = 1337,
data: Optional[LitDataModule] = None,
io: IOArgs = IOArgs(
train_data_dir=Path("data/alpaca"),
val_data_dir=Path("data/alpaca"),
checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir=Path("out/adapter/alpaca"),
out_dir=Path("out/adapter"),
),
train: TrainArgs = TrainArgs(
save_interval=1000,
Expand All @@ -50,13 +51,16 @@ def setup(
micro_batch_size=4,
lr_warmup_steps=100,
epochs=5,
epoch_size=50000,
learning_rate=1e-3,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
) -> None:

print(locals())
if data is None:
data = Alpaca()

precision = precision or get_default_supported_precision(training=True)

plugins = None
Expand Down Expand Up @@ -85,25 +89,24 @@ def setup(

logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), io, train, eval)
fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), data, io, train, eval)


def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(io, train, eval)

steps_per_epoch = train.epoch_size // devices // train.batch_size(devices)
lr_max_steps = train.epochs * steps_per_epoch

check_valid_checkpoint_dir(io.checkpoint_dir)

tokenizer = Tokenizer(io.checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))

fabric.seed_everything(seed) # same seed for every process to init model (FSDP)

if fabric.global_rank == 0:
os.makedirs(io.out_dir, exist_ok=True)

train_data = torch.load(io.train_data_dir / "train.pt")
val_data = torch.load(io.val_data_dir / "test.pt")

checkpoint_path = io.checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=(devices > 1)):
Expand Down Expand Up @@ -131,10 +134,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs,
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)

fabric.seed_everything(1337 + fabric.global_rank)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

train_time = time.perf_counter()
fit(fabric, model, optimizer, scheduler, train_data, val_data, devices, io, train, eval)
fit(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader, devices, io, train, eval)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Expand All @@ -149,34 +150,37 @@ def fit(
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_data: List[Dict],
val_data: List[Dict],
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
io: IOArgs,
train: TrainArgs,
eval: EvalArgs,
) -> None:
tokenizer = Tokenizer(io.checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)

validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2)) # sanity check

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
total_lengths = 0
total_t0 = time.perf_counter()

for iter_num in range(1, train.max_iters(devices) + 1):
while step_count < max_steps and train_iterator.epoch < train.epochs:
iter_num += 1
iter_t0 = time.perf_counter()

input_ids, targets = get_batch(
fabric, train_data, train.micro_batch_size, train.max_seq_length, longest_seq_ix if iter_num == 1 else None
)
batch = next(train_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
Expand Down Expand Up @@ -207,7 +211,7 @@ def fit(

if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, eval, train)
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval)
t1 = time.perf_counter() - t0
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
fabric.barrier()
Expand All @@ -219,13 +223,15 @@ def fit(
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs,
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length)
batch = next(val_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
Expand All @@ -234,7 +240,7 @@ def validate(
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
Expand All @@ -250,51 +256,24 @@ def validate(
return val_loss


def get_batch(
fabric: L.Fabric,
data: List[Dict],
micro_batch_size: int,
max_seq_length: Optional[int],
longest_seq_ix: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(data), (micro_batch_size,))
if longest_seq_ix is not None:
# force the longest sample at the beginning so potential OOMs happen right away
ix[0] = longest_seq_ix

input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]

# this could be `longest_seq_length` to have a fixed size for all batches
max_len = max(len(s) for s in input_ids)

def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])

# Truncate if needed
if max_seq_length:
x = x[:, :max_seq_length]
y = y[:, :max_seq_length]

if fabric.device.type == "cuda" and x.device.type == "cpu":
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
else:
x, y = fabric.to_device((x, y))
return x, y


def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader


def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
Expand All @@ -316,14 +295,16 @@ def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [
(io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]),
(train, ["epoch_size", "epochs"]),
(io, ["checkpoint_dir"]),
(train, ["epochs"]),
carmocca marked this conversation as resolved.
Show resolved Hide resolved
(eval, ["max_new_tokens"]),
]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))

Expand Down
Loading
Loading