Skip to content

Commit

Permalink
[Not for land] Settings to make Llama3-8B on 8 GPUs faster
Browse files Browse the repository at this point in the history
ghstack-source-id: bdaa49373bb992258483b4a6c5ceb37b826c0d86
Pull Request resolved: #615
  • Loading branch information
awgu committed Oct 22, 2024
1 parent 0edd2fb commit 98e6528
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 32 deletions.
16 changes: 7 additions & 9 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,14 @@ def forward(

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)
xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)

# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = F.scaled_dot_product_attention(
xq, xk, xv, is_causal=True, enable_gqa=self.n_rep > 1
)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
Expand Down Expand Up @@ -373,7 +371,7 @@ def __init__(self, model_args: ModelArgs):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
"fused_rmsnorm", dim=model_args.dim, eps=model_args.norm_eps
)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
Expand Down Expand Up @@ -438,7 +436,7 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
output = self.output(h) if self.output else h
return output

@classmethod
Expand Down
13 changes: 5 additions & 8 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,16 @@ def apply_fsdp(
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model.tok_embeddings, **fsdp_config)
# Embedding weight is not needed for embedding backward
model.tok_embeddings.set_unshard_in_backward(False)
fully_shard([model.output, model.norm], **fsdp_config, reshard_after_forward=False)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)


Expand Down
16 changes: 10 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import timedelta

import torch
import torch._inductor.config as inductor_config
from torch.distributed.elastic.multiprocessing.errors import record

from torchtitan import utils
Expand Down Expand Up @@ -142,11 +143,13 @@ def main(job_config: JobConfig):
f"{color.blue}Model {model_name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)
if job_config.training.compile:
inductor_config.coordinate_descent_tuning = True

# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

# apply parallelisms and initialization
Expand Down Expand Up @@ -271,6 +274,8 @@ def loss_fn(pred, labels):
ntokens_since_last_log += labels.numel()
data_loading_times.append(time.perf_counter() - data_load_start)

model.tok_embeddings.unshard(async_op=True)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()
Expand All @@ -297,11 +302,10 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
if job_config.training.compile:
loss = torch.compile(loss_fn)(model(input_ids), labels)
else:
loss = loss_fn(model(input_ids), labels)
loss.backward()

# clip gradients
Expand Down
20 changes: 11 additions & 9 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
profile_freq = 10
enable_memory_snapshot = false

[metrics]
log_freq = 10
enable_tensorboard = true
log_freq = 1
enable_color_printing = true
enable_tensorboard = false
save_tb_folder = "tb"

[model]
Expand All @@ -24,18 +26,18 @@ tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
fused = true

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
compile = true
dataset = "c4_test"

[experimental]
pipeline_parallel_degree = 1
Expand All @@ -50,7 +52,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down

0 comments on commit 98e6528

Please sign in to comment.