diff --git a/d2go/modeling/api.py b/d2go/modeling/api.py index 1d7e1c03..8124fe5e 100644 --- a/d2go/modeling/api.py +++ b/d2go/modeling/api.py @@ -63,7 +63,7 @@ def build_d2go_model( and "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS and hasattr(cfg, "FSDP") and hasattr(cfg.FSDP, "DISTRIBUTED_INIT") - and cfg.FSDP.DISTRIBUTED_INIT + and cfg.FSDP.DISTRIBUTED_INIT.ENABLED ): logger.info("Using distributed initialization path.") import torch.distributed as dist @@ -72,13 +72,18 @@ def build_d2go_model( from d2go.trainer.fsdp import CpuOverrideMode from torch._subclasses import FakeTensorMode - # NOTE (global) rank 0 will build the whole model on cpu - # other ranks will build the model on fake tensors - if dist.get_rank() == 0: - with CpuOverrideMode(): - model = build_meta_arch(cfg) + if cfg.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST: + # NOTE (global) rank 0 will build the whole model on cpu + # other ranks will build the model on fake tensors + if dist.get_rank() == 0: + with CpuOverrideMode(): + model = build_meta_arch(cfg) + else: + with FakeTensorMode(allow_non_fake_inputs=True): + model = build_meta_arch(cfg) else: - with FakeTensorMode(allow_non_fake_inputs=True): + # all ranks will build the model on cpu first + with CpuOverrideMode(): model = build_meta_arch(cfg) else: raise RuntimeError( diff --git a/d2go/trainer/fsdp.py b/d2go/trainer/fsdp.py index d3a46e8e..e311b0b1 100644 --- a/d2go/trainer/fsdp.py +++ b/d2go/trainer/fsdp.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional, Set, Tuple import torch +import torch.distributed as dist import torch.nn as nn from d2go.config import CfgNode as CN from d2go.modeling.modeling_hook import ModelingHook @@ -71,7 +72,12 @@ def add_fsdp_configs(_C: CN): # if False, this allows the CPU thread to schedule all-gathers without any extra synchronization _C.FSDP.LIMIT_ALL_GATHERS = False # flag for distributed FSDP model initialization - _C.FSDP.DISTRIBUTED_INIT = False + _C.FSDP.DISTRIBUTED_INIT = CN() + _C.FSDP.DISTRIBUTED_INIT.ENABLED = False + # whether to build full model on rank 0 cpu only + # and meta model on all other ranks + _C.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST = False + _C.FSDP.DISTRIBUTED_INIT.VERBOSE = False class ShardingAlgorithm(str, Enum): @@ -111,9 +117,9 @@ def get_grad_scaler(cfg): return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler() -def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]): - import torch.distributed as dist - +def bottom_up_nested_fsdp( + root_module, fsdp_kwargs: Dict[str, Any], rank0_broadcast: bool, verbose: bool +): modules_to_fsdp: Tuple = tuple(fsdp_kwargs["auto_wrap_policy"]._module_classes) del fsdp_kwargs["auto_wrap_policy"] modules_not_to_fsdp: List = fsdp_kwargs["ignored_modules"] @@ -129,6 +135,8 @@ def postorder_fsdp_wrap( fqn: str, parent_module: Optional[nn.Module], ignore_branch: bool, + rank0_broadcast: bool, + verbose: bool, ): rank = dist.get_rank() @@ -144,58 +152,71 @@ def postorder_fsdp_wrap( f"{fqn}.{child_name}", module, ignore_branch, + rank0_broadcast, + verbose, ) - logger.info( - f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}" - ) + if verbose: + logger.info( + f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}" + ) # regardless of wrapping, we need to transfer all # module params and buffers to device, and if not rank 0, # need to retreive data from rank 0 with torch.no_grad(): - if rank != 0: - with no_dispatch(): - for name, param in module.named_parameters(recurse=False): - setattr( - module, - name, - torch.nn.Parameter( - torch.empty_like(param, device=cuda_device), - requires_grad=param.requires_grad, - ), - ) - for name, buffer in module.named_buffers(recurse=False): - setattr( - module, - name, - torch.empty_like( - buffer, - device=cuda_device, - requires_grad=buffer.requires_grad, - ), - ) + if rank0_broadcast: + if rank != 0: + with no_dispatch(): + for name, param in module.named_parameters(recurse=False): + setattr( + module, + name, + torch.nn.Parameter( + torch.empty_like(param, device=cuda_device), + requires_grad=param.requires_grad, + ), + ) + for name, buffer in module.named_buffers(recurse=False): + setattr( + module, + name, + torch.empty_like( + buffer, + device=cuda_device, + requires_grad=buffer.requires_grad, + ), + ) + else: + for _, param in module.named_parameters(recurse=False): + param.data = param.to(cuda_device) + for _, buffer in module.named_buffers(recurse=False): + buffer.data = buffer.to(cuda_device) + for _, param in module.named_parameters(recurse=False): + dist.broadcast(param, 0) + for _, buffer in module.named_buffers(recurse=False): + dist.broadcast(buffer, 0) else: for _, param in module.named_parameters(recurse=False): param.data = param.to(cuda_device) for _, buffer in module.named_buffers(recurse=False): buffer.data = buffer.to(cuda_device) - for _, param in module.named_parameters(recurse=False): - dist.broadcast(param, 0) - for _, buffer in module.named_buffers(recurse=False): - dist.broadcast(buffer, 0) # if module is marked for FSDP, wrap it # AND if not in ignored branch if not ignore_branch and isinstance(module, modules_to_fsdp): - logger.info( - f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}" - ) + if verbose: + logger.info( + f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}" + ) setattr(parent_module, module_name, FSDP(module, **fsdp_kwargs)) - logger.info( - f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}" - ) + if verbose: + logger.info( + f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}" + ) - postorder_fsdp_wrap(root_module, "root", "root", None, False) + postorder_fsdp_wrap( + root_module, "root", "root", None, False, rank0_broadcast, verbose + ) class FSDPWrapper(FSDP): @@ -208,6 +229,8 @@ def __init__( state_dict_cpu_offload: bool = True, state_dict_rank0_only: bool = True, distributed_init: bool = False, + distributed_init_rank0_broadcast: bool = False, + distributed_init_verbose: bool = False, **fsdp_kwargs, ): self.precision = amp_autocast_dtype @@ -220,7 +243,14 @@ def __init__( if self.distributed_init: # NOTE traverse and apply all non-root level FSDP # and then wrap root level FSDP - bottom_up_nested_fsdp(model, fsdp_kwargs) + logger.info(f"(Distributed FSDP init) Rank {dist.get_rank()} Beginning") + bottom_up_nested_fsdp( + model, + fsdp_kwargs, + rank0_broadcast=distributed_init_rank0_broadcast, + verbose=distributed_init_verbose, + ) + logger.info(f"(Distributed FSDP init) Rank {dist.get_rank()} Finished") super().__init__(model, **fsdp_kwargs) logger.info(f"FSDP Wrapped model architecture: {self}") @@ -289,6 +319,8 @@ def build_fsdp( device_id: Optional[int] = None, limit_all_gathers: bool = False, distributed_init: bool = False, + distributed_init_rank0_broadcast: bool = False, + distributed_init_verbose: bool = False, ): if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP: sharding_strategy = ShardingStrategy.SHARD_GRAD_OP @@ -365,6 +397,8 @@ def build_fsdp( "state_dict_cpu_offload": state_dict_cpu_offload, "state_dict_rank0_only": state_dict_rank0_only, "distributed_init": distributed_init, + "distributed_init_rank0_broadcast": distributed_init_rank0_broadcast, + "distributed_init_verbose": distributed_init_verbose, } return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs) @@ -422,7 +456,9 @@ def apply(self, model: nn.Module) -> FSDPWrapper: use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS, device_id=torch.cuda.current_device(), limit_all_gathers=self.cfg.FSDP.LIMIT_ALL_GATHERS, - distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT, + distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT.ENABLED, + distributed_init_rank0_broadcast=self.cfg.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST, + distributed_init_verbose=self.cfg.FSDP.DISTRIBUTED_INIT.VERBOSE, ) return wrapped_model