diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 84666031..2a66f472 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -34,6 +34,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims +from torchtitan.parallelisms.utils import check_if_feature_in_pytorch def parallelize_llama( @@ -79,22 +80,31 @@ def parallelize_llama( if ( parallel_dims.dp_shard_enabled ): # apply FSDP or HSDP, potentially with Context Parallel - - # TODO: instead of flattening the mesh twice, we could've done in a batter way: - # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] - # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is - # a bug in DeviceMesh. We should fix it and then use the above line. - dp_mesh_dim_names = ( - ("dp_replicate", "dp_shard") - if parallel_dims.dp_replicate_enabled - else ("dp",) - ) - # note that mesh can only be flattened from the finest-grained mesh dimensions - dp_mesh = ( - world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") - if parallel_dims.cp_enabled - else world_mesh[dp_mesh_dim_names] - ) + try: + dp_mesh = ( + world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] + ) + except IndexError: + # note: this is a workaround of the above logic for old pytorch version + # where https://github.com/pytorch/pytorch/pull/138945 is not included + # throw a warning to encourage users to upgrade to a newer pytorch version + check_if_feature_in_pytorch( + "DeviceMesh flattening over 3D+ meshes", + "https://github.com/pytorch/pytorch/pull/138945", + "2.6.0.dev20241030", + ) + # TODO: remove this workaround once PyTorch 2.6 is released + dp_mesh_dim_names = ( + ("dp_replicate", "dp_shard") + if parallel_dims.dp_replicate_enabled + else ("dp",) + ) + # note that mesh can only be flattened from the finest-grained mesh dimensions + dp_mesh = ( + world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") + if parallel_dims.cp_enabled + else world_mesh[dp_mesh_dim_names] + ) apply_fsdp( model, diff --git a/torchtitan/parallelisms/utils.py b/torchtitan/parallelisms/utils.py new file mode 100644 index 00000000..a84af798 --- /dev/null +++ b/torchtitan/parallelisms/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from torchtitan.logging import logger + + +def check_if_feature_in_pytorch( + feature_name: str, + pull_request: str, + min_nightly_version: Optional[str] = None, +) -> None: + if "git" in torch.__version__: # pytorch is built from source + # notify users to check if the pull request is included in their pytorch + logger.warning( + "detected that the pytorch is built from source. Please make sure the PR " + f"({pull_request_link}) is included in pytorch for correct {feature_name}." + ) + elif min_nightly_version is not None and torch.__version__ < min_nightly_version: + logger.warning( + f"detected that the pytorch version {torch.__version__} is older than " + f"{min_nightly_version}. Please upgrade a newer version to include the " + f"change in ({pull_request_link}) for correct {feature_name}." + )