Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load 1 file per node during simulated checkpoint #160

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dataflux_pytorch/lightning/gcs_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Generator, Optional, Union

import torch.distributed
from dataflux_core import user_agent
from google.cloud import storage
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
Expand All @@ -30,6 +31,9 @@ def create_stream(self, path: Union[str, os.PathLike],
mode: str) -> Generator[io.IOBase, None, None]:
bucket, path = parse_gcs_path(path)
blob = self.storage_client.bucket(bucket).blob(path)
print(
f"Rank {torch.distributed.get_rank()} opened path {path} in mode {mode}"
)
if mode == "wb": # write mode.
with DatafluxCheckpointBuffer(blob) as stream:
yield stream
Expand Down
61 changes: 35 additions & 26 deletions demo/lightning/checkpoint/simulated/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import torch.distributed.checkpoint as dist_cp
import torch.multiprocessing as mp
import torch.nn as nn
from lightning.pytorch.strategies import FSDPStrategy

from dataflux_pytorch.lightning.gcs_filesystem import (GCSDistributedReader,
GCSDistributedWriter)
from lightning.pytorch.strategies import FSDPStrategy

# Constants for distributed setup
MASTER_ADDR = 'localhost'
Expand Down Expand Up @@ -222,13 +223,10 @@ class SimpleModel(nn.Module):
padding.
"""

def __init__(self, size: int, padding_size: int):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(size, size)
self.fc2 = nn.Linear(size, size)
self.dummy_tensors = [
torch.randn(size, size) for _ in range(padding_size)
]
self.fc1 = nn.Linear(1, 1)
self.fc2 = nn.Linear(1, 1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(torch.relu(self.fc1(x)))
Expand All @@ -248,7 +246,8 @@ def cleanup() -> None:
def time_checkpoint_operation(benchmark_strategy: BenchmarkStrategy,
distributed_state_dict: Dict[str, torch.Tensor],
filepath: str, sample_count: int, operation: str,
model: nn.Module) -> list:
model: nn.Module, rank: int, world_size: int,
tensor_count: int, tensor_size: int) -> list:
"""
Times the save or load operations for checkpoints.

Expand All @@ -268,12 +267,11 @@ def time_checkpoint_operation(benchmark_strategy: BenchmarkStrategy,
saving/loading under distributed settings.
"""
times = []
template_state_dict = model.state_dict()
for key, tensor in template_state_dict.items():
template_state_dict[key] = torch.empty_like(tensor)
if hasattr(model, 'dummy_tensors'):
for i, tensor in enumerate(model.dummy_tensors):
template_state_dict[f'dummy_tensor_{i}'] = torch.empty_like(tensor)
template_state_dict = dict()
for i in range(tensor_count):
if i % world_size == rank:
template_state_dict[f'dummy_tensor_{i}'] = torch.empty(
tensor_size, 1000)
for i in range(sample_count):
checkpoint_path = os.path.join(filepath, f'checkpoints/ckpt_{i}.ckpt')
dist.barrier()
Expand All @@ -297,23 +295,24 @@ def run_benchmark(rank, world_size: int, layer_size: int, project: str,
debug: bool) -> None:
setup(rank, world_size)

model = SimpleModel(layer_size, padding_size)
model = SimpleModel()

if rank == 0 and debug:
print("Writing initial model structure and parameters to file...")
write_full_model(model, "initial_model_state.txt")
# if rank == 0 and debug:
# print("Writing initial model structure and parameters to file...")
# write_full_model(model, "initial_model_state.txt")
benchmark_strategy = BenchmarkStrategy(project=project,
path=filepath,
model=model)

state_dict = model.state_dict()
for i, tensor in enumerate(model.dummy_tensors):
state_dict[f'dummy_tensor_{i}'] = tensor
state_dict = dict()
for i in range(padding_size):
if i % world_size == rank:
state_dict[f'dummy_tensor_{i}'] = torch.randn(layer_size, 1000)

if rank == 0 and debug:
print("Writing state dict before saving to file...")
if debug:
print(f"Writing state dict before saving to file... {rank}")
write_state_dict_to_file(state_dict, "state_dict_before_save.txt")
print("Shapes before saving:", {
print(f"Shapes before saving: {rank} ", {
k: v.shape
for k, v in state_dict.items()
})
Expand All @@ -322,12 +321,22 @@ def run_benchmark(rank, world_size: int, layer_size: int, project: str,
save_checkpoint_times = time_checkpoint_operation(benchmark_strategy,
state_dict, filepath,
sample_count, 'save',
model)
model, rank, world_size,
padding_size, layer_size)

if debug:
print(f"Writing state dict before loading... {rank}")
write_state_dict_to_file(state_dict, "state_dict_before_load.txt")
print(f"Shapes before loading: {rank} ", {
k: v.shape
for k, v in state_dict.items()
})

load_checkpoint_times = time_checkpoint_operation(benchmark_strategy,
state_dict, filepath,
sample_count, 'load',
model)
model, rank, world_size,
padding_size, layer_size)

if rank == 0:
print(f"Time taken to save checkpoint:\
Expand Down