From 2a785e9956ca4cb3fd2e4f3736413c4bd1669a22 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:06:53 -0700 Subject: [PATCH] [BE] remove old pytorch version warning on strided sharding since 2.5 is official released (#665) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #665 #507 added a PyTorch version check when users try to use FSDP+TP, to make sure the right PT version includes DTensor strided sharding which assures correct DTensor checkpoint. Since PyTorch 2.5 is official released and strided sharding is included in 2.5, we can safely remove this warning. --- torchtitan/parallelisms/parallelize_llama.py | 7 ----- torchtitan/parallelisms/utils.py | 30 -------------------- 2 files changed, 37 deletions(-) delete mode 100644 torchtitan/parallelisms/utils.py diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ed23936b..84666031 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -34,7 +34,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_strided_sharding_enabled def parallelize_llama( @@ -330,12 +329,6 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - # TODO: remove this check once PyTorch 2.5 is released. We can safely assume - # that users won't use a nightly build which is older than 20240809 by then. - if tp_enabled: - # check if strided sharding is enabled, which is necessary for 2D/3D DCP - check_strided_sharding_enabled() - for layer_id, transformer_block in model.layers.items(): if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch diff --git a/torchtitan/parallelisms/utils.py b/torchtitan/parallelisms/utils.py deleted file mode 100644 index a82ace7a..00000000 --- a/torchtitan/parallelisms/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from torchtitan.logging import logger - - -def check_strided_sharding_enabled() -> None: - # Correct 2D/3D DCP usage requires DTensor's strided sharding in PR - # https://github.com/pytorch/pytorch/pull/130760. This function checks if users' - # PyTorch nightly-build version is newer than 2024-08-09 to make sure this PR is - # included when 2D/3D DCP is used. - if "git" in torch.__version__: # pytorch is built from source - # notify users to check if the commit hash is newer than 2024-08-09 - logger.warning( - "detected that the pytorch is built from source. Please make sure the PR " - "(https://github.com/pytorch/pytorch/pull/130760) is included in pytorch " - "for correct 2D/3D DCP usage." - ) - elif torch.__version__ < "2.5.0.dev20240809": - # the nightly build pytorch was built before 2024-08-09 - logger.warning( - f"detected that the pytorch version {torch.__version__} is older than " - "2.5.0.dev20240809. Please upgrade a newer version to include the change " - "made in https://github.com/pytorch/pytorch/pull/130760 for correct 2D/3D " - "DCP usage." - )