Skip to content

Commit

Permalink
Update base for Update on "[BE] replace the extra DeviceMesh _flatten…
Browse files Browse the repository at this point in the history
… with mesh access"


**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Oct 31, 2024
2 parents 53d0f69 + 2a785e9 commit 03d27ce
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 37 deletions.
7 changes: 0 additions & 7 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_strided_sharding_enabled


def parallelize_llama(
Expand Down Expand Up @@ -330,12 +329,6 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand Down
30 changes: 0 additions & 30 deletions torchtitan/parallelisms/utils.py

This file was deleted.

0 comments on commit 03d27ce

Please sign in to comment.