Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distributed init updates (addition of non-meta-tensor/rank0-broadcast path) #674

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions d2go/modeling/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build_d2go_model(
and "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS
and hasattr(cfg, "FSDP")
and hasattr(cfg.FSDP, "DISTRIBUTED_INIT")
and cfg.FSDP.DISTRIBUTED_INIT
and cfg.FSDP.DISTRIBUTED_INIT.ENABLED
):
logger.info("Using distributed initialization path.")
import torch.distributed as dist
Expand All @@ -72,13 +72,18 @@ def build_d2go_model(
from d2go.trainer.fsdp import CpuOverrideMode
from torch._subclasses import FakeTensorMode

# NOTE (global) rank 0 will build the whole model on cpu
# other ranks will build the model on fake tensors
if dist.get_rank() == 0:
with CpuOverrideMode():
model = build_meta_arch(cfg)
if cfg.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST:
# NOTE (global) rank 0 will build the whole model on cpu
# other ranks will build the model on fake tensors
if dist.get_rank() == 0:
with CpuOverrideMode():
model = build_meta_arch(cfg)
else:
with FakeTensorMode(allow_non_fake_inputs=True):
model = build_meta_arch(cfg)
else:
with FakeTensorMode(allow_non_fake_inputs=True):
# all ranks will build the model on cpu first
with CpuOverrideMode():
model = build_meta_arch(cfg)
else:
raise RuntimeError(
Expand Down
118 changes: 77 additions & 41 deletions d2go/trainer/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, Generator, List, Optional, Set, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.modeling.modeling_hook import ModelingHook
Expand Down Expand Up @@ -71,7 +72,12 @@ def add_fsdp_configs(_C: CN):
# if False, this allows the CPU thread to schedule all-gathers without any extra synchronization
_C.FSDP.LIMIT_ALL_GATHERS = False
# flag for distributed FSDP model initialization
_C.FSDP.DISTRIBUTED_INIT = False
_C.FSDP.DISTRIBUTED_INIT = CN()
_C.FSDP.DISTRIBUTED_INIT.ENABLED = False
# whether to build full model on rank 0 cpu only
# and meta model on all other ranks
_C.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST = False
_C.FSDP.DISTRIBUTED_INIT.VERBOSE = False


class ShardingAlgorithm(str, Enum):
Expand Down Expand Up @@ -111,9 +117,9 @@ def get_grad_scaler(cfg):
return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler()


def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]):
import torch.distributed as dist

def bottom_up_nested_fsdp(
root_module, fsdp_kwargs: Dict[str, Any], rank0_broadcast: bool, verbose: bool
):
modules_to_fsdp: Tuple = tuple(fsdp_kwargs["auto_wrap_policy"]._module_classes)
del fsdp_kwargs["auto_wrap_policy"]
modules_not_to_fsdp: List = fsdp_kwargs["ignored_modules"]
Expand All @@ -129,6 +135,8 @@ def postorder_fsdp_wrap(
fqn: str,
parent_module: Optional[nn.Module],
ignore_branch: bool,
rank0_broadcast: bool,
verbose: bool,
):

rank = dist.get_rank()
Expand All @@ -144,58 +152,71 @@ def postorder_fsdp_wrap(
f"{fqn}.{child_name}",
module,
ignore_branch,
rank0_broadcast,
verbose,
)

