diff --git a/README.md b/README.md index 193fd94..470f1f3 100644 --- a/README.md +++ b/README.md @@ -88,3 +88,9 @@ This repository contains a variety of Determined examples that are not actively | Example | Dataset | Framework | |:------------------------------------------------------------------------:|:-------:|:----------:| | [asha\_search\_method](custom_search_method/asha_search_method) | MNIST | PyTorch | + +## Fully Sharded Data Parallel + +| Example | Framework | +|:------------------------------------------------------------------------:|:----------:| +| [minimal\_fsdp](fsdp/minimal_fsdp) | PyTorch | diff --git a/fsdp/minimal_fsdp/README.md b/fsdp/minimal_fsdp/README.md new file mode 100644 index 0000000..06423d3 --- /dev/null +++ b/fsdp/minimal_fsdp/README.md @@ -0,0 +1,39 @@ +# FSDP + Core API for LLM Training + +This example shows how to use Fully Sharded Data Parallel [(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) with Determined and Core API. (Relatively) simple transformer model adapted from [GPT-fast +](https://github.com/pytorch-labs/gpt-fast) training on fake data. + +## Files +* **fsdp.py**: Training setup and loop, including checkpointing, reporting, and profiling. +* **model.py**: Model architecture. +* **config.yaml**: Experiment configuration file. + +## Configuration +Settings can be changed in `config.yaml` `hyperparameters` section. + +### Hyperparameters +* `batch_size`: Per-device batch size. Global batch size will be `batch_size * slots_per_trial`. +* `lr`: Learning rate. +* `d_model`, `max_seq_len`, `n_heads`, `n_layers`, `vocab_size`: Model architecture parameters. Check code for more details. +* `report_rate`: Number of training steps to take between metric reports. +* `checkpoint_rate`: Number of training steps to take between checkpoint saves. +* `amp_dtype`: Whether to use torch automatic mixed-precision, and which dtype to use. Options are `'auto'`, `'bfloat16'`, `'float16'`, and `null`. +* `validation_batches`: Number of batches to use when calculating validation metrics. +* `core_api_profiler`: Set to true to enable Core API profiler. Results visible in Web UI. +* `torch_profiler`: Set to true to enable `torch` profiler. Results visible in Tensorboard, which can be launched through the Web UI. + +### Other configuration +Users might want to change `resources.slots_per_trial`, `workspace`, `project`, and `searcher.max_length` in `config.yaml`. + +## Data +This example uses a synthetically generated random dataset for simplicity. + +## To Run +If you have not yet installed Determined, installation instructions can be found at https://docs.determined.ai/latest/index.html + +Change any desired configuration variables as outlined in the **Configuration** section, then run the following command: `det -m experiment create +config.yaml .`. + + +## Results +Training loss should decrease from ~10.5 to ~8.5 with default settings run for 100 steps, while validation loss remains constant. This is due to validation data being a separate random dataset. \ No newline at end of file diff --git a/fsdp/minimal_fsdp/config.yaml b/fsdp/minimal_fsdp/config.yaml new file mode 100644 index 0000000..8d1dafc --- /dev/null +++ b/fsdp/minimal_fsdp/config.yaml @@ -0,0 +1,26 @@ +name: fsdp example +entrypoint: python3 -m determined.launch.torch_distributed -- python3 fsdp.py +searcher: + name: single + metric: loss + max_length: 100 +resources: + slots_per_trial: 2 +environment: + image: + gpu: determinedai/environments:cuda-11.8-pytorch-2.0-gpu-mpi-0.31.1 +hyperparameters: + batch_size: 1 + lr: 1e-4 + d_model: 512 + max_seq_len: 2048 + n_heads: 8 + n_layers: 4 + vocab_size: 32000 + report_rate: 10 + checkpoint_rate: 50 + amp_dtype: float16 + validation_batches: 10 + core_api_profiler: false + torch_profiler: false +max_restarts: 0 diff --git a/fsdp/minimal_fsdp/fsdp.py b/fsdp/minimal_fsdp/fsdp.py new file mode 100644 index 0000000..f9917db --- /dev/null +++ b/fsdp/minimal_fsdp/fsdp.py @@ -0,0 +1,368 @@ +import json +import logging +import os +import random +from typing import Any, Dict, Generator, Optional, TypedDict + +import determined as det +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ModuleWrapPolicy + +from model import EmbedAndEncode, LMHead, Transformer, TransformerBlock + +""" +Minimal transformer model FSDP script with Core API. +""" + + +def get_fake_data_iter( + batch_size: int, + vocab_size: int, + max_seq_len: int, + rank: int, + device: torch.device, + is_validation: bool, + simulated_size_in_batches: int = 10, +) -> Generator[tuple[torch.Tensor, torch.Tensor], None, None]: + """ + Fake dataloader. Yields a different set of data for each rank, and for train vs validation. + This data would usually come from a tokenized dataset. + """ + generator = torch.Generator(device=device) + next_idx = 0 + while True: + if next_idx == 0: + generator.manual_seed(42 + rank + 100000 * is_validation) + fake_sequence = torch.randint( + vocab_size, + (batch_size, max_seq_len + 1), + device=device, + generator=generator, + ) + inputs, targets = fake_sequence[..., :-1], fake_sequence[..., 1:] + yield inputs, targets + next_idx = (next_idx + 1) % simulated_size_in_batches + + +def get_loss( + fsdp_model: FSDP, batch: tuple[torch.Tensor, torch.Tensor], use_amp: bool +) -> torch.Tensor: + inputs, labels = batch + with torch.cuda.amp.autocast(enabled=use_amp): + outputs = fsdp_model(inputs) + outputs_flat = outputs.reshape(-1, outputs.shape[-1]) + labels_flat = labels.reshape(-1) + loss = F.cross_entropy(outputs_flat, labels_flat) + return loss + + +def get_reduced_loss_and_report( + loss_history: list[torch.Tensor], + steps_completed: int, + core_context: det.core.Context, + validation: bool, +) -> Optional[float]: + """ + Average the most recent losses across all processes and report the result. Returns the reduced + loss on rank 0 and None on all other ranks. + """ + + loss_history_t = torch.stack(loss_history).mean() + dist.reduce(loss_history_t, 0, op=dist.ReduceOp.AVG) + if core_context.distributed.rank == 0: + reduced_loss = loss_history_t.item() + # TypedDict pattern to satisfy mypy. + ReportArgs = TypedDict( + "ReportArgs", {"steps_completed": int, "metrics": Dict[str, float]} + ) + report_args: ReportArgs = { + "steps_completed": steps_completed, + "metrics": {"loss": reduced_loss}, + } + if validation: + core_context.train.report_validation_metrics(**report_args) + else: + core_context.train.report_training_metrics(**report_args) + return reduced_loss + return None + + +def save_checkpoint( + fsdp_model: FSDP, + optimizer: torch.optim.Optimizer, + scaler: ShardedGradScaler, + use_scaler: bool, + core_context: det.core.Context, + steps_completed: int, +) -> None: + # All ranks collectively build the checkpoint on rank 0: + + with FSDP.state_dict_type( + fsdp_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state_dict = fsdp_model.state_dict() + optim_state_dict = FSDP.optim_state_dict(fsdp_model, optimizer) + + if core_context.distributed.rank == 0: + with core_context.checkpoint.store_path( + metadata={"steps_completed": steps_completed} + ) as ( + path, + _, + ): + torch.save(model_state_dict, path.joinpath("model.bin")) + torch.save(optim_state_dict, path.joinpath("optim.bin")) + if use_scaler: + # Scaler state is automatically the same across ranks. + scaler_state_dict = scaler.state_dict() + torch.save(scaler_state_dict, path.joinpath("scaler.bin")) + + +def load_checkpoint( + fsdp_model: FSDP, + optimizer: torch.optim.Optimizer, + scaler: ShardedGradScaler, + use_scaler: bool, + core_context: det.core.Context, + device: torch.device, + uuid: str, +) -> int: + with core_context.checkpoint.restore_path(uuid) as path: + with FSDP.state_dict_type( + fsdp_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + fsdp_model.load_state_dict( + torch.load(path.joinpath("model.bin"), map_location=device) + ) + optim_state_dict = torch.load( + path.joinpath("optim.bin"), map_location=device + ) + optim_state_dict_to_load = FSDP.optim_state_dict_to_load( + model=fsdp_model, + optim=optimizer, + optim_state_dict=optim_state_dict, + ) + optimizer.load_state_dict(optim_state_dict_to_load) + scaler_path = path.joinpath("scaler.bin") + if use_scaler and os.path.isfile(scaler_path): + scaler.load_state_dict(torch.load(scaler_path)) + + with open(path.joinpath("metadata.json"), "r") as f: + metadata = json.load(f) + + last_step_completed = metadata["steps_completed"] + return last_step_completed + + +def get_amp_dtype(amp_dtype_str: Optional[str]) -> Optional[torch.dtype]: + if amp_dtype_str is None: + return None + elif amp_dtype_str == "auto": + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] < 8: + return torch.float16 + else: + return torch.bfloat16 + elif amp_dtype_str in ["float16", "bfloat16"]: + return getattr(torch, amp_dtype_str) + else: + raise Exception( + f"Unknown amp_dtype {amp_dtype_str}. Please set to one of " + "'auto'/'bfloat16'/'float16'/null." + ) + + +def main( + core_context: det.core.Context, + hparams: dict[str, Any], + checkpoint_uuid: Optional[str] = None, +) -> None: + # Fix the random seed on all devices + seed = 42 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + # Get and set the device for this process + device = torch.device(f"cuda:{core_context.distributed.local_rank}") + torch.cuda.set_device(device) + + amp_dtype = get_amp_dtype(hparams["amp_dtype"]) + use_amp = amp_dtype is not None + use_scaler = amp_dtype == torch.float16 + + # Build the unsharded model directly on the device. + model = Transformer( + d_model=hparams["d_model"], + n_heads=hparams["n_heads"], + vocab_size=hparams["vocab_size"], + n_layers=hparams["n_layers"], + max_seq_len=hparams["max_seq_len"], + device=device, + ) + + # Inspect the model: + if core_context.distributed.rank == 0: + print("Model before FSDP:") + print(model, flush=True) + + # Wrap the embedding layer, the lm head, and each transformer block into its own FSDP unit: + auto_wrap_policy = ModuleWrapPolicy([TransformerBlock, EmbedAndEncode, LMHead]) + + # Let FSDP know to use mixed precision settings. + mixed_precision = ( + MixedPrecision(param_dtype=amp_dtype, reduce_dtype=amp_dtype) + if use_amp + else None + ) + + # The fsdp model: + fsdp_model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + mixed_precision=mixed_precision, + sharding_strategy=ShardingStrategy.FULL_SHARD, + device_id=device, + use_orig_params=True, + ) + + # Inspect the model post-FSDP + if core_context.distributed.rank == 0: + print("Model after FSDP:") + print(fsdp_model, flush=True) + + # The optimizer must be created post-FSDP + optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=hparams["lr"]) + + steps_completed = 0 + report_rate = hparams["report_rate"] + checkpoint_rate = hparams["checkpoint_rate"] + validation_batches = hparams["validation_batches"] + use_torch_profiler = hparams["torch_profiler"] + train_loss_history = [] + + data_iter_arguments = { + "batch_size": hparams["batch_size"], + "vocab_size": hparams["vocab_size"], + "max_seq_len": hparams["max_seq_len"], + "rank": core_context.distributed.rank, + "device": device, + } + train_data_iter = get_fake_data_iter(is_validation=False, **data_iter_arguments) + scaler = ShardedGradScaler(enabled=use_scaler) + # If a previous checkpoint exists, load it now and correct the steps_completed: + if checkpoint_uuid is not None: + steps_completed = load_checkpoint( + fsdp_model, + optimizer, + scaler, + use_scaler, + core_context, + device, + checkpoint_uuid, + ) + # If torch profiler enabled, write profiling results to TensorBoard accessible through WebUI. + if use_torch_profiler: + torch_profiler = torch.profiler.profile( + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(core_context.train.get_tensorboard_path()) + ), + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), + ) + for op in core_context.searcher.operations(): + # Train for the number of steps specified in searcher.max_length in config.yaml + while steps_completed < op.length: + batch = next(train_data_iter) + loss = get_loss(fsdp_model, batch, use_amp) + train_loss_history.append(loss.detach().clone()) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + if use_torch_profiler: + torch_profiler.step() + + steps_completed += 1 + this_is_the_last_step = steps_completed == op.length + + if steps_completed % report_rate == 0 or this_is_the_last_step: + # Report the average training loss. + get_reduced_loss_and_report( + train_loss_history, steps_completed, core_context, validation=False + ) + train_loss_history.clear() + # Compute and report an average validation loss. + validation_data_iter = get_fake_data_iter( + is_validation=True, **data_iter_arguments + ) + validation_loss_history = [] + with torch.inference_mode(): + for i in range(validation_batches): + batch = next(validation_data_iter) + loss = get_loss(fsdp_model, batch, use_amp) + validation_loss_history.append(loss) + last_validation_loss = get_reduced_loss_and_report( + validation_loss_history, + steps_completed, + core_context, + validation=True, + ) + + if steps_completed % checkpoint_rate == 0 or this_is_the_last_step: + save_checkpoint( + fsdp_model, + optimizer, + scaler, + use_scaler, + core_context, + steps_completed, + ) + # Since should_preempt is blocking, we only check at checkpoint_rate to + # maintain performance. + if core_context.preempt.should_preempt(): + return + + # Tell the master we're done + if core_context.distributed.rank == 0: + op.report_completed(last_validation_loss) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG, format=det.LOG_FORMAT) + info = det.get_cluster_info() + assert info, "This script must run on a determined cluster." + assert torch.cuda.is_available(), "This script assumes cuda." + + checkpoint_uuid = info.latest_checkpoint + hparams = info.trial.hparams + core_api_profiler = hparams["core_api_profiler"] + try: + dist.init_process_group("nccl") + distributed = det.core.DistributedContext.from_torch_distributed() + with det.core.init(distributed=distributed) as core_context: + if core_api_profiler: + core_context.profiler.on() + main( + core_context=core_context, + hparams=hparams, + checkpoint_uuid=checkpoint_uuid, + ) + if core_api_profiler: + core_context.profiler.off() + finally: + dist.destroy_process_group() diff --git a/fsdp/minimal_fsdp/model.py b/fsdp/minimal_fsdp/model.py new file mode 100644 index 0000000..9bfe8b9 --- /dev/null +++ b/fsdp/minimal_fsdp/model.py @@ -0,0 +1,184 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" +Minimal transformer model code adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast +""" + + +class FeedForward(nn.Module): + def __init__( + self, + d_model: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.w0 = nn.Linear(d_model, 4 * d_model, device=device) + self.relu = nn.ReLU() + self.w1 = nn.Linear(4 * d_model, d_model, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w1(self.relu(self.w0(x))) + + +class Attention(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + device: Optional[torch.device] = None, + ): + super().__init__() + assert d_model % n_heads == 0, "n_heads must divide d_model evenly" + self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False, device=device) + self.wo = nn.Linear(d_model, d_model, bias=False, device=device) + + self.d_model = d_model + self.n_heads = n_heads + self.head_dim = d_model // n_heads + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + bsz, seqlen, _ = inputs.shape + + # Get queries, keys, and values + q, k, v = self.wqkv(inputs).split( + [self.d_model, self.d_model, self.d_model], dim=-1 + ) + q = q.view(bsz, seqlen, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_heads, self.head_dim) + q, k, v = map(lambda inputs: inputs.transpose(1, 2), (q, k, v)) + + # Compute attention + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=0.0) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.d_model) + y = self.wo(y) + + return y + + +class TransformerBlock(nn.Module): + """ + The transformer blocks. + + Forward pass schematic: + + ┌──────┐ + │inputs│ + └┬─┬───┘ + │┌▽─────────┐ + ││norm, attn│ + │└┬─────────┘ + ┌▽─▽──┐ + │add │ + └┬─┬──┘ + │┌▽────────┐ + ││norm, ffn│ + │└┬────────┘ + ┌▽─▽──┐ + │add │ + └┬────┘ + ┌▽──────┐ + │outputs│ + └───────┘ + """ + + def __init__( + self, + d_model: int, + n_heads: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.attention = Attention(d_model=d_model, n_heads=n_heads, device=device) + self.feed_forward = FeedForward(d_model=d_model, device=device) + self.ffn_norm = nn.LayerNorm(d_model, device=device) + self.attention_norm = nn.LayerNorm(d_model, device=device) + + def forward( + self, + inputs: torch.Tensor, + ) -> torch.Tensor: + h = inputs + self.attention(self.attention_norm(inputs)) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class EmbedAndEncode(nn.Module): + """ + Embedding layer with learned positional encodings. + """ + + def __init__( + self, + d_model: int, + vocab_size: int, + max_seq_len: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + # Learned positional encoding and embedding layer: + self.max_seq_len = max_seq_len + self.learned_pos_enc = nn.Parameter( + torch.zeros(max_seq_len, d_model, device=device) + ) + self.tok_embeddings = nn.Embedding(vocab_size, d_model, device=device) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + _, seq_len = inputs.shape + assert seq_len <= self.max_seq_len + outputs = self.tok_embeddings(inputs) + self.learned_pos_enc[None, :seq_len] + return outputs + + +class LMHead(nn.Module): + def __init__( + self, + d_model: int, + vocab_size: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + self.norm = nn.LayerNorm(d_model, device=device) + self.output = nn.Linear(d_model, vocab_size, bias=False, device=device) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + logits = self.output(self.norm(inputs)) + return logits + + +class Transformer(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + vocab_size: int, + n_layers: int, + max_seq_len: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + # Embed/encode + self.embed_and_encode = EmbedAndEncode( + d_model, vocab_size, max_seq_len, device=device + ) + + # Transformer blocks + self.layers = nn.ModuleList( + TransformerBlock(d_model, n_heads, device=device) for _ in range(n_layers) + ) + + # Final norm and language model head: + self.lm_head = LMHead(d_model, vocab_size, device=device) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + x = self.embed_and_encode(inputs) + for layer in self.layers: + x = layer(x) + logits = self.lm_head(x) + return logits