diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d50b66656..d222d1cc7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,6 +3,7 @@ name: CI on: pull_request: + workflow_dispatch: push: branches: [master] diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py index 58ca4c718..3c5211982 100644 --- a/dlrover/python/elastic_agent/torch/ckpt_saver.py +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -414,11 +414,7 @@ def __init__( ) self._master_client = None - # remove the history temp path if exists - self.storage.safe_rmtree( - os.path.join(self.checkpoint_dir, self._STAGE_DIR) - ) - logger.info("AsyncSaver initialized.") + logger.info(f"AsyncSaver({self.__class__.__name__}) initialized.") def __del__(self): self.close() @@ -781,10 +777,17 @@ def _sync_node_checkpoint( if elapsed_time > timeout: logger.info( "It is timeout to sync checkpoint " - "bacause some nodes may fail." + "because some nodes may fail." ) return False + def _remove_sub_dir_of_target_path(self, path): + if os.path.exists(path): + for entry in os.listdir(path): + full_path = os.path.join(path, entry) + if os.path.isdir(full_path): + self.storage.safe_rmtree(full_path) + @classmethod def reset(cls): """Reset the shared memory of all shards.""" @@ -999,6 +1002,28 @@ class TempDirCheckpointSaver(AsyncCheckpointSaver): by users. """ + def __init__( + self, + checkpoint_dir, + storage_meta: ClassMeta, + local_shard_num=1, + global_shard_num=1, + save_timeout=CheckpointConstant.SAVE_TIMEOUT, + ) -> None: + super().__init__( + checkpoint_dir, + storage_meta, + local_shard_num, + global_shard_num, + save_timeout, + ) + + if self._node_rank == 0: + # remove the history temp path if exists + self._remove_sub_dir_of_target_path( + os.path.join(self.checkpoint_dir, self._STAGE_DIR) + ) + def save_step_checkpoint(self, step): """ Save the checkpoint of a step into the storage. diff --git a/dlrover/python/tests/test_ckpt_saver.py b/dlrover/python/tests/test_ckpt_saver.py index 9470c30fb..3586ea177 100644 --- a/dlrover/python/tests/test_ckpt_saver.py +++ b/dlrover/python/tests/test_ckpt_saver.py @@ -157,6 +157,22 @@ def test_create_checkpoint_saver(self): id(AsyncCheckpointSaver._saver_instance._master_client), ) + # test + test_path = "/tmp/test_ckpt" + AsyncCheckpointSaver._saver_instance.checkpoint_dir = test_path + os.makedirs(test_path, exist_ok=True) + os.makedirs(os.path.join(test_path, "td1"), exist_ok=True) + with open( + os.path.join(test_path, "tf1"), "w", encoding="utf-8" + ) as file: + file.write("test") + AsyncCheckpointSaver._saver_instance._remove_sub_dir_of_target_path( + test_path + ) + self.assertTrue(os.path.exists(test_path)) + self.assertTrue(os.path.exists(os.path.join(test_path, "tf1"))) + self.assertFalse(os.path.exists(os.path.join(test_path, "td1"))) + def test_close_saver(self): saver = DdpCheckpointSaver("test_ckpt", self.storage.get_class_meta()) try: diff --git a/dlrover/trainer/torch/flash_checkpoint/engine.py b/dlrover/trainer/torch/flash_checkpoint/engine.py index 1574f5c8c..56249aad7 100644 --- a/dlrover/trainer/torch/flash_checkpoint/engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/engine.py @@ -174,7 +174,7 @@ def __init__( self.storage = storage self._save_timeout = save_timeout self._local_rank = env_utils.get_local_rank() - self._cached_step = 0 + self._cached_step = -1 self._restart_count = env_utils.get_torch_restart_count() # init saver