Skip to content

Commit

Permalink
remove llm-awq dependancy (conflict with autoawq) (#2540)
Browse files Browse the repository at this point in the history
* remove llm-awq dependancy (conflict with autoawq)
* set cuda 12.1 for docker
* awq mixtral conversion from HF hub
  • Loading branch information
vince62s authored Dec 29, 2023
1 parent 05cd7cd commit acb76c3
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 48 deletions.
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,7 @@ flash attention and `F.scaled_dot_product_attention` are a bit faster and saves

AWQ:

If you want to run inference or quantize an AWQ model you will need llm-awq and/or AutoAWQ.

For [llm-awq](https://github.com/mit-han-lab/llm-awq):
git clone https://github.com/mit-han-lab/llm-awq
cd llm-awq
pip install -e .
cd ..
If you want to run inference or quantize an AWQ model you will need AutoAWQ.

For [AutoAWQ](https://github.com/casper-hansen/AutoAWQ):
pip install autoawq
Expand Down
10 changes: 2 additions & 8 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG CUDA_VERSION=11.8.0
ARG CUDA_VERSION=12.1.0
FROM nvidia/cuda:$CUDA_VERSION-devel-ubuntu22.04

RUN apt-get update && apt-get install -y locales gcc g++ python3-dev
Expand Down Expand Up @@ -33,12 +33,6 @@ RUN pip3 install -v --no-build-isolation \
# Install flash-attention
RUN pip install flash-attn --no-build-isolation

# Install llm-awq
RUN git clone https://github.com/mit-han-lab/llm-awq && \
cd llm-awq && \
pip install -e . && \
cd ..

# Install AutoAWQ
RUN pip install autoawq

Expand All @@ -49,4 +43,4 @@ RUN pip install -e .

WORKDIR /

ENTRYPOINT /bin/bash
ENTRYPOINT /bin/bash
4 changes: 2 additions & 2 deletions docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ cd $ROOT_DIR

ONMT_VERSION="$1"
CUDA_VERSION="$2"
[ -z "$CUDA_VERSION" ] && CUDA_VERSION="11.8.0"
[ -z "$CUDA_VERSION" ] && CUDA_VERSION="12.1.0"

IMAGE="ghcr.io/opennmt/opennmt-py"
TAG="$ONMT_VERSION-ubuntu22.04-cuda${CUDA_VERSION%.*}"

echo "Building $IMAGE:$TAG with CUDA_VERSION=$CUDA_VERSION"

docker build -t $IMAGE:$TAG --progress=plain -f docker/Dockerfile --build-arg CUDA_VERSION=$CUDA_VERSION .
docker push $IMAGE:$TAG
docker push $IMAGE:$TAG
2 changes: 2 additions & 0 deletions eval_llm/MMLU/run_mmlu_opennmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ def evaluate(opt):
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, task, k)
prompt = train_prompt + prompt_end
"""
while len(prompt.split()) > 768:
prompt_split = prompt.split("\n\n")
prompt_split.pop(1)
prompt = "\n\n".join(prompt_split)
"""
label = test_df.iloc[i, test_df.shape[1] - 1]
records.append({"prompt": prompt, "answer": label})
src.append(prompt.replace("\n", "⦅newline⦆"))
Expand Down
14 changes: 6 additions & 8 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def load_test_model(opt, device_id=0, model_path=None):
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])

if hasattr(model_opt, "quant_type") and model_opt.quant_type in [
"llm_awq",
"aawq_gemm",
"aawq_gemv",
"awq_gemm",
"awq_gemv",
]: # if the loaded model is a awq quantized one, inference config cannot overwrite this
if (
hasattr(opt, "quant_type")
Expand All @@ -110,9 +109,8 @@ def load_test_model(opt, device_id=0, model_path=None):
)

