diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index a09703d7..2a66f472 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -34,7 +34,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_strided_sharding_enabled +from torchtitan.parallelisms.utils import check_if_feature_in_pytorch def parallelize_llama( @@ -80,8 +80,31 @@ def parallelize_llama( if ( parallel_dims.dp_shard_enabled ): # apply FSDP or HSDP, potentially with Context Parallel - - dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] + try: + dp_mesh = ( + world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] + ) + except IndexError: + # note: this is a workaround of the above logic for old pytorch version + # where https://github.com/pytorch/pytorch/pull/138945 is not included + # throw a warning to encourage users to upgrade to a newer pytorch version + check_if_feature_in_pytorch( + "DeviceMesh flattening over 3D+ meshes", + "https://github.com/pytorch/pytorch/pull/138945", + "2.6.0.dev20241030", + ) + # TODO: remove this workaround once PyTorch 2.6 is released + dp_mesh_dim_names = ( + ("dp_replicate", "dp_shard") + if parallel_dims.dp_replicate_enabled + else ("dp",) + ) + # note that mesh can only be flattened from the finest-grained mesh dimensions + dp_mesh = ( + world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") + if parallel_dims.cp_enabled + else world_mesh[dp_mesh_dim_names] + ) apply_fsdp( model, @@ -316,12 +339,6 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - # TODO: remove this check once PyTorch 2.5 is released. We can safely assume - # that users won't use a nightly build which is older than 20240809 by then. - if tp_enabled: - # check if strided sharding is enabled, which is necessary for 2D/3D DCP - check_strided_sharding_enabled() - for layer_id, transformer_block in model.layers.items(): if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch diff --git a/torchtitan/parallelisms/utils.py b/torchtitan/parallelisms/utils.py index a82ace7a..a84af798 100644 --- a/torchtitan/parallelisms/utils.py +++ b/torchtitan/parallelisms/utils.py @@ -3,28 +3,26 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch +from typing import Optional +import torch from torchtitan.logging import logger -def check_strided_sharding_enabled() -> None: - # Correct 2D/3D DCP usage requires DTensor's strided sharding in PR - # https://github.com/pytorch/pytorch/pull/130760. This function checks if users' - # PyTorch nightly-build version is newer than 2024-08-09 to make sure this PR is - # included when 2D/3D DCP is used. +def check_if_feature_in_pytorch( + feature_name: str, + pull_request: str, + min_nightly_version: Optional[str] = None, +) -> None: if "git" in torch.__version__: # pytorch is built from source - # notify users to check if the commit hash is newer than 2024-08-09 + # notify users to check if the pull request is included in their pytorch logger.warning( "detected that the pytorch is built from source. Please make sure the PR " - "(https://github.com/pytorch/pytorch/pull/130760) is included in pytorch " - "for correct 2D/3D DCP usage." + f"({pull_request_link}) is included in pytorch for correct {feature_name}." ) - elif torch.__version__ < "2.5.0.dev20240809": - # the nightly build pytorch was built before 2024-08-09 + elif min_nightly_version is not None and torch.__version__ < min_nightly_version: logger.warning( f"detected that the pytorch version {torch.__version__} is older than " - "2.5.0.dev20240809. Please upgrade a newer version to include the change " - "made in https://github.com/pytorch/pytorch/pull/130760 for correct 2D/3D " - "DCP usage." + f"{min_nightly_version}. Please upgrade a newer version to include the " + f"change in ({pull_request_link}) for correct {feature_name}." )