From 98e6528345b3fa85ad6848dd97dd50d6e2d038dd Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 21 Oct 2024 21:28:40 -0700 Subject: [PATCH] [Not for land] Settings to make Llama3-8B on 8 GPUs faster ghstack-source-id: bdaa49373bb992258483b4a6c5ceb37b826c0d86 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/615 --- torchtitan/models/llama/model.py | 16 +++++++--------- torchtitan/parallelisms/parallelize_llama.py | 13 +++++-------- train.py | 16 ++++++++++------ train_configs/llama3_8b.toml | 20 +++++++++++--------- 4 files changed, 33 insertions(+), 32 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 7f102a80..5cb24db2 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -198,16 +198,14 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim) + xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim) # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = F.scaled_dot_product_attention( + xq, xk, xv, is_causal=True, enable_gqa=self.n_rep > 1 + ) output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) @@ -373,7 +371,7 @@ def __init__(self, model_args: ModelArgs): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) self.norm = build_norm( - model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + "fused_rmsnorm", dim=model_args.dim, eps=model_args.norm_eps ) self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) @@ -438,7 +436,7 @@ def forward(self, tokens: torch.Tensor): h = layer(h, self.freqs_cis) h = self.norm(h) if self.norm else h - output = self.output(h).float() if self.output else h + output = self.output(h) if self.output else h return output @classmethod diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fc26703d..09b21ee8 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -313,19 +313,16 @@ def apply_fsdp( check_strided_sharding_enabled() for layer_id, transformer_block in model.layers.items(): - if pp_enabled: - # For PP, do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = False - else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + reshard_after_forward = False fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + fully_shard(model.tok_embeddings, **fsdp_config) + # Embedding weight is not needed for embedding backward + model.tok_embeddings.set_unshard_in_backward(False) + fully_shard([model.output, model.norm], **fsdp_config, reshard_after_forward=False) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) diff --git a/train.py b/train.py index 3e8994a3..cdae3952 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ from datetime import timedelta import torch +import torch._inductor.config as inductor_config from torch.distributed.elastic.multiprocessing.errors import record from torchtitan import utils @@ -142,11 +143,13 @@ def main(job_config: JobConfig): f"{color.blue}Model {model_name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) + if job_config.training.compile: + inductor_config.coordinate_descent_tuning = True # loss function to be shared by Pipeline Parallel and SPMD training def loss_fn(pred, labels): return torch.nn.functional.cross_entropy( - pred.flatten(0, 1), labels.flatten(0, 1) + pred.flatten(0, 1).float(), labels.flatten(0, 1) ) # apply parallelisms and initialization @@ -271,6 +274,8 @@ def loss_fn(pred, labels): ntokens_since_last_log += labels.numel() data_loading_times.append(time.perf_counter() - data_load_start) + model.tok_embeddings.unshard(async_op=True) + input_ids = input_ids.cuda() labels = labels.cuda() optimizers.zero_grad() @@ -297,11 +302,10 @@ def loss_fn(pred, labels): else: # Non-PP forward / backward with train_context(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - # pred.shape=(bs, seq_len, vocab_size) - # need to free to before bwd to avoid peaking memory - del pred + if job_config.training.compile: + loss = torch.compile(loss_fn)(model(input_ids), labels) + else: + loss = loss_fn(model(input_ids), labels) loss.backward() # clip gradients diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index e0c5bd03..cbdd461e 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -6,13 +6,15 @@ dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" -profile_freq = 100 +profile_freq = 10 +enable_memory_snapshot = false [metrics] -log_freq = 10 -enable_tensorboard = true +log_freq = 1 +enable_color_printing = true +enable_tensorboard = false save_tb_folder = "tb" [model] @@ -24,6 +26,7 @@ tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 +fused = true [training] batch_size = 1 @@ -31,11 +34,10 @@ seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 +data_parallel_degree = -1 tensor_parallel_degree = 1 -compile = false -dataset = "c4" +compile = true +dataset = "c4_test" [experimental] pipeline_parallel_degree = 1 @@ -50,7 +52,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'selective' # ['none', 'selective', 'full'] +mode = 'none' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8]