Skip to content

Commit

Permalink
Uriel's task_ddp_backend setting (#678)
Browse files Browse the repository at this point in the history
* fsdp wrap task

* bf16 cfg fix

* split out TASK_DDP_BACKEND_CHOICES

* add one more comment

* remove extra is_moe arg
  • Loading branch information
suchenzang authored Mar 13, 2023
1 parent 79785e4 commit 0714529
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
33 changes: 26 additions & 7 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,42 @@ def main(cfg: DictConfig) -> None:
logger.info(cfg)

# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
if cfg.distributed_training.task_ddp_backend == "fully_sharded":
# As the task is non-trainable, we switch flags to more optimized ones.
# See https://github.com/facebookresearch/metaseq/pull/668 for when/why this was added.
orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16
orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter
# Clobber memory_efficient_fp16 and fp32_reduce_scatter
cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16
cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16

with fsdp_enable_wrap(
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
task = tasks.setup_task(cfg.task)

# Reset memory_efficient_fp16 and fp32_reduce_scatter values.
cfg.distributed_training.memory_efficient_fp16 = orig_memory_efficient_fp16
cfg.distributed_training.fp32_reduce_scatter = orig_fp32_reduce_scatter
else:
task = tasks.setup_task(cfg.task)

# Build model and criterion
assert cfg.criterion, "Please specify criterion to train a model"

# Build model and criterion
if cfg.distributed_training.ddp_backend == "fully_sharded":
extra = {
"use_sharded_state": cfg.distributed_training.use_sharded_state,
}

with fsdp_enable_wrap(cfg.distributed_training, **extra):
with fsdp_enable_wrap(
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
model = fsdp_wrap(
task.build_model(cfg.model),
process_group=distributed_utils.get_data_parallel_group(),
)
else:
model = task.build_model(cfg.model)

# TODO[Susan]: FSDP on criterion?
criterion = task.build_criterion(cfg.criterion)

Expand Down
6 changes: 6 additions & 0 deletions metaseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from metaseq.dataclass.constants import (
DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
TASK_DDP_BACKEND_CHOICES,
LOG_FORMAT_CHOICES,
CLIP_GRAD_NORM_TYPE_CHOICES,
)
Expand Down Expand Up @@ -247,6 +248,11 @@ class DistributedTrainingConfig(MetaseqDataclass):
ddp_backend: DDP_BACKEND_CHOICES = field(
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
)
# Reference: https://github.com/facebookresearch/metaseq/pull/668
task_ddp_backend: TASK_DDP_BACKEND_CHOICES = field(
default="none",
metadata={"help": "If set to fully_sharded, will fsdp wrap task."},
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
Expand Down
7 changes: 7 additions & 0 deletions metaseq/dataclass/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def ChoiceEnum(choices: List[str]):
"pytorch_ddp",
]
)

TASK_DDP_BACKEND_CHOICES = ChoiceEnum(
[
"none", # default
"fully_sharded", # FSDP wraps task. See https://github.com/facebookresearch/metaseq/pull/668/
]
)
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"])
CLIP_GRAD_NORM_TYPE_CHOICES = ChoiceEnum(["l2", "inf"])

Expand Down
4 changes: 2 additions & 2 deletions metaseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def _build_ema(self):
if self.is_fsdp:
# Build FSDP model
extra = {
"is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0,
"use_sharded_state": self.use_sharded_state,
}
with fsdp_enable_wrap(self.cfg.distributed_training, **extra):
Expand Down Expand Up @@ -1212,11 +1211,12 @@ def _prepare_sample(self, sample, is_dummy=False):
def lower_precision(t):
"""Converts a tensor to the desired dtype based on our cfg."""
if t.dtype is torch.float32:
if self.cfg.common.bf16 or self.cfg.bf16:
if self.cfg.common.bf16:
return t.bfloat16()
return t.half()
return t

# TODO[Susan]: sample dict is full of int64 tensors - check this.
if self.cfg.common.fp16:
sample = utils.apply_to_sample(lower_precision, sample)

Expand Down

0 comments on commit 0714529

Please sign in to comment.