Skip to content

Commit

Permalink
70B training + other tweaks (#295)
Browse files Browse the repository at this point in the history
Adding 70B training config, plus a few small fixes/tweaks for finetune.py
  • Loading branch information
jacob-morrison authored Aug 27, 2024
1 parent dcae531 commit 082b400
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 7 deletions.
4 changes: 0 additions & 4 deletions configs/beaker_configs/default_finetune_multinode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ tasks:
value: false
- name: WANDB_DISABLED
value: true
- name: NCCL_NET
value: IB
- name: NCCL_DEBUG
value: INFO
- name: HF_TOKEN
secret: HF_TOKEN
result:
Expand Down
30 changes: 30 additions & 0 deletions configs/train_configs/sft/tulu3_L3.1_70b_preview_mix_v3.3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model_name_or_path: meta-llama/Meta-Llama-3.1-70B
model_revision: main
use_flash_attn: true
tokenizer_name: meta-llama/Meta-Llama-3.1-70B
use_slow_tokenizer: true
dataset_mixer:
# Tulu V2 datasets
ai2-adapt-dev/llama-3-tulu-v2-sft-mixture-with-subset-llama-405b-completions-code_alpaca-open_orca-gpt4_alpaca: 326154
# Tulu V3 datasets (WIP)
HuggingFaceH4/no_robots: 9500 # all
ai2-adapt-dev/metamath-qa-reformat: 100000
ai2-adapt-dev/codefeedback-single-turn-reformat: 156526 # all
nvidia/Daring-Anteater: 99532 # all
max_seq_length: 4096
preprocessing_num_workers: 128
per_device_train_batch_size: 1 # note, this is set up for 8 GPUs
gradient_accumulation_steps: 4 # effective batch size 128 with 4 nodes
learning_rate: 5.0e-06 # best LR so far
lr_scheduler_type: linear
warmup_ratio: 0.03
weight_decay: 0.0
num_train_epochs: 2
output_dir: /output/
with_tracking: true
report_to:
- wandb
logging_steps: 1
checkpointing_steps: epoch
dataset_mix_dir: /output/
gradient_checkpointing: true
13 changes: 10 additions & 3 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ class FlatArguments:
"""The url of the saved model in the Hugging Face Hub (will be autoset)"""
try_launch_beaker_eval_jobs: bool = True
"""Whether to launch beaker evaluation jobs after training"""
fused_optimizer: bool = field(
default=True,
metadata={
"help": "Whether to use fused AdamW or not.",
},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
Expand Down Expand Up @@ -598,7 +604,7 @@ def main(args: FlatArguments):
device_map=device_map,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True if args.use_flash_attn else False,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
revision=args.model_revision,
token=os.getenv("HF_TOKEN", None),
)
Expand All @@ -609,7 +615,8 @@ def main(args: FlatArguments):
config=config,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
revision=args.model_revision,
token=os.getenv("HF_TOKEN", None),
)
Expand Down Expand Up @@ -780,7 +787,7 @@ def main(args: FlatArguments):
is_paged=True,
)
else:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, fused=args.fused_optimizer)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand Down
28 changes: 28 additions & 0 deletions scripts/submit_finetune_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,34 @@ def parse_args(args):
d['description'] = exp_name
d['tasks'][0]['name'] = exp_name

# add cluster-specific env vars
if args.cluster == "ai2/jupiter-cirrascale-2":
d['tasks'][0]['envVars'] += [
{
"name": "NCCL_SOCKET_IFNAME",
"value": "ib",
},
{
"name": "NCCL_IB_HCA",
"value": "^=mlx5_bond_0",
},
{
"name": "NCCL_DEBUG",
"value": "INFO",
},
]
elif args.cluster == "ai2/pluto-cirrascale":
d['tasks'][0]['envVars'] += [
{
"name": "NCCL_IB_HCA",
"value": "^=mlx5_1,mlx5_2",
},
{
"name": "NCCL_DEBUG",
"value": "INFO",
},
]

# WANDB settings
for env in d['tasks'][0]['envVars']:
if env['name'] == "WANDB_DISABLED":
Expand Down

0 comments on commit 082b400

Please sign in to comment.