Skip to content

Commit

Permalink
Merge pull request #1299 from BalaBalaYi/compatible_torch_2.4_storage…
Browse files Browse the repository at this point in the history
…_writer

Compatible torch 2.4(2.3) storage writer.
  • Loading branch information
samplise authored Oct 18, 2024
2 parents 8dd10ab + 379c51a commit e88d7eb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
6 changes: 6 additions & 0 deletions dlrover/trainer/tests/torch/fsdp_ckpt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,15 @@ def test_shared_memory_writer(self):
for _, item in files:
write_items.append(item)
writer = _write_state_dict_to_shm(self.shm, files, state_dict)
writer.reset()
self.assertTrue(writer.validate_checkpoint_id(None))
self.assertTrue("dcp_metadata" in writer.metadata)
self.assertTrue("no_shard_data" in writer.metadata)

writer.shm_handler.metadata.set(writer.metadata)
reader = SharedMemoryReader(writer.shm_handler)
reader.reset()
self.assertTrue(reader.validate_checkpoint_id(None))
dcp_metadata = reader.read_metadata()
self.assertTrue(_OPTIMIZER_KEY in dcp_metadata.state_dict_metadata)
self.assertTrue(_OPTIMIZER_KEY in reader.no_shard_data)
Expand Down Expand Up @@ -295,6 +299,8 @@ def test_file_reader(self):
f.write(writer.shm_handler.shared_memory.buf)

reader = FileReader(tmpdir)
reader.reset()
self.assertTrue(reader.validate_checkpoint_id(None))
metadata = reader.read_metadata()
reader.set_up_storage_reader(metadata, True)

Expand Down
35 changes: 33 additions & 2 deletions dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ def __init__(self, shm_handler: SharedMemoryHandler) -> None:
self.shm_handler = shm_handler
self.metadata: Dict[str, Any] = {}

# Implement the abstract function in StorageWriter
def reset(
self, checkpoint_id: Union[str, os.PathLike, None] = None
) -> None:
pass

def validate_checkpoint_id(
cls, checkpoint_id: Union[str, os.PathLike]
) -> bool:
return True

def set_up_storage_writer(self, is_coordinator: bool) -> None:
pass

Expand Down Expand Up @@ -296,7 +307,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
fut.set_result(None)
return fut

# Implementating the abstract function in StorageReader
# Implement the abstract function in StorageReader
def read_metadata(self) -> Metadata:
cached_meta = self.shm_handler.metadata.get()
dcp_metadata = cached_meta["dcp_metadata"]
Expand All @@ -318,6 +329,16 @@ def prepare_global_plan(
) -> List[LoadPlan]:
return global_plan

def reset(
self, checkpoint_id: Union[str, os.PathLike, None] = None
) -> None:
pass

def validate_checkpoint_id(
cls, checkpoint_id: Union[str, os.PathLike]
) -> bool:
return True


class SlicedBufferedReader(io.BufferedReader):
def __init__(self, base_stream: io.RawIOBase, offset: int, len: int):
Expand Down Expand Up @@ -392,7 +413,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
fut.set_result(None)
return fut

# Implementating the abstract function in StorageReader
# Implement the abstract function in StorageReader
def read_metadata(self) -> Metadata:
with (self.path / ".metadata").open("rb") as metadata_file:
return pickle.load(metadata_file)
Expand All @@ -412,6 +433,16 @@ def prepare_global_plan(
) -> List[LoadPlan]:
return global_plan

def reset(
self, checkpoint_id: Union[str, os.PathLike, None] = None
) -> None:
pass

def validate_checkpoint_id(
cls, checkpoint_id: Union[str, os.PathLike]
) -> bool:
return True


class FsdpCheckpointEngine(CheckpointEngine):
"""
Expand Down

0 comments on commit e88d7eb

Please sign in to comment.