Skip to content

Commit

Permalink
Update on "[BE] replace the extra DeviceMesh _flatten with mesh access"
Browse files Browse the repository at this point in the history
**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Oct 31, 2024
2 parents 47356dc + 03d27ce commit e35e17d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
35 changes: 26 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_strided_sharding_enabled
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch


def parallelize_llama(
Expand Down Expand Up @@ -80,8 +80,31 @@ def parallelize_llama(
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel

dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
try:
dp_mesh = (
world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
)
except IndexError:
# note: this is a workaround of the above logic for old pytorch version
# where https://github.com/pytorch/pytorch/pull/138945 is not included
# throw a warning to encourage users to upgrade to a newer pytorch version
check_if_feature_in_pytorch(
"DeviceMesh flattening over 3D+ meshes",
"https://github.com/pytorch/pytorch/pull/138945",
"2.6.0.dev20241030",
)
# TODO: remove this workaround once PyTorch 2.6 is released
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)

apply_fsdp(
model,
Expand Down Expand Up @@ -316,12 +339,6 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand Down
26 changes: 12 additions & 14 deletions torchtitan/parallelisms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,26 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from typing import Optional

import torch
from torchtitan.logging import logger


def check_strided_sharding_enabled() -> None:
# Correct 2D/3D DCP usage requires DTensor's strided sharding in PR
# https://github.com/pytorch/pytorch/pull/130760. This function checks if users'
# PyTorch nightly-build version is newer than 2024-08-09 to make sure this PR is
# included when 2D/3D DCP is used.
def check_if_feature_in_pytorch(
feature_name: str,
pull_request: str,
min_nightly_version: Optional[str] = None,
) -> None:
if "git" in torch.__version__: # pytorch is built from source
# notify users to check if the commit hash is newer than 2024-08-09
# notify users to check if the pull request is included in their pytorch
logger.warning(
"detected that the pytorch is built from source. Please make sure the PR "
"(https://github.com/pytorch/pytorch/pull/130760) is included in pytorch "
"for correct 2D/3D DCP usage."
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
)
elif torch.__version__ < "2.5.0.dev20240809":
# the nightly build pytorch was built before 2024-08-09
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
logger.warning(
f"detected that the pytorch version {torch.__version__} is older than "
"2.5.0.dev20240809. Please upgrade a newer version to include the change "
"made in https://github.com/pytorch/pytorch/pull/130760 for correct 2D/3D "
"DCP usage."
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request_link}) for correct {feature_name}."
)

0 comments on commit e35e17d

Please sign in to comment.