diff --git a/dataflux_pytorch/lightning/gcs_filesystem.py b/dataflux_pytorch/lightning/gcs_filesystem.py index 6c0f59e..362b942 100644 --- a/dataflux_pytorch/lightning/gcs_filesystem.py +++ b/dataflux_pytorch/lightning/gcs_filesystem.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Generator, Optional, Union +import torch.distributed from dataflux_core import user_agent from google.cloud import storage from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter @@ -30,6 +31,9 @@ def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]: bucket, path = parse_gcs_path(path) blob = self.storage_client.bucket(bucket).blob(path) + print( + f"Rank {torch.distributed.get_rank()} opened path {path} in mode {mode}" + ) if mode == "wb": # write mode. with DatafluxCheckpointBuffer(blob) as stream: yield stream diff --git a/demo/lightning/checkpoint/simulated/multiprocessing_train.py b/demo/lightning/checkpoint/simulated/multiprocessing_train.py index 604e384..110e123 100644 --- a/demo/lightning/checkpoint/simulated/multiprocessing_train.py +++ b/demo/lightning/checkpoint/simulated/multiprocessing_train.py @@ -24,9 +24,10 @@ import torch.distributed.checkpoint as dist_cp import torch.multiprocessing as mp import torch.nn as nn +from lightning.pytorch.strategies import FSDPStrategy + from dataflux_pytorch.lightning.gcs_filesystem import (GCSDistributedReader, GCSDistributedWriter) -from lightning.pytorch.strategies import FSDPStrategy # Constants for distributed setup MASTER_ADDR = 'localhost' @@ -222,13 +223,10 @@ class SimpleModel(nn.Module): padding. """ - def __init__(self, size: int, padding_size: int): + def __init__(self): super(SimpleModel, self).__init__() - self.fc1 = nn.Linear(size, size) - self.fc2 = nn.Linear(size, size) - self.dummy_tensors = [ - torch.randn(size, size) for _ in range(padding_size) - ] + self.fc1 = nn.Linear(1, 1) + self.fc2 = nn.Linear(1, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(torch.relu(self.fc1(x))) @@ -248,7 +246,8 @@ def cleanup() -> None: def time_checkpoint_operation(benchmark_strategy: BenchmarkStrategy, distributed_state_dict: Dict[str, torch.Tensor], filepath: str, sample_count: int, operation: str, - model: nn.Module) -> list: + model: nn.Module, rank: int, world_size: int, + tensor_count: int, tensor_size: int) -> list: """ Times the save or load operations for checkpoints. @@ -268,12 +267,11 @@ def time_checkpoint_operation(benchmark_strategy: BenchmarkStrategy, saving/loading under distributed settings. """ times = [] - template_state_dict = model.state_dict() - for key, tensor in template_state_dict.items(): - template_state_dict[key] = torch.empty_like(tensor) - if hasattr(model, 'dummy_tensors'): - for i, tensor in enumerate(model.dummy_tensors): - template_state_dict[f'dummy_tensor_{i}'] = torch.empty_like(tensor) + template_state_dict = dict() + for i in range(tensor_count): + if i % world_size == rank: + template_state_dict[f'dummy_tensor_{i}'] = torch.empty( + tensor_size, 1000) for i in range(sample_count): checkpoint_path = os.path.join(filepath, f'checkpoints/ckpt_{i}.ckpt') dist.barrier() @@ -297,23 +295,24 @@ def run_benchmark(rank, world_size: int, layer_size: int, project: str, debug: bool) -> None: setup(rank, world_size) - model = SimpleModel(layer_size, padding_size) + model = SimpleModel() - if rank == 0 and debug: - print("Writing initial model structure and parameters to file...") - write_full_model(model, "initial_model_state.txt") + # if rank == 0 and debug: + # print("Writing initial model structure and parameters to file...") + # write_full_model(model, "initial_model_state.txt") benchmark_strategy = BenchmarkStrategy(project=project, path=filepath, model=model) - state_dict = model.state_dict() - for i, tensor in enumerate(model.dummy_tensors): - state_dict[f'dummy_tensor_{i}'] = tensor + state_dict = dict() + for i in range(padding_size): + if i % world_size == rank: + state_dict[f'dummy_tensor_{i}'] = torch.randn(layer_size, 1000) - if rank == 0 and debug: - print("Writing state dict before saving to file...") + if debug: + print(f"Writing state dict before saving to file... {rank}") write_state_dict_to_file(state_dict, "state_dict_before_save.txt") - print("Shapes before saving:", { + print(f"Shapes before saving: {rank} ", { k: v.shape for k, v in state_dict.items() }) @@ -322,12 +321,22 @@ def run_benchmark(rank, world_size: int, layer_size: int, project: str, save_checkpoint_times = time_checkpoint_operation(benchmark_strategy, state_dict, filepath, sample_count, 'save', - model) + model, rank, world_size, + padding_size, layer_size) + + if debug: + print(f"Writing state dict before loading... {rank}") + write_state_dict_to_file(state_dict, "state_dict_before_load.txt") + print(f"Shapes before loading: {rank} ", { + k: v.shape + for k, v in state_dict.items() + }) load_checkpoint_times = time_checkpoint_operation(benchmark_strategy, state_dict, filepath, sample_count, 'load', - model) + model, rank, world_size, + padding_size, layer_size) if rank == 0: print(f"Time taken to save checkpoint:\