Skip to content

Commit

Permalink
[BE] remove old pytorch version warning on strided sharding since 2.5…
Browse files Browse the repository at this point in the history
… is official released (#665)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #665

#507 added a PyTorch version check when users try to use FSDP+TP, to
make sure the right PT version includes DTensor strided sharding which
assures correct DTensor checkpoint. Since PyTorch 2.5 is official
released and strided sharding is included in 2.5, we can safely remove
this warning.
  • Loading branch information
XilunWu authored Oct 30, 2024
1 parent 53d0f69 commit 2a785e9
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 37 deletions.
7 changes: 0 additions & 7 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_strided_sharding_enabled


def parallelize_llama(
Expand Down Expand Up @@ -330,12 +329,6 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand Down
30 changes: 0 additions & 30 deletions torchtitan/parallelisms/utils.py

This file was deleted.

0 comments on commit 2a785e9

Please sign in to comment.