-
Notifications
You must be signed in to change notification settings - Fork 52
/
deepseek_moe_w4a16.py
120 lines (95 loc) · 3.23 KB
/
deepseek_moe_w4a16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-V2.5"
# adjust based off number of desired GPUs
# if not enough memory is available, some layers will automatically be offlaoded to cpu
device_map = calculate_offload_device_map(
MODEL_ID,
reserve_for_hessians=True,
num_gpus=2,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}
ds = ds.map(preprocess)
# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
ds = ds.map(tokenize, remove_columns=ds.column_names)
# define a llmcompressor recipe for W416 quantization
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = "deepseek_recipe_w4a16.yaml"
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16"
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
save_compressed=True,
output_dir=SAVE_DIR,
)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))
print("==========================================")
# Run the model on vLLM
try:
from vllm import LLM, SamplingParams
vllm_installed = True
except ImportError:
vllm_installed = False
if vllm_installed:
print("vLLM installed, running using vLLM")
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
llm = LLM(
model=SAVE_DIR,
tensor_parallel_size=2,
trust_remote_code=True,
max_model_len=1042,
dtype=torch.half,
)
prompts = [
"The capital of France is",
"The president of the US is",
"My name is",
]
outputs = llm.generate(prompts, sampling_params)
print("================= vLLM GENERATION ======================")
for output in outputs:
assert output
prompt = output.prompt
generated_text = output.outputs[0].text
print("PROMPT", prompt)
print("GENERATED TEXT", generated_text)