From 257233ae05a6423c20a565666ee5c60295187ed2 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 17 Oct 2024 11:23:35 +0800 Subject: [PATCH 1/3] add workflow dispatch --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d50b66656..d222d1cc7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,6 +3,7 @@ name: CI on: pull_request: + workflow_dispatch: push: branches: [master] From b2183269aaeda122145a8d2caaa36c8e463788a6 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 17 Oct 2024 16:34:10 +0800 Subject: [PATCH 2/3] torch 2.4(2.3) compatible --- .../torch/flash_checkpoint/fsdp_engine.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py b/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py index 2ea51cda4..0946b34a5 100644 --- a/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py @@ -172,6 +172,17 @@ def __init__(self, shm_handler: SharedMemoryHandler) -> None: self.shm_handler = shm_handler self.metadata: Dict[str, Any] = {} + # Implement the abstract function in StorageWriter + def reset( + self, checkpoint_id: Union[str, os.PathLike, None] = None + ) -> None: + pass + + def validate_checkpoint_id( + cls, checkpoint_id: Union[str, os.PathLike] + ) -> bool: + return True + def set_up_storage_writer(self, is_coordinator: bool) -> None: pass @@ -296,7 +307,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: fut.set_result(None) return fut - # Implementating the abstract function in StorageReader + # Implement the abstract function in StorageReader def read_metadata(self) -> Metadata: cached_meta = self.shm_handler.metadata.get() dcp_metadata = cached_meta["dcp_metadata"] @@ -318,6 +329,16 @@ def prepare_global_plan( ) -> List[LoadPlan]: return global_plan + def reset( + self, checkpoint_id: Union[str, os.PathLike, None] = None + ) -> None: + pass + + def validate_checkpoint_id( + cls, checkpoint_id: Union[str, os.PathLike] + ) -> bool: + return True + class SlicedBufferedReader(io.BufferedReader): def __init__(self, base_stream: io.RawIOBase, offset: int, len: int): @@ -392,7 +413,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: fut.set_result(None) return fut - # Implementating the abstract function in StorageReader + # Implement the abstract function in StorageReader def read_metadata(self) -> Metadata: with (self.path / ".metadata").open("rb") as metadata_file: return pickle.load(metadata_file) @@ -412,6 +433,16 @@ def prepare_global_plan( ) -> List[LoadPlan]: return global_plan + def reset( + self, checkpoint_id: Union[str, os.PathLike, None] = None + ) -> None: + pass + + def validate_checkpoint_id( + cls, checkpoint_id: Union[str, os.PathLike] + ) -> bool: + return True + class FsdpCheckpointEngine(CheckpointEngine): """ From 379c51a69edc6b6586bd10b025c3e00aa90ca819 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 18 Oct 2024 11:12:59 +0800 Subject: [PATCH 3/3] optimize ut --- dlrover/trainer/tests/torch/fsdp_ckpt_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dlrover/trainer/tests/torch/fsdp_ckpt_test.py b/dlrover/trainer/tests/torch/fsdp_ckpt_test.py index 613072cd8..33aa900b1 100644 --- a/dlrover/trainer/tests/torch/fsdp_ckpt_test.py +++ b/dlrover/trainer/tests/torch/fsdp_ckpt_test.py @@ -242,11 +242,15 @@ def test_shared_memory_writer(self): for _, item in files: write_items.append(item) writer = _write_state_dict_to_shm(self.shm, files, state_dict) + writer.reset() + self.assertTrue(writer.validate_checkpoint_id(None)) self.assertTrue("dcp_metadata" in writer.metadata) self.assertTrue("no_shard_data" in writer.metadata) writer.shm_handler.metadata.set(writer.metadata) reader = SharedMemoryReader(writer.shm_handler) + reader.reset() + self.assertTrue(reader.validate_checkpoint_id(None)) dcp_metadata = reader.read_metadata() self.assertTrue(_OPTIMIZER_KEY in dcp_metadata.state_dict_metadata) self.assertTrue(_OPTIMIZER_KEY in reader.no_shard_data) @@ -295,6 +299,8 @@ def test_file_reader(self): f.write(writer.shm_handler.shared_memory.buf) reader = FileReader(tmpdir) + reader.reset() + self.assertTrue(reader.validate_checkpoint_id(None)) metadata = reader.read_metadata() reader.set_up_storage_reader(metadata, True)