Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue in ckpt saver. #1297

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: CI

on:
pull_request:
workflow_dispatch:
push:
branches: [master]

Expand Down
37 changes: 31 additions & 6 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,7 @@ def __init__(
)
self._master_client = None

# remove the history temp path if exists
self.storage.safe_rmtree(
os.path.join(self.checkpoint_dir, self._STAGE_DIR)
)
logger.info("AsyncSaver initialized.")
logger.info(f"AsyncSaver({self.__class__.__name__}) initialized.")

def __del__(self):
self.close()
Expand Down Expand Up @@ -781,10 +777,17 @@ def _sync_node_checkpoint(
if elapsed_time > timeout:
logger.info(
"It is timeout to sync checkpoint "
"bacause some nodes may fail."
"because some nodes may fail."
)
return False

def _remove_sub_dir_of_target_path(self, path):
if os.path.exists(path):
for entry in os.listdir(path):
full_path = os.path.join(path, entry)
if os.path.isdir(full_path):
self.storage.safe_rmtree(full_path)

@classmethod
def reset(cls):
"""Reset the shared memory of all shards."""
Expand Down Expand Up @@ -999,6 +1002,28 @@ class TempDirCheckpointSaver(AsyncCheckpointSaver):
by users.
"""

def __init__(
self,
checkpoint_dir,
storage_meta: ClassMeta,
local_shard_num=1,
global_shard_num=1,
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
) -> None:
super().__init__(
checkpoint_dir,
storage_meta,
local_shard_num,
global_shard_num,
save_timeout,
)

if self._node_rank == 0:
# remove the history temp path if exists
self._remove_sub_dir_of_target_path(
os.path.join(self.checkpoint_dir, self._STAGE_DIR)
)

def save_step_checkpoint(self, step):
"""
Save the checkpoint of a step into the storage.
Expand Down
16 changes: 16 additions & 0 deletions dlrover/python/tests/test_ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ def test_create_checkpoint_saver(self):
id(AsyncCheckpointSaver._saver_instance._master_client),
)

# test
test_path = "/tmp/test_ckpt"
AsyncCheckpointSaver._saver_instance.checkpoint_dir = test_path
os.makedirs(test_path, exist_ok=True)
os.makedirs(os.path.join(test_path, "td1"), exist_ok=True)
with open(
os.path.join(test_path, "tf1"), "w", encoding="utf-8"
) as file:
file.write("test")
AsyncCheckpointSaver._saver_instance._remove_sub_dir_of_target_path(
test_path
)
self.assertTrue(os.path.exists(test_path))
self.assertTrue(os.path.exists(os.path.join(test_path, "tf1")))
self.assertFalse(os.path.exists(os.path.join(test_path, "td1")))

def test_close_saver(self):
saver = DdpCheckpointSaver("test_ckpt", self.storage.get_class_meta())
try:
Expand Down
2 changes: 1 addition & 1 deletion dlrover/trainer/torch/flash_checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self.storage = storage
self._save_timeout = save_timeout
self._local_rank = env_utils.get_local_rank()
self._cached_step = 0
self._cached_step = -1
self._restart_count = env_utils.get_torch_restart_count()

# init saver
Expand Down
Loading