Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen devices #553

Merged
merged 18 commits into from
Sep 20, 2024
1 change: 0 additions & 1 deletion libai/models/utils/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import oneflow as flow
from safetensors import safe_open
from termcolor import colored
from safetensors import safe_open

import libai.utils.distributed as dist
from libai.config import LazyCall
Expand Down
6 changes: 5 additions & 1 deletion libai/tokenizer/tokenization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,11 @@ def encode(self, text, return_tensors=None, is_global=False, device="cuda", **kw
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
token_ids_list = self.convert_to_tensors(
token_ids_list, return_tensors=return_tensors, is_global=is_global, **kwargs
token_ids_list,
return_tensors=return_tensors,
is_global=is_global,
device=device,
**kwargs,
)
return token_ids_list
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
Expand Down
67 changes: 67 additions & 0 deletions projects/Qwen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

### 推理

- cuda PASS

```bash
python projects/Qwen/pipeline.py --model_path=/root/models/Qwen1.5-7B-Chat --mode=huggingface
```

- npu PASS

```bash
python projects/Qwen/pipeline.py --model_path=/data0/hf_models/qwen2/Qwen1.5-7B-Chat --mode=huggingface --device=npu
```

- xpu PASS

```bash
python projects/Qwen/pipeline.py --model_path=/root/models/Qwen1.5-7B-Chat --mode=huggingface --device=xpu
```

### 训练

- data preparation

```bash
python projects/Qwen/utils/data_prepare.py
```

- cuda PASS

```bash
export NUM_GPUS=8
python3 -m oneflow.distributed.launch \
--nproc_per_node ${NUM_GPUS} \
--nnodes 1 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 12345 \
tools/train_net.py --config-file=projects/Qwen/configs/qwen_sft.py \
graph.enabled=True \
train.input_placement_device="cuda" \
train.dist.device_type="cuda" \
train.dist.pipeline_parallel_size=${NUM_GPUS}
```
A100-PCIE-40GB x 4 OOM

- xpu OOM

```bash
export NUM_GPUS=1
python3 -m oneflow.distributed.launch \
--nproc_per_node ${NUM_GPUS} \
--nnodes 1 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 12345 \
tools/train_net.py --config-file=projects/Qwen/configs/qwen_sft.py \
graph.enabled=False \
train.input_placement_device="xpu" \
train.dist.device_type="xpu" \
train.dist.pipeline_parallel_size=${NUM_GPUS}
```

- npu 没有测,应该不行


Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from omegaconf import DictConfig, OmegaConf

from configs.common.train import train
from libai.config import LazyCall
from projects.Qwen.qwen2 import Qwen2ForCausalLM
from projects.Qwen.tokenizer import Qwen2Tokenizer
from configs.common.train import train


cfg = dict(
# Model
Expand Down Expand Up @@ -49,7 +48,7 @@
eos_token_id=151645,
pad_token_id=151643,
# train
pretrained_model_path="/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B",
pretrained_model_path="/root/models/Qwen1.5-7B-Chat",
)

cfg = DictConfig(cfg)
Expand All @@ -58,6 +57,6 @@
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(Qwen2Tokenizer)(
vocab_file="/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B/vocab.json",
merges_file="/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B/merges.txt",
# vocab_file="/root/models/Qwen1.5-7B/vocab.json",
# merges_file="/root/models/Qwen/Qwen1.5-7B/merges.txt",
)
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
import os

from omegaconf import OmegaConf

from configs.common.models.graph import graph
from configs.common.optim import optim
from configs.common.train import train
from libai.config import LazyCall
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader
from libai.evaluation import PPLEvaluator
from libai.scheduler import WarmupExponentialLR
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader

from configs.common.train import train
from configs.common.models.graph import graph
from configs.common.optim import optim

from projects.Qwen.config.qwen_config import cfg
from projects.Qwen.utils.qwen_dataset import QwenDataset
from projects.Qwen.tokenizer import Qwen2Tokenizer
from projects.Qwen.configs.qwen_config import cfg
from projects.Qwen.qwen2 import Qwen2ForCausalLM

from projects.Qwen.tokenizer import Qwen2Tokenizer
from projects.Qwen.qwen_dataset import QwenDataset

# Hyperparameters
weight_decay = 0.1
learning_rate = 5e-5
dataset_path = "/data/home/xiezipeng/libai/projects/Qwen/train_set"
pretrained_model_path = "/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B"
dataset_path = "./alpaca_data"
pretrained_model_path = "/root/models/Qwen1.5-7B-Chat"

# graph & optim
graph["enabled"] = False
Expand All @@ -35,34 +33,42 @@
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(Qwen2Tokenizer)(
vocab_file="/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B/vocab.json",
merges_file="/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B/merges.txt",
vocab_file=pretrained_model_path + "/vocab.json",
merges_file=pretrained_model_path + "/merges.txt",
)


# model
cfg.pretrained_model_path = pretrained_model_path
model = LazyCall(Qwen2ForCausalLM)(cfg=cfg)

# datasets
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(QwenDataset)(
path=dataset_path, tokenizer=tokenization.tokenizer
path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer
)
],
)
dataloader.test = [
LazyCall(build_nlp_test_loader)(
dataset=LazyCall(QwenDataset)(
path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer
),
),
]

