diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index ff47ab171..089160e66 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -20,6 +20,7 @@ This folder contains the model builder for quickly creating optimized and quanti - [Use 8 Bits Quantization in QMoE](#use-8-bits-quantization-in-qmoe) - [Hugging Face Authentication](#hugging-face-authentication) - [Use QDQ Pattern for Quantization](#use-qdq-pattern-for-quantization) + - [LoRA Models](#lora-models) - [Unit Testing Models](#unit-testing-models) - [Option 1: Use the model builder directly](#option-1-use-the-model-builder-directly) - [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder) @@ -217,6 +218,20 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1 ``` +#### LoRA Models + +This scenario is where you have a finetuned model with LoRA adapters and your model can be loaded in the Hugging Face style via [PEFT](https://github.com/huggingface/peft). + +- path_to_local_folder_on_disk = location where base_model's weights are present + +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p fp16 -e execution_provider -c cache_dir_to_store_temp_files --extra_options adapter_path=path_to_adapter_files + +# From source: +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p fp16 -e execution_provider -c cache_dir_to_store_temp_files --extra_options adapter_path=path_to_adapter_files +``` + ### Unit Testing Models This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it. diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 8f6117b94..7951de8e2 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -39,8 +39,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.model_type = config.architectures[0] self.io_dtype = io_dtype # {'fp16', 'fp32'} self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"} - self.use_qdq = extra_options.get("use_qdq", False) self.quant_type = config.quantization_config["quant_method"] if hasattr(config, "quantization_config") else None + self.adapter_path = extra_options["adapter_path"] if "adapter_path" in extra_options else None self.cache_dir = cache_dir self.filename = extra_options["filename"] if "filename" in extra_options else "model.onnx" @@ -54,7 +54,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.nodes = [] # EP-specific variables - enable_cuda_graph = "1" if "enable_cuda_graph" in extra_options and extra_options["enable_cuda_graph"] == "1" else "0" + enable_cuda_graph = "1" if "enable_cuda_graph" in extra_options else "0" self.ep = ep self.ep_attrs = { "cpu": {}, @@ -158,6 +158,12 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "add_offset": 0, # Offset value for LayerNorm weight } + # MatMul-specific variables + is_lora = hasattr(config, "peft_type") and config.peft_type == "LORA" + self.matmul_attrs = { + "use_lora": is_lora, # Use LoRA/QLoRA format + } + # RotaryEmbedding-specific variables position_scale = config.rope_position_scale if hasattr(config, "rope_position_scale") else 1 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 @@ -213,11 +219,14 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): vert_block_stride = config.blocksparse_vert_stride if hasattr(config, "blocksparse_vert_stride") else 0 homo_head = config.blocksparse_homo_head_pattern if hasattr(config, "blocksparse_homo_head_pattern") else False self.attention_attrs = { + "q_path": "", # Q path to attention + "k_path": "", # K path to attention + "v_path": "", # V path to attention "op_type": "MultiHeadAttention", # Attention op to use "scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention "use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op) "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) - "block_sparse": { + "block_sparse": { # Block-sparse attention-specific variables "sparse_block_size": sparse_block_size, # Sparse block size for SparseAttention op "kernel_block_size": kernel_block_size, # Kernel block size for sparse attention "local_blocks": local_blocks, # Number of local blocks for sparse attention @@ -237,7 +246,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): print("GroupQueryAttention (GQA) is used in this model.") # DML doesn't support packed Q/K/V for GQA yet - self.attention_attrs["use_packed_matmul"] = self.ep != "dml" + # Packed MatMul with LoRA/QLoRA is not currently supported + self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and not self.matmul_attrs["use_lora"] # GQA + Rot.Emb. does not require `position ids` as input if self.ep != "dml": @@ -280,7 +290,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "int4": { "block_size": int(extra_options["int4_block_size"]) if "int4_block_size" in extra_options else 32, "accuracy_level": int(extra_options["int4_accuracy_level"]) if "int4_accuracy_level" in extra_options else 0, # Default is 0 for non-QDQ formats, default is 4 for QDQ formats - } + }, + "use_qdq": False, # Use QDQ format } if self.quant_type is not None: # Create quantized attributes from quantization config @@ -362,7 +373,7 @@ def save_model(self, out_dir): # Create ONNX model model = helper.make_model( - opset_imports=[self.clear_field(helper.make_operatorsetid('', 21 if self.use_qdq else 14), 'domain'), helper.make_operatorsetid('com.microsoft', 1)], + opset_imports=[self.clear_field(helper.make_operatorsetid('', 21 if self.quant_attrs["use_qdq"] else 14), 'domain'), helper.make_operatorsetid('com.microsoft', 1)], ir_version=7, producer_name="onnxruntime-genai", producer_version="0.0.0", @@ -420,7 +431,7 @@ def to_int4(self, model): is_symmetric=True, accuracy_level=self.quant_attrs["int4"]["accuracy_level"], nodes_to_exclude=[], - quant_format=QuantFormat.QDQ if self.use_qdq else QuantFormat.QOperator, + quant_format=QuantFormat.QDQ if self.quant_attrs["use_qdq"] else QuantFormat.QOperator, ) quant.process() return quant.model.model @@ -666,10 +677,18 @@ def make_tanh(self, name, root_input, dtype, shape): self.make_value_info(output, dtype, shape=shape) def make_matmul(self, matmul, basename, root_input, **kwargs): + if hasattr(matmul, "base_layer"): + # For LoRA `MatMul` + return self.make_matmul_lora(matmul, basename, root_input, **kwargs) + else: + # For regular `MatMul` + return self.make_matmul_op(matmul, basename, root_input, **kwargs) + + def make_matmul_op(self, matmul, basename, root_input, **kwargs): if self.onnx_dtype in {"fp16", "fp32"}: return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs) elif self.onnx_dtype == "int4": - if self.use_qdq: + if self.quant_attrs["use_qdq"]: return self.make_matmul_int4_qdq(matmul, basename, root_input, **kwargs) else: return self.make_matmul_int4(matmul, basename, root_input, **kwargs) @@ -775,6 +794,45 @@ def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs): return matmul_name + def make_matmul_lora(self, matmul, basename, root_input, **kwargs): + # Make nodes for the MatMul-LoRA subgraph + # + # root_input + # | + # +------+------+ + # | | + # MatMul_LoRA_A MatMul + # | | + # MatMul_LoRA_B | + # | | + # +------+------+ + # | + # Add_LoRA_Add + + basename_parts = basename.split("/") + + # Make LoRA MatMul path + matmul_A_basename = "/".join(basename_parts[:-1] + ["lora_A"] + basename_parts[-1:]) + matmul_A_name = self.make_matmul_op(matmul.lora_A.default, matmul_A_basename, root_input=root_input) + lora_A = f"{matmul_A_name}/output_0" + + matmul.lora_B.default.weight *= matmul.scaling["default"] + matmul_B_basename = "/".join(basename_parts[:-1] + ["lora_B"] + basename_parts[-1:]) + matmul_B_name = self.make_matmul_op(matmul.lora_B.default, matmul_B_basename, root_input=lora_A) + lora_B = f"{matmul_B_name}/output_0" + + # Make regular MatMul path + last_dim = matmul.base_layer.weight.shape[0] + matmul_name = self.make_matmul_op(matmul.base_layer, basename, root_input, **kwargs) + + # Make LoRA Add node + add_name = "/".join(basename_parts[:-1] + ["lora", "Add"]) + add_inputs = [f"{matmul_name}/output_0", lora_B] + add_shape = ["batch_size", "sequence_length", last_dim] + self.make_add(add_name, add_inputs, dtype=self.io_dtype, shape=add_shape) + + return add_name + def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs): if self.onnx_dtype in {"fp16", "fp32"}: return self.make_packed_matmul_fp16_or_fp32(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) @@ -1378,26 +1436,22 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # | # O_Add - q_input_to_attention = "" - k_input_to_attention = "" - v_input_to_attention = "" - # Make MatMul nodes if self.attention_attrs["use_packed_matmul"]: # Combine 3 MatMuls into 1 packed MatMul qkv_matmul_basename = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul" qkv_matmul_name = self.make_packed_matmul(attention.q_proj, attention.k_proj, attention.v_proj, qkv_matmul_basename, root_input) - q_input_to_attention = f"{qkv_matmul_name}/output_0" + self.attention_attrs["q_path"] = f"{qkv_matmul_name}/output_0" else: q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" q_matmul_name = self.make_matmul(attention.q_proj, q_matmul_basename, root_input) - q_input_to_attention = f"{q_matmul_name}/output_0" + self.attention_attrs["q_path"] = f"{q_matmul_name}/output_0" k_matmul_basename = f"/model/layers.{layer_id}/attn/k_proj/MatMul" k_matmul_name = self.make_matmul(attention.k_proj, k_matmul_basename, root_input) - k_input_to_attention = f"{k_matmul_name}/output_0" + self.attention_attrs["k_path"] = f"{k_matmul_name}/output_0" v_matmul_basename = f"/model/layers.{layer_id}/attn/v_proj/MatMul" v_matmul_name = self.make_matmul(attention.v_proj, v_matmul_basename, root_input) - v_input_to_attention = f"{v_matmul_name}/output_0" + self.attention_attrs["v_path"] = f"{v_matmul_name}/output_0" # Make Add nodes (if bias exists) q_bias_exists = attention.q_proj.bias is not None and torch.count_nonzero(attention.q_proj.bias) > 0 @@ -1408,21 +1462,21 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): if all_bias_exists and self.attention_attrs["use_packed_matmul"]: # Combine 3 Adds into 1 packed Add qkv_add_name = f"/model/layers.{layer_id}/attn/qkv_proj/Add" - self.make_packed_add(attention.q_proj.bias.detach().numpy(), attention.k_proj.bias.detach().numpy(), attention.v_proj.bias.detach().numpy(), qkv_add_name, root_input=q_input_to_attention) - q_input_to_attention = f"{qkv_add_name}/output_0" + self.make_packed_add(attention.q_proj.bias.detach().numpy(), attention.k_proj.bias.detach().numpy(), attention.v_proj.bias.detach().numpy(), qkv_add_name, root_input=self.attention_attrs["q_path"]) + self.attention_attrs["q_path"] = f"{qkv_add_name}/output_0" else: if q_bias_exists: q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" - self.make_add_bias(attention.q_proj.bias.detach().numpy(), q_add_name, root_input=q_input_to_attention) - q_input_to_attention = f"{q_add_name}/output_0" + self.make_add_bias(attention.q_proj.bias.detach().numpy(), q_add_name, root_input=self.attention_attrs["q_path"]) + self.attention_attrs["q_path"] = f"{q_add_name}/output_0" if k_bias_exists: k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" - self.make_add_bias(attention.k_proj.bias.detach().numpy(), k_add_name, root_input=k_input_to_attention) - k_input_to_attention = f"{k_add_name}/output_0" + self.make_add_bias(attention.k_proj.bias.detach().numpy(), k_add_name, root_input=self.attention_attrs["k_path"]) + self.attention_attrs["k_path"] = f"{k_add_name}/output_0" if v_bias_exists: v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" - self.make_add_bias(attention.v_proj.bias.detach().numpy(), v_add_name, root_input=v_input_to_attention) - v_input_to_attention = f"{v_add_name}/output_0" + self.make_add_bias(attention.v_proj.bias.detach().numpy(), v_add_name, root_input=self.attention_attrs["v_path"]) + self.attention_attrs["v_path"] = f"{v_add_name}/output_0" # Make RotaryEmbedding nodes cos_cache_name, sin_cache_name = "", "" @@ -1430,11 +1484,11 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches(attention.rotary_emb) else: q_rotary_name = f"/model/layers.{layer_id}/attn/q_rotary/RotaryEmbedding" - self.make_rotary_embedding(attention.rotary_emb, q_rotary_name, root_input=q_input_to_attention, position_ids=kwargs.get("position_ids", "position_ids")) - q_input_to_attention = f"{q_rotary_name}/output_0" + self.make_rotary_embedding(attention.rotary_emb, q_rotary_name, root_input=self.attention_attrs["q_path"], position_ids=kwargs.get("position_ids", "position_ids")) + self.attention_attrs["q_path"] = f"{q_rotary_name}/output_0" k_rotary_name = f"/model/layers.{layer_id}/attn/k_rotary/RotaryEmbedding" - self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=k_input_to_attention, position_ids=kwargs.get("position_ids", "position_ids")) - k_input_to_attention = f"{k_rotary_name}/output_0" + self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=self.attention_attrs["k_path"], position_ids=kwargs.get("position_ids", "position_ids")) + self.attention_attrs["k_path"] = f"{k_rotary_name}/output_0" # Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA) past_k = f"past_key_values.{layer_id}.key" @@ -1442,14 +1496,14 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): present_k = f"present.{layer_id}.key" present_v = f"present.{layer_id}.value" if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention": - k_input_to_attention = self.make_repeat_kv(layer_id, root_input=k_input_to_attention, past_kv=past_k, present_kv=present_k) - v_input_to_attention = self.make_repeat_kv(layer_id, root_input=v_input_to_attention, past_kv=past_v, present_kv=present_v) + self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k) + self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v) past_k, past_v, present_k, present_v = "", "", "", "" # Make attention node (e.g. MultiHeadAttention, GroupQueryAttention, etc.) attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" self.make_attention_op( - attn_name, q_path=q_input_to_attention, k_path=k_input_to_attention, v_path=v_input_to_attention, + attn_name, q_path=self.attention_attrs["q_path"], k_path=self.attention_attrs["k_path"], v_path=self.attention_attrs["v_path"], past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, cos_cache=cos_cache_name, sin_cache=sin_cache_name, **kwargs, ) @@ -1471,23 +1525,16 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): self.layernorm_attrs["skip_input"] = f"{o_matmul_name if not o_bias_exists else o_add_name}/output_0" def make_attention_unpacked(self, layer_id, attention, root_input, **kwargs): - q_size = self.num_attn_heads * self.head_size - kv_size = self.num_kv_heads * self.head_size qkv_proj = 'qkv_proj' if hasattr(attention, 'qkv_proj') else 'query_key_value' qkv_linear = eval(f"attention.{qkv_proj}") - attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) - attention.q_proj.weight = torch.nn.Parameter(qkv_linear.weight[: q_size, :]) - attention.q_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[: q_size]) - - attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.k_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size : q_size + kv_size, :]) - attention.k_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size]) - - attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.v_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size + kv_size :, :]) - attention.v_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :]) + if hasattr(qkv_linear, "base_layer"): + # For LoRA packed `MatMul` + return self.make_attention_unpacked_lora(layer_id, attention, qkv_linear, root_input, **kwargs) + else: + # For regular packed `MatMul` + return self.make_attention_unpacked_regular(layer_id, attention, qkv_linear, root_input, **kwargs) # Delete original packed weights and any references to them (e.g. `del qkv_linear` isn't sufficient) del qkv_linear @@ -1496,6 +1543,72 @@ def make_attention_unpacked(self, layer_id, attention, root_input, **kwargs): else: del attention.query_key_value + def make_attention_unpacked_lora(self, layer_id, attention, qkv_linear, root_input, **kwargs): + from peft.tuners.lora.layer import LoraLayer + + q_size = self.num_attn_heads * self.head_size + kv_size = self.num_kv_heads * self.head_size + + # Create Q/K/V base layers + q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) + q_proj.weight = torch.nn.Parameter(qkv_linear.weight[: q_size, :], requires_grad=False) + q_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[: q_size], requires_grad=False) + + k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + k_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size : q_size + kv_size, :], requires_grad=False) + k_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size], requires_grad=False) + + v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + v_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size + kv_size :, :], requires_grad=False) + v_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :], requires_grad=False) + + # Create Q/K/V lora_B layers + lora_B = qkv_linear.lora_B.default + + q_lora_B = torch.nn.Linear(in_features=q_size, out_features=q_size) + q_lora_B.weight = torch.nn.Parameter(lora_B.weight[: q_size, :], requires_grad=False) + q_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[: q_size], requires_grad=False) + + k_lora_B = torch.nn.Linear(in_features=q_size, out_features=kv_size) + k_lora_B.weight = torch.nn.Parameter(lora_B.weight[q_size : q_size + kv_size, :], requires_grad=False) + k_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[q_size : q_size + kv_size], requires_grad=False) + + v_lora_B = torch.nn.Linear(in_features=q_size, out_features=kv_size) + v_lora_B.weight = torch.nn.Parameter(lora_B.weight[q_size + kv_size :, :], requires_grad=False) + v_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[q_size + kv_size :], requires_grad=False) + + # Create Q/K/V LoRA layers + attention.q_proj = LoraLayer(q_proj) + attention.q_proj.lora_A = qkv_linear.lora_A + attention.q_proj.lora_B.default = q_lora_B + attention.q_proj.scaling = qkv_linear.scaling + + attention.k_proj = LoraLayer(k_proj) + attention.k_proj.lora_A = qkv_linear.lora_A + attention.k_proj.lora_B.default = k_lora_B + attention.k_proj.scaling = qkv_linear.scaling + + attention.v_proj = LoraLayer(v_proj) + attention.v_proj.lora_A = qkv_linear.lora_A + attention.v_proj.lora_B.default = v_lora_B + attention.v_proj.scaling = qkv_linear.scaling + + def make_attention_unpacked_regular(self, layer_id, attention, qkv_linear, root_input, **kwargs): + q_size = self.num_attn_heads * self.head_size + kv_size = self.num_kv_heads * self.head_size + + attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) + attention.q_proj.weight = torch.nn.Parameter(qkv_linear.weight[: q_size, :], requires_grad=False) + attention.q_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[: q_size], requires_grad=False) + + attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + attention.k_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size : q_size + kv_size, :], requires_grad=False) + attention.k_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size], requires_grad=False) + + attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + attention.v_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size + kv_size :, :], requires_grad=False) + attention.v_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :], requires_grad=False) + def make_mlp(self, layer_id, mlp, root_input): if self.mlp_attrs["use_proj"]: self.make_mlp_proj(layer_id, mlp, root_input) @@ -1505,13 +1618,50 @@ def make_mlp(self, layer_id, mlp, root_input): raise NotImplementedError(f"The MLP layer type is not set.") def make_mlp_unpacked(self, layer_id, mlp, root_input): + if hasattr(mlp, "base_layer"): + # For LoRA packed `MatMul` + return self.make_mlp_unpacked_lora(layer_id, mlp, root_input) + else: + # For regular packed `MatMul` + return self.make_mlp_unpacked_regular(layer_id, mlp, root_input) + + def make_mlp_unpacked_lora(self, layer_id, mlp, root_input): + from peft.tuners.lora.layer import LoraLayer + + # Create GateProj/UpProj base layers + gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) + gate_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[ : self.intermediate_size, :], requires_grad=False) + + up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) + up_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[self.intermediate_size :, :], requires_grad=False) + + # Create GateProj/UpProj lora_B layers + lora_B = mlp.lora_B.default + + gate_proj_lora_B = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) + gate_proj_lora_B.weight = torch.nn.Parameter(lora_B.weight[ : self.intermediate_size, :], requires_grad=False) + + up_proj_lora_B = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) + up_proj_lora_B.weight = torch.nn.Parameter(lora_B.weight[self.intermediate_size :, :], requires_grad=False) + + # Create GateProj/UpProj LoRA layers + mlp.gate_proj = LoraLayer(q_proj) + mlp.gate_proj.lora_A = mlp.gate_up_proj.lora_A + mlp.gate_proj.lora_B.default = gate_proj_lora_B + mlp.gate_proj.scaling = mlp.gate_up_proj.scaling + + mlp.up_proj = LoraLayer(k_proj) + mlp.up_proj.lora_A = mlp.gate_up_proj.lora_A + mlp.up_proj.lora_B.default = up_proj_lora_B + mlp.up_proj.scaling = mlp.gate_up_proj.scaling + + def make_mlp_unpacked_regular(self, layer_id, mlp, root_input): packed_proj = getattr(mlp, "gate_up_proj", None) or getattr(mlp, "dense_h_to_4h", None) mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) mlp.gate_proj.weight = torch.nn.Parameter(packed_proj.weight[: self.intermediate_size, :]) mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) mlp.up_proj.weight = torch.nn.Parameter(packed_proj.weight[self.intermediate_size :, :]) - # Delete original packed weights del packed_proj @@ -1855,6 +2005,10 @@ def make_model(self, input_path): extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, cache_dir=self.cache_dir, token=self.hf_token, trust_remote_code=True, **extra_kwargs) + if "adapter_path" in self.extra_options: + from peft import PeftModel + model = PeftModel.from_pretrained(model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token) + # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 for module in model.modules(): @@ -2674,16 +2828,16 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): qkv_bias = attention.query_key_value.bias.view(self.num_kv_heads, (self.num_attn_heads // self.num_kv_heads) + 2, self.head_size) attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) - attention.q_proj.weight = torch.nn.Parameter(qkv_weight[:, :, :-2].reshape(q_size, q_size).T) - attention.q_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, :-2].flatten()) + attention.q_proj.weight = torch.nn.Parameter(qkv_weight[:, :, :-2].reshape(q_size, q_size).T, requires_grad=False) + attention.q_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, :-2].flatten(), requires_grad=False) attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.k_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-2]].reshape(q_size, kv_size).T) - attention.k_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-2]].flatten()) + attention.k_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-2]].reshape(q_size, kv_size).T, requires_grad=False) + attention.k_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-2]].flatten(), requires_grad=False) attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.v_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-1]].reshape(q_size, kv_size).T) - attention.v_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-1]].flatten()) + attention.v_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-1]].reshape(q_size, kv_size).T, requires_grad=False) + attention.v_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-1]].flatten(), requires_grad=False) del qkv_weight del qkv_bias @@ -2899,7 +3053,12 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid extra_kwargs = {} if os.path.isdir(input_path) else {"cache_dir": cache_dir} hf_name = input_path if os.path.isdir(input_path) else model_name hf_token = parse_hf_token(extra_options.get("hf_token", "true")) + config = AutoConfig.from_pretrained(hf_name, token=hf_token, trust_remote_code=True, **extra_kwargs) + if "adapter_path" in extra_options: + from peft import PeftConfig + peft_config = PeftConfig.from_pretrained(extra_options["adapter_path"], token=hf_token, trust_remote_code=True, **extra_kwargs) + config.update(peft_config.__dict__) # Set input/output precision of ONNX model io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16 @@ -3047,6 +3206,7 @@ def get_args(): hf_token = false/token: Use this to disable authentication with Hugging Face or provide a custom authentication token that differs from the one stored in your environment. Default behavior is to use the authentication token stored by `huggingface-cli login`. If you have already authenticated via `huggingface-cli login`, you do not need to use this flag because Hugging Face has already stored your authentication token for you. use_qdq = 1 : Use the QDQ decomposition for quantized MatMul instead of the MatMulNBits operator. + adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights). """), )