logger.info(
f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}"
)
if verbose:
logger.info(
f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}"
)
# regardless of wrapping, we need to transfer all
# module params and buffers to device, and if not rank 0,
# need to retreive data from rank 0
with torch.no_grad():
if rank != 0:
with no_dispatch():
for name, param in module.named_parameters(recurse=False):
setattr(
module,
name,
torch.nn.Parameter(
torch.empty_like(param, device=cuda_device),
requires_grad=param.requires_grad,
),
)
for name, buffer in module.named_buffers(recurse=False):
setattr(
module,
name,
torch.empty_like(
buffer,
device=cuda_device,
requires_grad=buffer.requires_grad,
),
)
if rank0_broadcast:
if rank != 0:
with no_dispatch():
for name, param in module.named_parameters(recurse=False):
setattr(
module,
name,
torch.nn.Parameter(
torch.empty_like(param, device=cuda_device),
requires_grad=param.requires_grad,
),
)
for name, buffer in module.named_buffers(recurse=False):
setattr(
module,
name,
torch.empty_like(
buffer,
device=cuda_device,
requires_grad=buffer.requires_grad,
),
)
else:
for _, param in module.named_parameters(recurse=False):
param.data = param.to(cuda_device)
for _, buffer in module.named_buffers(recurse=False):
buffer.data = buffer.to(cuda_device)
for _, param in module.named_parameters(recurse=False):
dist.broadcast(param, 0)
for _, buffer in module.named_buffers(recurse=False):
dist.broadcast(buffer, 0)
else:
for _, param in module.named_parameters(recurse=False):
param.data = param.to(cuda_device)
for _, buffer in module.named_buffers(recurse=False):
buffer.data = buffer.to(cuda_device)
for _, param in module.named_parameters(recurse=False):
dist.broadcast(param, 0)
for _, buffer in module.named_buffers(recurse=False):
dist.broadcast(buffer, 0)

# if module is marked for FSDP, wrap it
# AND if not in ignored branch
if not ignore_branch and isinstance(module, modules_to_fsdp):
logger.info(
f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}"
)
if verbose:
logger.info(
f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}"
)
setattr(parent_module, module_name, FSDP(module, **fsdp_kwargs))
logger.info(
f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}"
)
if verbose:
logger.info(
f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}"
)

postorder_fsdp_wrap(root_module, "root", "root", None, False)
postorder_fsdp_wrap(
root_module, "root", "root", None, False, rank0_broadcast, verbose
)


class FSDPWrapper(FSDP):
Expand All @@ -208,6 +229,8 @@ def __init__(
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
distributed_init: bool = False,
distributed_init_rank0_broadcast: bool = False,
distributed_init_verbose: bool = False,
**fsdp_kwargs,
):
self.precision = amp_autocast_dtype
Expand All @@ -220,7 +243,14 @@ def __init__(
if self.distributed_init:
# NOTE traverse and apply all non-root level FSDP
# and then wrap root level FSDP
bottom_up_nested_fsdp(model, fsdp_kwargs)
logger.info(f"(Distributed FSDP init) Rank {dist.get_rank()} Beginning")
bottom_up_nested_fsdp(
model,
fsdp_kwargs,
rank0_broadcast=distributed_init_rank0_broadcast,
verbose=distributed_init_verbose,
)
logger.info(f"(Distributed FSDP init) Rank {dist.get_rank()} Finished")
super().__init__(model, **fsdp_kwargs)
logger.info(f"FSDP Wrapped model architecture: {self}")

Expand Down Expand Up @@ -289,6 +319,8 @@ def build_fsdp(
device_id: Optional[int] = None,
limit_all_gathers: bool = False,
distributed_init: bool = False,
distributed_init_rank0_broadcast: bool = False,
distributed_init_verbose: bool = False,
):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
Expand Down Expand Up @@ -365,6 +397,8 @@ def build_fsdp(
"state_dict_cpu_offload": state_dict_cpu_offload,
"state_dict_rank0_only": state_dict_rank0_only,
"distributed_init": distributed_init,
"distributed_init_rank0_broadcast": distributed_init_rank0_broadcast,
"distributed_init_verbose": distributed_init_verbose,
}

return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs)
Expand Down Expand Up @@ -422,7 +456,9 @@ def apply(self, model: nn.Module) -> FSDPWrapper:
use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS,
device_id=torch.cuda.current_device(),
limit_all_gathers=self.cfg.FSDP.LIMIT_ALL_GATHERS,
distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT,
distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT.ENABLED,
distributed_init_rank0_broadcast=self.cfg.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST,
distributed_init_verbose=self.cfg.FSDP.DISTRIBUTED_INIT.VERBOSE,
)
return wrapped_model

Expand Down
Loading