elif hasattr(opt, "quant_type") and opt.quant_type not in [
"llm_awq",
"aawq_gemm",
"aawq_gemv",
"awq_gemm",
"awq_gemv",
]: # we still want to be able to load fp16/32 models with bnb 4bit to minimize ram footprint
model_opt.quant_layers = opt.quant_layers
model_opt.quant_type = opt.quant_type
Expand Down Expand Up @@ -327,14 +325,14 @@ def build_base_model(model_opt, vocabs):
model = replace_bnb_linear(
model, module_to_convert=nonlora_to_quant, q_type=model_opt.quant_type
)
elif model_opt.quant_type in ["llm_awq", "aawq_gemm", "aawq_gemv"]:
elif model_opt.quant_type in ["awq_gemm", "awq_gemv"]:
logger.info(
"%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant)
)
try:
from onmt.modules.awq_linear import replace_awq_linear
except ImportError:
raise ImportError("Install llm-awq/AutoAWQ to use awq quantized model")
raise ImportError("Install AutoAWQ to use awq quantized model")
model = replace_awq_linear(
model,
module_to_convert=nonlora_to_quant,
Expand Down
10 changes: 2 additions & 8 deletions onmt/modules/awq_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@
def replace_awq_linear(
model, module_to_convert=[], w_bit=4, group_size=128, q_type="llm_awq"
):
if q_type == "llm_awq":
try:
from awq.quantize.qmodule import WQLinear
except ImportError:
raise ImportError("Install llm-awq to use awq")
AWQLin = WQLinear
elif q_type in ["aawq_gemm", "aawq_gemv"]:
if q_type in ["awq_gemm", "awq_gemv"]:
try:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
except ImportError:
raise ImportError("Install AutoAWQ to use awq")
if q_type == "aawq_gemm":
if q_type == "awq_gemm":
AWQLin = WQLinear_GEMM
else:
AWQLin = WQLinear_GEMV
Expand Down
5 changes: 2 additions & 3 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,9 +1598,8 @@ def _add_quant_opts(parser):
"bnb_8bit",
"bnb_FP4",
"bnb_NF4",
"llm_awq",
"aawq_gemm",
"aawq_gemv",
"awq_gemm",
"awq_gemv",
],
type=str,
help="Type of compression.",
Expand Down
150 changes: 138 additions & 12 deletions tools/convert_HF_llamalike.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,65 @@
".feed_forward.layer_norm.weight": ".post_attention_layernorm.weight",
}
key_maps["MistralForCausalLM"] = key_maps["LlamaForCausalLM"]
ln_table = {"LlamaForCausalLM": "rms", "MistralForCausalLM": "rms"}
act_table = {"LlamaForCausalLM": "silu", "MistralForCausalLM": "silu"}
decoder_start_table = {"LlamaForCausalLM": "<s>", "MistralForCausalLM": "<s>"}
key_maps["MixtralForCausalLM"] = {
"layer_prefix": "model.layers.",
"decoder.embeddings.make_embedding.emb_luts.0.weight": "model.embed_tokens.weight",
"decoder.layer_norm.weight": "model.norm.weight",
"generator.weight": "lm_head.weight",
".self_attn.linear_query.": ".self_attn.q_proj.",
".self_attn.linear_keys.": ".self_attn.k_proj.",
".self_attn.linear_values.": ".self_attn.v_proj.",
".self_attn.final_linear.": ".self_attn.o_proj.",
".layer_norm_1.weight": ".input_layernorm.weight",
".feed_forward.gate.weight": ".block_sparse_moe.gate.weight",
".feed_forward.experts.0.w_1.": ".block_sparse_moe.experts.0.w1.",
".feed_forward.experts.0.w_2.": ".block_sparse_moe.experts.0.w2.",
".feed_forward.experts.0.w_3.": ".block_sparse_moe.experts.0.w3.",
".feed_forward.experts.0.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.1.w_1.": ".block_sparse_moe.experts.1.w1.",
".feed_forward.experts.1.w_2.": ".block_sparse_moe.experts.1.w2.",
".feed_forward.experts.1.w_3.": ".block_sparse_moe.experts.1.w3.",
".feed_forward.experts.1.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.2.w_1.": ".block_sparse_moe.experts.2.w1.",
".feed_forward.experts.2.w_2.": ".block_sparse_moe.experts.2.w2.",
".feed_forward.experts.2.w_3.": ".block_sparse_moe.experts.2.w3.",
".feed_forward.experts.2.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.3.w_1.": ".block_sparse_moe.experts.3.w1.",
".feed_forward.experts.3.w_2.": ".block_sparse_moe.experts.3.w2.",
".feed_forward.experts.3.w_3.": ".block_sparse_moe.experts.3.w3.",
".feed_forward.experts.3.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.4.w_1.": ".block_sparse_moe.experts.4.w1.",
".feed_forward.experts.4.w_2.": ".block_sparse_moe.experts.4.w2.",
".feed_forward.experts.4.w_3.": ".block_sparse_moe.experts.4.w3.",
".feed_forward.experts.4.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.5.w_1.": ".block_sparse_moe.experts.5.w1.",
".feed_forward.experts.5.w_2.": ".block_sparse_moe.experts.5.w2.",
".feed_forward.experts.5.w_3.": ".block_sparse_moe.experts.5.w3.",
".feed_forward.experts.5.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.6.w_1.": ".block_sparse_moe.experts.6.w1.",
".feed_forward.experts.6.w_2.": ".block_sparse_moe.experts.6.w2.",
".feed_forward.experts.6.w_3.": ".block_sparse_moe.experts.6.w3.",
".feed_forward.experts.6.layer_norm.weight": ".post_attention_layernorm.weight",
".feed_forward.experts.7.w_1.": ".block_sparse_moe.experts.7.w1.",
".feed_forward.experts.7.w_2.": ".block_sparse_moe.experts.7.w2.",
".feed_forward.experts.7.w_3.": ".block_sparse_moe.experts.7.w3.",
".feed_forward.experts.7.layer_norm.weight": ".post_attention_layernorm.weight",
}
ln_table = {
"LlamaForCausalLM": "rms",
"MistralForCausalLM": "rms",
"MixtralForCausalLM": "rms",
}
act_table = {
"LlamaForCausalLM": "silu",
"MistralForCausalLM": "silu",
"MixtralForCausalLM": "silu",
}
decoder_start_table = {
"LlamaForCausalLM": "<s>",
"MistralForCausalLM": "<s>",
"MixtralForCausalLM": "<s>",
}


