Skip to content

Commit

Permalink
[Not for land] Seettings to make Llama3-8B on 8 GPUs faster
Browse files Browse the repository at this point in the history
ghstack-source-id: c5767bb4f3d7ad3330953ab97b9f06ff5c6917f5
Pull Request resolved: #615
  • Loading branch information
awgu committed Oct 14, 2024
1 parent 25ec560 commit 8b3677a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
4 changes: 2 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(self, model_args: ModelArgs):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
"fused_rmsnorm", dim=model_args.dim, eps=model_args.norm_eps
)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
Expand Down Expand Up @@ -438,7 +438,7 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
output = self.output(h) if self.output else h
return output

@classmethod
Expand Down
13 changes: 5 additions & 8 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,16 @@ def apply_fsdp(
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model.tok_embeddings, **fsdp_config)
# Embedding weight is not needed for embedding backward
model.tok_embeddings.set_unshard_in_backward(False)
fully_shard([model.output, model.norm], **fsdp_config, reshard_after_forward=False)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)


Expand Down
13 changes: 7 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(job_config: JobConfig):
# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

# apply parallelisms and initialization
Expand Down Expand Up @@ -271,6 +271,8 @@ def loss_fn(pred, labels):
ntokens_since_last_log += labels.numel()
data_loading_times.append(time.perf_counter() - data_load_start)

model.tok_embeddings.unshard(async_op=True)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()
Expand All @@ -297,11 +299,10 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
if job_config.training.compile:
loss = torch.compile(loss_fn)(model(input_ids), labels)
else:
loss = loss_fn(model(input_ids), labels)
loss.backward()

# clip gradients
Expand Down
20 changes: 11 additions & 9 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
profile_freq = 10
enable_memory_snapshot = false

[metrics]
log_freq = 10
enable_tensorboard = true
log_freq = 1
enable_color_printing = true
enable_tensorboard = false
save_tb_folder = "tb"

[model]
Expand All @@ -24,18 +26,18 @@ tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
fused = true

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
compile = true
dataset = "c4_test"

[experimental]
pipeline_parallel_degree = 1
Expand All @@ -50,7 +52,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down

0 comments on commit 8b3677a

Please sign in to comment.