train.update(
dict(
output_dir="./sft_result",
train_micro_batch_size=1,
test_micro_batch_size=1,
train_epoch=3,
train_epoch=1,
train_iter=1,
log_period=10,
log_period=1,
warmup_ratio=1 / 3,
num_accumulation_steps=8,
num_accumulation_steps=1,
rdma_enabled=False,
amp=dict(enabled=True),
activation_checkpoint=dict(enabled=True),
Expand Down
60 changes: 52 additions & 8 deletions projects/Qwen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import click

from libai.config import try_get_key
from libai.engine import DefaultTrainer
from libai.inference.basic import BasePipeline
from libai.utils import distributed as dist

Expand Down Expand Up @@ -67,15 +73,19 @@ def _parse_parameters(self, **pipeline_parameters):

def preprocess(self, inputs, **kwargs) -> dict:
# tokenizer encoderW
inputs = self.tokenizer.encode(inputs, return_tensors='of', is_global=True)
import oneflow as flow

inputs = flow.tensor(self.tokenizer.encode(inputs, add_bos=True, padding=True))

inputs = {
"input_ids": inputs,
}

return inputs

def forward(self, inputs, **kwargs) -> dict:
outputs = self.model.generate(inputs["input_ids"], max_length=100, **kwargs)
inputs = dist.convert_to_distributed_default_setting(inputs["input_ids"])
outputs = self.model.generate(inputs, max_length=50, **kwargs)
return {"return_ids": outputs}

def postprocess(self, model_output_dict, **kwargs) -> dict:
Expand All @@ -86,21 +96,55 @@ def postprocess(self, model_output_dict, **kwargs) -> dict:
]
return records


if __name__ == "__main__":
# ----- load huggingface checkpoint -----
def build_tokenizer(self, cfg):
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer_cfg = cfg.tokenization.tokenizer
if "vocab_file" not in tokenizer_cfg:
# If "vocab_file" does not exist in the tokenizer's config,
# set it to default as f"{model_path}/vocab.json"
tokenizer_cfg.vocab_file = str(Path(self.model_path).joinpath("vocab.json"))
if "merges_file" not in tokenizer_cfg:
# If "merges_file" does not exist in the tokenizer's config,
# set it to default as f"{model_path}/merges.txt"
tokenizer_cfg.merges_file = str(Path(self.model_path).joinpath("merges.txt"))
tokenizer = DefaultTrainer.build_tokenizer(cfg)
return tokenizer


@click.command()
@click.option(
"--config_file",
default="projects/Qwen/configs/qwen_config.py",
help="Path to the configuration file.",
)
@click.option("--model_path", default=None, help="Path to the model checkpoint.")
@click.option(
"--mode",
default="libai",
help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.",
)
@click.option(
"--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'."
)
def main(config_file, model_path, mode, device):
pipeline = TextGenerationPipeline(
"projects/Qwen/config/qwen_config.py",
config_file,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_num_layers=32,
model_path="/data/home/xiezipeng/hf_models/Qwen/Qwen1.5-7B",
mode="huggingface",
model_path=model_path,
mode=mode,
device=device,
)

text = ["给出3点关于保持身体健康的意见。"]

output = pipeline(inputs=text)
if dist.is_main_process():
print(output)


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions projects/Qwen/qwen_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import oneflow as flow
from oneflow.utils.data import Dataset

from libai.data.structures import DistTensorData, Instance


class QwenDataset(Dataset):
def __init__(self, path, tokenizer):
self.data = flow.load(path)
self.tokenizer = tokenizer

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return Instance(
input_ids=DistTensorData(self.data[index]["input_ids"]),
labels=DistTensorData(self.data[index]["labels"]),
)
37 changes: 0 additions & 37 deletions projects/Qwen/test.py

This file was deleted.

Loading
Loading