From 4f729c298cddc17dfe06f8a36561e8f66718e734 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 30 Oct 2024 22:35:11 -0700 Subject: [PATCH] Revert "[BE] replace the extra DeviceMesh _flatten with mesh access (#666)" This reverts commit 3653bf290c62e312f31ddc89df75bc18d0e163ad. --- torchtitan/parallelisms/parallelize_llama.py | 42 ++++++++------------ torchtitan/parallelisms/utils.py | 28 ------------- 2 files changed, 16 insertions(+), 54 deletions(-) delete mode 100644 torchtitan/parallelisms/utils.py diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 2a66f472..84666031 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -34,7 +34,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_if_feature_in_pytorch def parallelize_llama( @@ -80,31 +79,22 @@ def parallelize_llama( if ( parallel_dims.dp_shard_enabled ): # apply FSDP or HSDP, potentially with Context Parallel - try: - dp_mesh = ( - world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] - ) - except IndexError: - # note: this is a workaround of the above logic for old pytorch version - # where https://github.com/pytorch/pytorch/pull/138945 is not included - # throw a warning to encourage users to upgrade to a newer pytorch version - check_if_feature_in_pytorch( - "DeviceMesh flattening over 3D+ meshes", - "https://github.com/pytorch/pytorch/pull/138945", - "2.6.0.dev20241030", - ) - # TODO: remove this workaround once PyTorch 2.6 is released - dp_mesh_dim_names = ( - ("dp_replicate", "dp_shard") - if parallel_dims.dp_replicate_enabled - else ("dp",) - ) - # note that mesh can only be flattened from the finest-grained mesh dimensions - dp_mesh = ( - world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") - if parallel_dims.cp_enabled - else world_mesh[dp_mesh_dim_names] - ) + + # TODO: instead of flattening the mesh twice, we could've done in a batter way: + # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] + # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is + # a bug in DeviceMesh. We should fix it and then use the above line. + dp_mesh_dim_names = ( + ("dp_replicate", "dp_shard") + if parallel_dims.dp_replicate_enabled + else ("dp",) + ) + # note that mesh can only be flattened from the finest-grained mesh dimensions + dp_mesh = ( + world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") + if parallel_dims.cp_enabled + else world_mesh[dp_mesh_dim_names] + ) apply_fsdp( model, diff --git a/torchtitan/parallelisms/utils.py b/torchtitan/parallelisms/utils.py deleted file mode 100644 index a84af798..00000000 --- a/torchtitan/parallelisms/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import Optional - -import torch -from torchtitan.logging import logger - - -def check_if_feature_in_pytorch( - feature_name: str, - pull_request: str, - min_nightly_version: Optional[str] = None, -) -> None: - if "git" in torch.__version__: # pytorch is built from source - # notify users to check if the pull request is included in their pytorch - logger.warning( - "detected that the pytorch is built from source. Please make sure the PR " - f"({pull_request_link}) is included in pytorch for correct {feature_name}." - ) - elif min_nightly_version is not None and torch.__version__ < min_nightly_version: - logger.warning( - f"detected that the pytorch version {torch.__version__} is older than " - f"{min_nightly_version}. Please upgrade a newer version to include the " - f"change in ({pull_request_link}) for correct {feature_name}." - )