Skip to content

Commit

Permalink
tensor parallel training (#2465)
Browse files Browse the repository at this point in the history
* tensor parallel training
  • Loading branch information
vince62s authored Sep 6, 2023
1 parent 2b13ed1 commit 16aaba8
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 35 deletions.
4 changes: 3 additions & 1 deletion onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def build_base_model(model_opt, vocabs):
return model


def build_model(model_opt, opt, vocabs, checkpoint):
def build_model(model_opt, opt, vocabs, checkpoint, device_id):
logger.info("Building model...")

model = build_base_model(model_opt, vocabs)
Expand Down Expand Up @@ -414,6 +414,7 @@ def build_model(model_opt, opt, vocabs, checkpoint):
precision=precision,
device=device,
strict=strict,
device_id=device_id,
)
else:
# weights are not in the .pt checkpoint but stored in the safetensors file
Expand All @@ -425,6 +426,7 @@ def build_model(model_opt, opt, vocabs, checkpoint):
precision=precision,
device=device,
strict=strict,
device_id=device_id,
)
else:
model.to(precision)
Expand Down
4 changes: 2 additions & 2 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def load_state_dict(
# bitsandbytes quantize weights when .cuda() is called
# for huge models we need to save Ram
# so we load the weights module by module and transfer them to GPU for quantization
if device == torch.device("cpu"):
device_id = 0
buf_list = []
for name, module in self.named_modules():
for buf_name, buf in module.named_buffers():
Expand Down Expand Up @@ -220,7 +222,6 @@ def load_safe_state_dict(
else:
row_slice_start = 0
row_slice_end = param.data.size(1)

assert (
param.data.size()
== ckpt_t[
Expand All @@ -234,7 +235,6 @@ def load_safe_state_dict(
row_slice_start:row_slice_end,
]
else:

assert (
param.data.size()
== ckpt_t[col_slice_start:col_slice_end].size()
Expand Down
179 changes: 156 additions & 23 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from onmt.modules.lora import lora_state_dict


def build_model_saver(model_opt, opt, model, vocabs, optim):
def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
Expand All @@ -20,6 +20,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim):
optim,
opt.keep_checkpoint,
opt.save_format,
device_id,
)
return model_saver

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
optim,
keep_checkpoint=-1,
save_format="pytorch",
device_id=0,
):
self.base_path = base_path
self.model = model
Expand All @@ -106,6 +108,8 @@ def __init__(
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.device_id = device_id

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
if save_format == "safetensors":
Expand Down Expand Up @@ -135,20 +139,24 @@ def save(self, step, moving_average=None):

self.last_saved_step = step

if moving_average:
for param_data, param in zip(model_params_data, save_model.parameters()):
param.data = param_data
if ckpt_path: # not None when process id 0

if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
self._rm_checkpoint(todel)
if self.save_format == "safetensors":
todel = self.model_queue.popleft()
if moving_average:
for param_data, param in zip(
model_params_data, save_model.parameters()
):
param.data = param_data

if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
self._rm_checkpoint(todel)
self.checkpoint_queue.append(ckpt_path)
if self.save_format == "safetensors":
self.model_queue.append(model_path)
if self.save_format == "safetensors":
todel = self.model_queue.popleft()
self._rm_checkpoint(todel)
self.checkpoint_queue.append(ckpt_path)
if self.save_format == "safetensors":
self.model_queue.append(model_path)

def _save(self, step, model):
"""Save a resumable checkpoint.
Expand Down Expand Up @@ -196,17 +204,78 @@ def _save(self, step, model):
}
generator_state_dict = model.generator.state_dict()

if torch.distributed.is_initialized():
ws = torch.distributed.get_world_size()
else:
ws = 1
if ws > 1:
full_model = [None for _ in range(ws)]
for key, value in model_state_dict.items():
model_state_dict[key] = value.cpu()
torch.distributed.all_gather_object(full_model, model_state_dict)
fm_sd = {}
for key in full_model[0].keys():
if key.split(".")[-1] == "lora_A":
if key.split(".")[-2] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = (
sum([full_model[i][key].cpu() for i in range(ws)]) / ws
)
elif key.split(".")[-2] in ["final_linear", "w_2"]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 1
)
elif key.split(".")[-1] == "lora_B":
if key.split(".")[-2] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 0
)
elif key.split(".")[-2] in ["final_linear", "w_2"]:
fm_sd[key] = (
sum([full_model[i][key].cpu() for i in range(ws)]) / ws
)
elif key.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 0
)
elif key.split(".")[-1] in ["final_linear", "w_2"]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 1
)
model_state_dict = fm_sd

checkpoint = {
"model": model_state_dict,
"generator": generator_state_dict,
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}

logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
torch.save(checkpoint, ckpt_path)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
torch.save(checkpoint, ckpt_path)
else:
ckpt_path = None
if torch.distributed.is_initialized():
torch.distributed.barrier()
return ckpt_path, None

def _st_save(self, step, model):
Expand All @@ -224,18 +293,82 @@ def _st_save(self, step, model):
else:
model_state_dict = model.state_dict()

if torch.distributed.is_initialized():
ws = torch.distributed.get_world_size()
else:
ws = 1
if ws > 1:
full_model = [None for _ in range(ws)]
for key, value in model_state_dict.items():
model_state_dict[key] = value.cpu()
torch.distributed.all_gather_object(full_model, model_state_dict)
fm_sd = {}
for key in full_model[0].keys():
if key.split(".")[-1] == "lora_A":
if key.split(".")[-2] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = (
sum([full_model[i][key].cpu() for i in range(ws)]) / ws
)
elif key.split(".")[-2] in ["final_linear", "w_2"]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 1
)
elif key.split(".")[-1] == "lora_B":
if key.split(".")[-2] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 0
)
elif key.split(".")[-2] in ["final_linear", "w_2"]:
fm_sd[key] = (
sum([full_model[i][key].cpu() for i in range(ws)]) / ws
)
elif key.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 0
)
elif key.split(".")[-1] in ["final_linear", "w_2"]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 1
)
model_state_dict = fm_sd

checkpoint = {
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}

logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
torch.save(checkpoint, ckpt_path)
logger.info("Saving safetensors %s_step_%d.pt" % (self.base_path, step))
model_path = "%s_step_%d.safetensors" % (self.base_path, step)
save_file(model_state_dict, model_path)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
torch.save(checkpoint, ckpt_path)
logger.info("Saving safetensors %s_step_%d.pt" % (self.base_path, step))
model_path = "%s_step_%d.safetensors" % (self.base_path, step)
save_file(model_state_dict, model_path)
else:
ckpt_path = None
model_path = None
if torch.distributed.is_initialized():
torch.distributed.barrier()

return ckpt_path, model_path

def _rm_checkpoint(self, name):
Expand Down
4 changes: 2 additions & 2 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(opt, device_id):
model_opt = _get_model_opts(opt, checkpoint=checkpoint)

# Build model.
model = build_model(model_opt, opt, vocabs, checkpoint)
model = build_model(model_opt, opt, vocabs, checkpoint, device_id)

model.count_parameters(log=logger.info)
trainable = {
Expand Down Expand Up @@ -196,7 +196,7 @@ def main(opt, device_id):
del checkpoint

# Build model saver
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim)
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id)

trainer = build_trainer(
opt, device_id, model, vocabs, optim, model_saver=model_saver
Expand Down
14 changes: 7 additions & 7 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import time
import sys
import torch
import traceback
import onmt.utils
Expand Down Expand Up @@ -86,7 +87,7 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None):
parallel_mode,
report_manager,
with_align=True if opt.lambda_align > 0 else False,
model_saver=model_saver if gpu_rank <= 0 else None,
model_saver=model_saver,
average_decay=average_decay,
average_every=average_every,
model_dtype=opt.model_dtype,
Expand Down Expand Up @@ -319,11 +320,7 @@ def train(
step, train_steps, self.optim.learning_rate(), report_stats
)

if (
valid_iter is not None
and step % valid_steps == 0
and self.gpu_rank <= 0
):
if valid_iter is not None and step % valid_steps == 0:
valid_stats = self.validate(
valid_iter, moving_average=self.moving_average
)
Expand Down Expand Up @@ -519,6 +516,9 @@ def _gradient_accumulation(
self.optim.training_step,
)
torch.cuda.empty_cache()
if self.n_gpu > 1 and self.parallel_mode == "tensor_parallel":
torch.distributed.destroy_process_group()
sys.exit()
else:
traceback.print_exc()
raise exc
Expand Down Expand Up @@ -563,7 +563,7 @@ def _maybe_report_training(self, step, num_steps, learning_rate, report_stats):
if self.earlystopper is None
else self.earlystopper.current_tolerance,
report_stats,
multigpu=self.n_gpu > 1,
multigpu=self.n_gpu > 1 and self.parallel_mode == "data_parallel",
)

def _report_step(self, learning_rate, step, valid_stats=None, train_stats=None):
Expand Down
2 changes: 2 additions & 0 deletions onmt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import pickle
import torch.distributed
from datetime import timedelta
from onmt.translate.translator import build_translator
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
Expand All @@ -29,6 +30,7 @@ def multi_init(opt, device_id):
init_method=dist_init_method,
world_size=dist_world_size,
rank=opt.gpu_ranks[device_id],
timeout=timedelta(seconds=30),
)
gpu_rank = torch.distributed.get_rank()
if not is_master(opt, device_id):
Expand Down
1 change: 1 addition & 0 deletions tools/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
lora_opt = lora_checkpoint["opt"]

lora_opt.quant_layers = [] # we need to remove any quantization to merge weights
lora_opt.parallel_mode = "data_parallel"

model = build_base_model(lora_opt, vocabs)

Expand Down

0 comments on commit 16aaba8

Please sign in to comment.