Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev mock pipeline inference #474

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions projects/mock_transformers/dist_infer_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ def __init__(self, config):
parallel_config = DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=2,
pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now
pipeline_num_layers=None,
tensor_parallel_size=1,
pipeline_parallel_size=4, # set to 1, unsupport pipeline parallel now
pipeline_num_layers=24, # set to 1, unsupport pipeline parallel now
custom_pipeline_stage_id=[0]*6 + [1]*6 + [2]*6 + [3]*6,
device_type="cpu",
)
)
Expand All @@ -95,6 +96,10 @@ def __init__(self, config):
# set model to cuda
dist.set_device_type("cuda")
model._apply(dist.convert_to_distributed_default_setting)

init_env.auto_set_pipeline_stage_id(model, pipeline_parallel_size=parallel_config.pipeline_parallel_size)
import pdb
pdb.set_trace()
# initial tokenizer
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m", use_fast=False)

Expand Down
41 changes: 25 additions & 16 deletions projects/mock_transformers/dist_infer_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,46 @@ def __init__(self, *args, **kwargs):
parallel_config = DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=2,
tensor_parallel_size=1,
pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now
pipeline_num_layers=None,
# pipeline_num_layers=12,
# custom_pipeline_stage_id= [0]*3 + [1]*3 + [2]*3 + [3]*3,
device_type="cpu",
)
)
dist.setup_dist_util(parallel_config)

placement_sbp_dict = dict(
placement=flow.env.all_device_placement("cuda"),
sbp=flow.sbp.broadcast,
)

# initial and load model
model = AutoModelForCausalLM.from_pretrained("facebook/opt-2.7b", torch_dtype=flow.float16)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype=flow.float16)
# set model to cuda
dist.set_device_type("cuda")
model._apply(dist.convert_to_distributed_default_setting)

model = init_env.auto_set_pipeline_stage_id(model, pipeline_parallel_size=parallel_config.pipeline_parallel_size)

# initial tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", use_fast=False)

# get input_ids
prompt = "Hello, I'm am conscious and"
input_ids = tokenizer(prompt, return_tensors="np").input_ids
input_ids = flow.from_numpy(input_ids)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# input_ids = flow.from_numpy(input_ids)
input_ids = input_ids.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)

# generate id
placement_sbp_dict = dict(
placement=flow.env.all_device_placement("cuda"),
sbp=flow.sbp.broadcast,
)
with global_mode(True, **placement_sbp_dict):
generated_ids = model.generate(input_ids, max_length=30)
out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(out_put_ids)
for i in range(100):
with global_mode(True, **placement_sbp_dict):
model = init_env.compile_auto_placement(
model,
input_ids
)
generated_ids = model.generate(input_ids, max_length=30)
out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(out_put_ids)
186 changes: 186 additions & 0 deletions projects/mock_transformers/init_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@

flow.mock_torch.enable()


import copy # noqa
import onefx as fx # noqa
from typing import List, Dict, Any # noqa
from oneflow import Tensor, nn # noqa
from transformers import modeling_utils # noqa
from transformers.modeling_utils import _load_state_dict_into_model # noqa
from libai.utils import distributed as dist #noqa



# ---------------- mock _load_state_dict_into_model ------------------
Expand Down Expand Up @@ -111,3 +117,183 @@ def flow_softmax(*args, **kwargs):


nn.functional.softmax = flow_softmax

# =============================================
# -----------------def function----------------
# =============================================

def set_pipeline_stage_id(self, placement):
for param in self.parameters():
param.data = param.data.to_global(placement=placement)

nn.Module.set_pipeline_stage_id = set_pipeline_stage_id


def sizeof_fmt(num, suffix='B'):
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return f"{num:.2f} {unit}{suffix}"
num /= 1024.0
return f"{num:.2f} Yi{suffix}"

def print_model(model, depth=0, max_depth=2, last_child=False, prefix=''):
indent = " "
stage_str = ""
if hasattr(model, "layer_idx"):
layer_idx = getattr(model, "layer_idx")
stage_idx = getattr(model, "stage_idx")
same_placement = True
for path, module in model.named_modules():
if getattr(module, "layer_idx") != layer_idx:
same_placement = False
if same_placement:
stage_str = f" stage{stage_idx}_ranks{dist.get_layer_placement(layer_idx).ranks} "

if depth > max_depth:
return
if isinstance(model, nn.Module):
params = sum(p.numel() for p in model.parameters())
print(indent * depth + ("└─" if last_child else "├─") + prefix + str(model.__class__.__name__) + ": " + stage_str + sizeof_fmt(params) + " params")
elif isinstance(model, nn.Sequential):
print(indent * depth + ("└─" if last_child else "├─") + prefix + str(model.__class__.__name__) + ": " + str(len(list(model.named_children()))) + " modules")
else:
print(indent * depth + ("└─" if last_child else "├─") + prefix + str(type(model).__name__))
for i, (name, child) in enumerate(model.named_children()):
print_model(child, depth=depth+1, max_depth=max_depth, last_child=i==len(list(model.named_children()))-1, prefix=f'[{name}] ')