class Tokenizer:
Expand Down Expand Up @@ -233,7 +289,14 @@ def __init__(self, model_path: str):
sliding_window = 4096
else:
sliding_window = 0

if "num_local_experts" in config.keys():
num_experts = config["num_local_experts"]
else:
num_local_experts = 0
if "num_experts_per_tok" in config.keys():
num_experts_per_tok = config["num_experts_per_tok"]
else:
num_experts_per_tok = 0
if "quantization_config" in config.keys():
if (
"quant_method" in config["quantization_config"].keys()
Expand All @@ -242,22 +305,22 @@ def __init__(self, model_path: str):
if "backend" in config["quantization_config"].keys():
backend = config["quantization_config"]["backend"]
if backend == "llm-awq":
quant_type = "llm_awq"
quant_type = "awq_gemv"
elif backend == "autoawq":
if config["quantization_config"]["version"].lower() == "gemm":
quant_type = "aawq_gemm"
quant_type = "awq_gemm"
elif config["quantization_config"]["version"].lower() == "gemv":
quant_type = "aawq_gemv"
quant_type = "awq_gemv"
else:
raise ValueError("Unknown quantization config")
else:
raise ValueError("Unknown backend config")
else:
print("Backend not specified in config, using Autoawq")
if config["quantization_config"]["version"].lower() == "gemm":
quant_type = "aawq_gemm"
quant_type = "awq_gemm"
elif config["quantization_config"]["version"].lower() == "gemv":
quant_type = "aawq_gemv"
quant_type = "awq_gemv"
else:
raise ValueError("Unknown quantization config")
else:
Expand Down Expand Up @@ -403,6 +466,30 @@ def get_weight(checkpoint, tensor_name):
".feed_forward.w_1.",
".feed_forward.w_2.",
".feed_forward.w_3.",
".feed_forward.experts.0.w_1.",
".feed_forward.experts.0.w_2.",
".feed_forward.experts.0.w_3.",
".feed_forward.experts.1.w_1.",
".feed_forward.experts.1.w_2.",
".feed_forward.experts.1.w_3.",
".feed_forward.experts.2.w_1.",
".feed_forward.experts.2.w_2.",
".feed_forward.experts.2.w_3.",
".feed_forward.experts.3.w_1.",
".feed_forward.experts.3.w_2.",
".feed_forward.experts.3.w_3.",
".feed_forward.experts.4.w_1.",
".feed_forward.experts.4.w_2.",
".feed_forward.experts.4.w_3.",
".feed_forward.experts.5.w_1.",
".feed_forward.experts.5.w_2.",
".feed_forward.experts.5.w_3.",
".feed_forward.experts.6.w_1.",
".feed_forward.experts.6.w_2.",
".feed_forward.experts.6.w_3.",
".feed_forward.experts.7.w_1.",
".feed_forward.experts.7.w_2.",
".feed_forward.experts.7.w_3.",
]
for target in targetlist:
if target in key_maps[arch].keys():
Expand Down Expand Up @@ -471,20 +558,57 @@ def get_weight(checkpoint, tensor_name):
+ ".layer_norm_res."
+ p
] = w
if ".feed_forward.layer_norm.weight" in key_maps[arch].keys():
if ".feed_forward.layer_norm." + p in key_maps[arch].keys():
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][".feed_forward.layer_norm." + p],
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ ".feed_forward.layer_norm."
+ p
] = w

if ".feed_forward.gate." + p in key_maps[arch].keys():
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][".feed_forward.layer_norm.weight"],
+ key_maps[arch][".feed_forward.gate." + p],
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ ".feed_forward.layer_norm.weight"
+ ".feed_forward.gate."
+ p
] = w

for j in range(num_experts):
if (
f".feed_forward.experts.{j}.layer_norm." + p
in key_maps[arch].keys()
):
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][
f".feed_forward.experts.{j}.layer_norm." + p
],
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ f".feed_forward.experts.{j}.layer_norm."
+ p
] = w

if shard == 0:
vocab_size = onmt_safetensor["generator.weight"].size(0)
if opt.format == "safetensors":
Expand Down Expand Up @@ -741,6 +865,8 @@ def get_weight(checkpoint, tensor_name):
quant_type=quant_type,
w_bit=w_bit,
group_size=group_size,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
print("Saving the pytorch file")
if opt.output[-3:] == ".pt":
Expand Down

0 comments on commit acb76c3

Please sign in to comment.