Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate state dict API to DSD #1930

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -494,6 +495,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -602,6 +604,7 @@ def save_checkpoint(
log.info("Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -610,7 +613,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
if self._is_rank_zero:
log.info(
Expand Down
8 changes: 7 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def setup(self, cfg: DictConfig) -> None:
self._tokenizer = config.instantiate(cfg.tokenizer)

self._optimizer = self._setup_optimizer(
model=self._model,
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
Expand Down Expand Up @@ -409,11 +410,15 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
self,
model: nn.Module,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -511,6 +516,7 @@ def save_checkpoint(
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
8 changes: 7 additions & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def setup(self, cfg: DictConfig) -> None:
self._tokenizer = config.instantiate(cfg.tokenizer)

self._optimizer = self._setup_optimizer(
model=self._model,
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
Expand Down Expand Up @@ -549,11 +550,15 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
self,
model: nn.Module,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -679,6 +684,7 @@ def save_checkpoint(
if self._is_rank_zero:
log.info("Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
8 changes: 7 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def setup(self, cfg: DictConfig) -> None:
self._tokenizer = config.instantiate(cfg.tokenizer)

self._optimizer = self._setup_optimizer(
model=self._model,
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
Expand Down Expand Up @@ -470,11 +471,15 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
self,
model: nn.Module,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -562,6 +567,7 @@ def save_checkpoint(

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
)
Expand Down
2 changes: 2 additions & 0 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def test_lora_state_dict(self):
fsdp_model_to_save, is_rank_zero
)
optim_full_sd = training.get_full_optimizer_state_dict(
fsdp_model_to_save,
fsdp_optim_to_save,
is_rank_zero,
)
Expand Down Expand Up @@ -371,6 +372,7 @@ def test_lora_state_dict(self):
fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01
)
training.load_from_full_optimizer_state_dict(
fsdp_model_to_load,
fsdp_optim_to_load,
# mimic mmap=True where every rank see full SD
copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)),
Expand Down
60 changes: 30 additions & 30 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,17 @@ def load_from_full_model_state_dict(
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
has_nf4 = any(
isinstance(param._local_tensor, NF4Tensor) for param in model.parameters()
)
for param_name in full_sd.keys():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if isinstance(sharded_meta_param._local_tensor, NF4Tensor):
full_sd[param_name] = (
full_sd[param_name].to(sharded_meta_param.dtype).to(device)
)

if has_nf4:
for param_name, full_tensor in full_sd.items():
full_tensor = to_nf4(full_tensor)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
Expand Down Expand Up @@ -332,18 +339,18 @@ def load_from_full_model_state_dict(
),
requires_grad=sharded_meta_param.requires_grad,
)

else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
else:
options = torch.distributed.checkpoint.state_dict.StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True, strict=strict
)
torch.distributed.checkpoint.state_dict.set_model_state_dict(
model=model, model_state_dict=full_sd, options=options
)


def get_full_model_state_dict(
Expand Down Expand Up @@ -413,25 +420,17 @@ def get_full_model_state_dict(
cpu_state_dict[full_fqn] = param.cpu()
module.reshard()
else:
for param_name, sharded_param in sharded_sd.items():
# without this, it may hang forever for +70B models.
torch.distributed.barrier()
if sharded_param.is_cpu:
assert device is not None and device.type == "cuda", (
f"Expect cuda but got device={device}. "
"Please call get_full_model_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
)
sharded_param = sharded_param.to(device)
full_param = sharded_param.full_tensor()
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
options = torch.distributed.checkpoint.state_dict.StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True
)
cpu_state_dict = torch.distributed.checkpoint.state_dict.get_model_state_dict(
model=model, options=options
)
return cpu_state_dict


def get_full_optimizer_state_dict(
model: "FSDPModule", # noqa
opt: Optimizer,
is_rank_zero: bool,
device: Optional[torch.device] = None,
Expand Down Expand Up @@ -481,6 +480,7 @@ def get_full_optimizer_state_dict(


def load_from_full_optimizer_state_dict(
model: "FSDPModule", # noqa
opt: Optimizer,
full_sd: Dict[str, Any],
device: torch.device,
Expand Down
Loading