def auto_set_pipeline_stage_id(model, pipeline_parallel_size=1):
# Define a local variable to record the number of repeated and integer layers encountered
count = 0
max_depth=1
name_stage_dict = {}
# Iterate over all submodules and paths of the model
for path, module in model.named_modules():
# Get the name and class of the module
name = path.split(".")[-1]
prefix_path = ".".join(path.split(".")[:-1])
module_cls = type(module)

# Determine if the layer is a number, i.e. if it is possible to be a repeated and integer layer
if name.isdigit():
# Determine if the layer has been repeated, i.e. if there is the same path and class in named_modules
repeated = False
for n, m in model.named_modules():
prefix_n = ".".join(n.split(".")[:-1])
if m is not module and prefix_n == prefix_path and type(m) == module_cls:
max_depth = max(len(n.split(".")), max_depth)
repeated = True
if repeated:
count += 1
# print(f"Layer {name} with path {path} is repeated. {count}")

name_stage_dict[path] = max(count-1, 0)


length = (count + pipeline_parallel_size - 1) // pipeline_parallel_size
param_id_set = set() # skip shared weight param

for path, module in model.named_modules():
# Add to_global to the parameter
layer_idx = name_stage_dict[path]
stage_idx = layer_idx // length
setattr(module, "stage_idx", stage_idx)
setattr(module, "layer_idx", layer_idx)
if len(path.split(".")) >= max_depth or len(list(module.named_children())) == 0:
for param in module.parameters():
if id(param) not in param_id_set:
param.data = param.data.to_global(placement=dist.get_layer_placement(layer_idx))
param_id_set.add(id(param))

if dist.is_main_process():
print_model(model, depth=0, max_depth=100 if max_depth==1 else max_depth)
# Return the modified model
return model

# ---------------def fx for auto changing placement ----------------------


class AutoPlacementInterpreter(fx.Interpreter):
def __init__(self, mod : flow.nn.Module):
gm = fx.symbolic_trace(mod)
super().__init__(gm)

self.global_infos : Dict[int, Dict[int, Any]] = {}
self.node_id = 0

def run(self, *args) -> Any:
return_val = super().run(*args)
return return_val

def run_node(self, n : fx.Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)
global_info_to_replace = None
max_rank_sum = -1
for arg in args:
if not isinstance(arg, flow.Tensor):
continue
if arg.is_local or len(arg.placement.ranks) == 0:
continue
placement = arg.placement
sbp = arg.sbp
# print(sum(placement.ranks))
if max_rank_sum < sum(placement.ranks):
max_rank_sum = sum(placement.ranks)
global_info_to_replace = (placement, sbp)
# elif max_rank_sum == sum(placement.ranks) and zip(placement_to_replace.ranks, placement.ranks).all(lambda x, y: x == y):
# raise ValueError("There is two different placements with same rank sum. "
# + f"They are {placement_to_replace} and {placement}.")

if max_rank_sum == -1:
self.node_id += 1
return_val = super().run_node(n)
return return_val

for arg_id in range(len(args)):
if isinstance(arg, flow.Tensor) and sum(arg.placement.ranks) < max_rank_sum:
self.global_infos.setdefault(self.node_id, {})[arg_id] = global_info_to_replace
n.update_arg(arg_id, args[arg_id].to_global(global_info_to_replace[0], global_info_to_replace[1]))

return_val = super().run_node(n)
return return_val


def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[int, List[int]]]) -> flow.nn.Module:
model = copy.deepcopy(model)
fx_model: fx.GraphModule = fx.symbolic_trace(model)

for node_id, node in enumerate(fx_model.graph.nodes):
print(node_id, " ", node.op)
if not node_id in global_info_dict:
continue

for idx, arg in enumerate(node.args):
if not idx in global_info_dict[node_id]:
continue
global_info = global_info_dict[node_id][idx]
new_node = fx.Node(fx_model.graph, f"auto_placement_{node_id}_{idx}", "call_function", flow.to_global, (arg, global_info[0], global_info[1]), {})
node.prepend(new_node)
node.update_arg(idx, new_node)

fx_model.graph.lint()
fx_model.recompile()
return fx_model

def compile_auto_placement(model: flow.nn.Module, input_x: flow.Tensor):
assert input_x.is_global
interpret = AutoPlacementInterpreter(model)
interpret.run(input_x)
model = add_auto_placement(model, interpret.global_infos)
return model

# b = flow.ones(
# (2,2),
# sbp=[flow.sbp.broadcast, flow.sbp.broadcast],
# placement=flow.placement("cuda", ranks=[[2], [3]])
# )
# demo_module = demoModule()
# interpret = AutoPlacementInterpreter(demo_module)
# c = interpret.run(b)
# model = add_auto_placement(demo_module, interpret.global_infos)
# print(model.code)
# print(model(b))