Skip to content

Commit

Permalink
Merge pull request #1297 from BalaBalaYi/fix_unsave_rm_in_common_dir_…
Browse files Browse the repository at this point in the history
…ckpt_saver

Fix issue in ckpt saver.
  • Loading branch information
samplise authored Oct 17, 2024
2 parents d26b965 + 3d86a8d commit d0bbbbc
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: CI

on:
pull_request:
workflow_dispatch:
push:
branches: [master]

Expand Down
37 changes: 31 additions & 6 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,7 @@ def __init__(
)
self._master_client = None

# remove the history temp path if exists
self.storage.safe_rmtree(
os.path.join(self.checkpoint_dir, self._STAGE_DIR)
)
logger.info("AsyncSaver initialized.")
logger.info(f"AsyncSaver({self.__class__.__name__}) initialized.")

def __del__(self):
self.close()
Expand Down Expand Up @@ -781,10 +777,17 @@ def _sync_node_checkpoint(
if elapsed_time > timeout:
logger.info(
"It is timeout to sync checkpoint "
"bacause some nodes may fail."
"because some nodes may fail."
)
return False

def _remove_sub_dir_of_target_path(self, path):
if os.path.exists(path):
for entry in os.listdir(path):
full_path = os.path.join(path, entry)
if os.path.isdir(full_path):
self.storage.safe_rmtree(full_path)

@classmethod
def reset(cls):
"""Reset the shared memory of all shards."""
Expand Down Expand Up @@ -999,6 +1002,28 @@ class TempDirCheckpointSaver(AsyncCheckpointSaver):
by users.
"""

def __init__(
self,
checkpoint_dir,
storage_meta: ClassMeta,
local_shard_num=1,
global_shard_num=1,
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
) -> None:
super().__init__(
checkpoint_dir,
storage_meta,
local_shard_num,
global_shard_num,
save_timeout,
)

if self._node_rank == 0:
# remove the history temp path if exists
self._remove_sub_dir_of_target_path(
os.path.join(self.checkpoint_dir, self._STAGE_DIR)
)

def save_step_checkpoint(self, step):
"""
Save the checkpoint of a step into the storage.
Expand Down
16 changes: 16 additions & 0 deletions dlrover/python/tests/test_ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ def test_create_checkpoint_saver(self):
id(AsyncCheckpointSaver._saver_instance._master_client),
)

# test
test_path = "/tmp/test_ckpt"
AsyncCheckpointSaver._saver_instance.checkpoint_dir = test_path
os.makedirs(test_path, exist_ok=True)
os.makedirs(os.path.join(test_path, "td1"), exist_ok=True)
with open(
os.path.join(test_path, "tf1"), "w", encoding="utf-8"
) as file:
file.write("test")
AsyncCheckpointSaver._saver_instance._remove_sub_dir_of_target_path(
test_path
)
self.assertTrue(os.path.exists(test_path))
self.assertTrue(os.path.exists(os.path.join(test_path, "tf1")))
self.assertFalse(os.path.exists(os.path.join(test_path, "td1")))

def test_close_saver(self):
saver = DdpCheckpointSaver("test_ckpt", self.storage.get_class_meta())
try:
Expand Down
2 changes: 1 addition & 1 deletion dlrover/trainer/torch/flash_checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self.storage = storage
self._save_timeout = save_timeout
self._local_rank = env_utils.get_local_rank()
self._cached_step = 0
self._cached_step = -1
self._restart_count = env_utils.get_torch_restart_count()

# init saver
Expand Down

0 comments on commit d0bbbbc

Please sign in to comment.