Skip to content

Commit

Permalink
fix bnb loading (#2529)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Nov 29, 2023
1 parent 78c8908 commit a147137
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,24 @@ def load_test_model(opt, device_id=0, model_path=None):

model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])

if hasattr(model_opt, "quant_type") and model_opt.quant_type not in [
if hasattr(model_opt, "quant_type") and model_opt.quant_type in [
"llm_awq",
"aawq_gemm",
"aawq_gemv",
]:
]: # if the loaded model is a awq quantized one, inference config cannot overwrite this
if hasattr(opt, "quant_type") and opt.quant_type != model_opt.quant_type:
raise ValueError(
"Model is a awq quantized model, cannot overwrite with another quant method"
)

elif hasattr(opt, "quant_type") and opt.quant_type not in [
"llm_awq",
"aawq_gemm",
"aawq_gemv",
]: # we still want to be able to load fp16/32 models with bnb 4bit to minimize ram footprint
model_opt.quant_layers = opt.quant_layers
model_opt.quant_type = opt.quant_type
model_opt.lora_layers = []

if opt.world_size > 1 and opt.parallel_mode == "tensor_parallel":
model_opt.world_size = opt.world_size
Expand Down

0 comments on commit a147137

Please sign in to comment.