Skip to content

Commit

Permalink
Revert "[BE] replace the extra DeviceMesh _flatten with mesh access (#…
Browse files Browse the repository at this point in the history
…666)"

This reverts commit 3653bf2.
  • Loading branch information
XilunWu authored Oct 31, 2024
1 parent 3653bf2 commit 4f729c2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 54 deletions.
42 changes: 16 additions & 26 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch


def parallelize_llama(
Expand Down Expand Up @@ -80,31 +79,22 @@ def parallelize_llama(
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
try:
dp_mesh = (
world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
)
except IndexError:
# note: this is a workaround of the above logic for old pytorch version
# where https://github.com/pytorch/pytorch/pull/138945 is not included
# throw a warning to encourage users to upgrade to a newer pytorch version
check_if_feature_in_pytorch(
"DeviceMesh flattening over 3D+ meshes",
"https://github.com/pytorch/pytorch/pull/138945",
"2.6.0.dev20241030",
)
# TODO: remove this workaround once PyTorch 2.6 is released
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)

# TODO: instead of flattening the mesh twice, we could've done in a batter way:
# dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
# However, this leads to an error in `DeviceMesh.__get_item__` which I believe is
# a bug in DeviceMesh. We should fix it and then use the above line.
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)

apply_fsdp(
model,
Expand Down
28 changes: 0 additions & 28 deletions torchtitan/parallelisms/utils.py

This file was deleted.

0 comments on commit 4f729c2

Please sign in to comment.