From 1d872d740379fcbde537a644bd49959573325093 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Thu, 5 Sep 2024 15:21:31 -0700 Subject: [PATCH] build: Upgrade to 24.08, TRT-LLM 0.12.0, and Triton CLI v0.0.11 (#83) Co-authored-by: David Yastremsky <58150256+dyastremsky@users.noreply.github.com> --- README.md | 27 +- pyproject.toml | 6 +- src/triton_cli/__init__.py | 2 +- src/triton_cli/docker/Dockerfile | 8 +- src/triton_cli/profile.py | 8 +- src/triton_cli/repository.py | 6 + .../trt_llm/postprocessing/1/model.py | 21 +- .../trt_llm/postprocessing/config.pbtxt | 11 + .../trt_llm/preprocessing/1/model.py | 116 +- .../trt_llm/preprocessing/config.pbtxt | 28 +- .../templates/trt_llm/tensorrt_llm/1/model.py | 397 +++- .../trt_llm/tensorrt_llm/config.pbtxt | 21 + .../trt_llm/tensorrt_llm_bls/1/lib/decode.py | 131 +- .../tensorrt_llm_bls/1/lib/triton_decoder.py | 171 +- .../trt_llm/tensorrt_llm_bls/1/model.py | 42 +- .../trt_llm/tensorrt_llm_bls/config.pbtxt | 23 +- src/triton_cli/trt_llm/builder.py | 2 + .../gpt2/convert_checkpoint.py | 2046 ++--------------- .../llama/convert_checkpoint.py | 122 +- tests/test_cli.py | 4 +- 20 files changed, 956 insertions(+), 2236 deletions(-) diff --git a/README.md b/README.md index 93772d5..cc96537 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@ and running the CLI from within the latest corresponding `tritonserver` container image, which should have all necessary system dependencies installed. For vLLM and TRT-LLM, you can use their respective images: -- `nvcr.io/nvidia/tritonserver:24.07-vllm-python-py3` -- `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` +- `nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3` +- `nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3` If you decide to run the CLI on the host or in a custom image, please see this list of [additional dependencies](#additional-dependencies-for-custom-environments) @@ -38,6 +38,7 @@ matrix below: | Triton CLI Version | TRT-LLM Version | Triton Container Tag | |:------------------:|:---------------:|:--------------------:| +| 0.0.11 | v0.12.0 | 24.08 | | 0.0.10 | v0.11.0 | 24.07 | | 0.0.9 | v0.10.0 | 24.06 | | 0.0.8 | v0.9.0 | 24.05 | @@ -57,7 +58,7 @@ It is also possible to install from a specific branch name, a commit hash or a tag name. For example to install `triton_cli` with a specific tag: ```bash -GIT_REF="0.0.10" +GIT_REF="0.0.11" pip install git+https://github.com/triton-inference-server/triton_cli.git@${GIT_REF} ``` @@ -92,7 +93,7 @@ triton -h triton import -m gpt2 # Start server pointing at the default model repository -triton start --image nvcr.io/nvidia/tritonserver:24.07-vllm-python-py3 +triton start --image nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3 # Infer with CLI triton infer -m gpt2 --prompt "machine learning is" @@ -135,6 +136,8 @@ The following models have currently been tested for vLLM through the CLI: - `llama-2-7b-chat` - `llama-3-8b` - `llama-3-8b-instruct` +- `llama-3.1-8b` +- `llama-3.1-8b-instruct` #### Example @@ -146,10 +149,10 @@ docker run -ti \ --shm-size=1g --ulimit memlock=-1 \ -v ${HOME}/models:/root/models \ -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ - nvcr.io/nvidia/tritonserver:24.07-vllm-python-py3 + nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3 # Install the Triton CLI -pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.10 +pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.11 # Authenticate with huggingface for restricted models like Llama-2 and Llama-3 huggingface-cli login @@ -193,10 +196,14 @@ engine builds through the CLI: - `llama-2-7b-chat` - `llama-3-8b` - `llama-3-8b-instruct` +- `llama-3.1-8b` +- `llama-3.1-8b-instruct` > [!NOTE] -> Building a TRT-LLM engine for Llama-2-7B or Llama-3-8B models -> may require system RAM of at least 48GB of RAM. +> 1. Building a TRT-LLM engine for Llama-2-7B, Llama-3-8B, or Llama-3.1-8B +> models may require system RAM of at least 48GB of RAM. +> +> 2. Llama 3.1 may require `pip install transformers>=4.43.1` #### Example @@ -215,10 +222,10 @@ docker run -ti \ -v /tmp:/tmp \ -v ${HOME}/models:/root/models \ -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ - nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 + nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3 # Install the Triton CLI -pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.10 +pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.11 # Authenticate with huggingface for restricted models like Llama-2 and Llama-3 huggingface-cli login diff --git a/pyproject.toml b/pyproject.toml index 27f42b2..080aa17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,10 +47,10 @@ keywords = [] requires-python = ">=3.10,<4" # TODO: Add [gpu] set of dependencies for trtllm once it's available on pypi dependencies = [ - "grpcio>=1.64.0", + "grpcio>=1.65.5", "directory-tree == 0.0.4", # may remove in future "docker == 6.1.3", - "genai-perf @ git+https://github.com/triton-inference-server/client.git@r24.07#subdirectory=src/c++/perf_analyzer/genai-perf", + "genai-perf @ git+https://github.com/triton-inference-server/perf_analyzer.git@r24.08#subdirectory=genai-perf", # TODO: rely on tritonclient to pull in protobuf and numpy dependencies? "numpy >=1.21,<2", "protobuf>=3.7.0", @@ -59,7 +59,7 @@ dependencies = [ "rich == 13.5.2", # TODO: Test on cpu-only machine if [cuda] dependency is an issue, # Use explicit client version matching genai-perf version for tagged release - "tritonclient[all] == 2.48", + "tritonclient[all] == 2.49", "huggingface-hub >= 0.19.4", # Testing "pytest >= 8.1.1", # may remove later diff --git a/src/triton_cli/__init__.py b/src/triton_cli/__init__.py index 643ba86..1adf4f9 100644 --- a/src/triton_cli/__init__.py +++ b/src/triton_cli/__init__.py @@ -24,4 +24,4 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -__version__ = "0.0.10" +__version__ = "0.0.11" diff --git a/src/triton_cli/docker/Dockerfile b/src/triton_cli/docker/Dockerfile index 407813a..7bc4c87 100644 --- a/src/triton_cli/docker/Dockerfile +++ b/src/triton_cli/docker/Dockerfile @@ -1,9 +1,11 @@ # TRT-LLM image contains engine building and runtime dependencies -FROM nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 +FROM nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3 # Setup vLLM Triton backend RUN mkdir -p /opt/tritonserver/backends/vllm && \ - wget -P /opt/tritonserver/backends/vllm https://raw.githubusercontent.com/triton-inference-server/vllm_backend/r24.07/src/model.py + git clone -b r24.08 https://github.com/triton-inference-server/vllm_backend.git /tmp/vllm_backend && \ + cp -r /tmp/vllm_backend/src/* /opt/tritonserver/backends/vllm && \ + rm -r /tmp/vllm_backend # vLLM runtime dependencies -RUN pip install "vllm==0.5.0.post1" +RUN pip install "vllm==0.5.3.post1" "setuptools==74.0.0" diff --git a/src/triton_cli/profile.py b/src/triton_cli/profile.py index 60cb575..ad97971 100755 --- a/src/triton_cli/profile.py +++ b/src/triton_cli/profile.py @@ -34,7 +34,7 @@ # ================================================ def build_command(args: argparse.Namespace, executable: str): skip_args = ["func"] - cmd = [executable] + cmd = [executable, "profile"] for arg, value in vars(args).items(): if arg in skip_args: pass @@ -45,12 +45,6 @@ def build_command(args: argparse.Namespace, executable: str): cmd += [f"-{arg}"] else: cmd += [f"--{arg}"] - # [DLIS-6656] - Remove backend renaming. - # This allows "tensorrtllm" to be used as the backend for consistency. - # Once GenAI-Perf releases 24.05, "tensorrtllm" as the backend value - # will be supported by default. - elif arg == "backend" and value in ["tensorrtllm", "trtllm"]: - cmd += ["--backend", "tensorrtllm"] else: if len(arg) == 1: cmd += [f"-{arg}", f"{value}"] diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index c06ef0a..6a853ef 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -92,6 +92,12 @@ "meta-llama/Meta-Llama-3-8B-Instruct": { "hf_allow_patterns": ["*.safetensors", "*.json"], }, + "meta-llama/Meta-Llama-3.1-8B": { + "hf_allow_patterns": ["*.safetensors", "*.json"], + }, + "meta-llama/Meta-Llama-3.1-8B-Instruct": { + "hf_allow_patterns": ["*.safetensors", "*.json"], + }, "gpt2": { "hf_allow_patterns": ["*.safetensors", "*.json"], "hf_ignore_patterns": ["onnx/*"], diff --git a/src/triton_cli/templates/trt_llm/postprocessing/1/model.py b/src/triton_cli/templates/trt_llm/postprocessing/1/model.py index 4ab14fb..b3415fd 100644 --- a/src/triton_cli/templates/trt_llm/postprocessing/1/model.py +++ b/src/triton_cli/templates/trt_llm/postprocessing/1/model.py @@ -142,6 +142,10 @@ def execute(self, requests): generation_logits = pb_utils.get_input_tensor_by_name( request, 'GENERATION_LOGITS') + # Get the batch index + batch_index = pb_utils.get_input_tensor_by_name( + request, 'BATCH_INDEX') + # Reshape Input # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) # tokens_batch = tokens_batch.T @@ -197,6 +201,15 @@ def execute(self, requests): np.array([[[[0.0]]]], dtype=np.float32)) outputs.append(out_generation_logits) + if batch_index: + out_batch_index = pb_utils.Tensor('OUT_BATCH_INDEX', + batch_index.as_numpy()) + outputs.append(out_batch_index) + else: + out_batch_index = pb_utils.Tensor( + 'OUT_BATCH_INDEX', np.array([[0]], dtype=np.int32)) + outputs.append(out_batch_index) + # Create InferenceResponse. You can set an error here in case # there was a problem with handling this inference request. # Below is an example of how you can set errors in inference @@ -224,8 +237,14 @@ def _postprocessing(self, tokens_batch, sequence_lengths): for batch_idx, beam_tokens in enumerate(tokens_batch): for beam_idx, tokens in enumerate(beam_tokens): seq_len = sequence_lengths[batch_idx][beam_idx] + # Exclude fake ids in multimodal models + fake_id_len = 0 + for i in range(seq_len): + if tokens[i] < len(self.tokenizer.vocab): + fake_id_len = i + break output = self.tokenizer.decode( - tokens[:seq_len], + tokens[fake_id_len:seq_len], skip_special_tokens=self.skip_special_tokens) outputs.append(output.encode('utf8')) return outputs diff --git a/src/triton_cli/templates/trt_llm/postprocessing/config.pbtxt b/src/triton_cli/templates/trt_llm/postprocessing/config.pbtxt index aaecb13..a1c2eb2 100644 --- a/src/triton_cli/templates/trt_llm/postprocessing/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/postprocessing/config.pbtxt @@ -61,6 +61,12 @@ input [ data_type: TYPE_FP32 dims: [ -1, -1, -1 ] optional: true + }, + { + name: "BATCH_INDEX" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true } ] output [ @@ -88,6 +94,11 @@ output [ name: "OUT_GENERATION_LOGITS" data_type: TYPE_FP32 dims: [ -1, -1, -1 ] + }, + { + name: "OUT_BATCH_INDEX" + data_type: TYPE_INT32 + dims: [ 1 ] } ] diff --git a/src/triton_cli/templates/trt_llm/preprocessing/1/model.py b/src/triton_cli/templates/trt_llm/preprocessing/1/model.py index ed09cd4..76ef8dc 100644 --- a/src/triton_cli/templates/trt_llm/preprocessing/1/model.py +++ b/src/triton_cli/templates/trt_llm/preprocessing/1/model.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import os from typing import List import numpy as np @@ -59,6 +60,11 @@ def initialize(self, args): add_special_tokens = model_config['parameters'].get( 'add_special_tokens') + visual_model_path = model_config['parameters']['visual_model_path'][ + 'string_value'] + if visual_model_path == "${visual_model_path}" or visual_model_path == "": + visual_model_path = None + if add_special_tokens is not None: add_special_tokens_str = add_special_tokens['string_value'].lower() if add_special_tokens_str in [ @@ -93,6 +99,28 @@ def initialize(self, args): self.tokenizer_pad_id = self.tokenizer.encode( self.tokenizer.pad_token, add_special_tokens=False)[0] + self.is_multimodal = False + if visual_model_path is not None: + self.is_multimodal = True + visual_model_path = os.path.join(visual_model_path, 'config.json') + with open(visual_model_path) as f: + visual_model_config = json.load(f) + self.model_type = visual_model_config['builder_config'][ + 'model_type'] + + assert self.model_type in [ + 'llava', 'blip2-opt' + ], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava and blip2-opt" + + llm_model_path = model_config['parameters']['gpt_model_path'][ + 'string_value'] + llm_model_path = os.path.join(llm_model_path, 'config.json') + with open(llm_model_path) as f: + llm_model_config = json.load(f) + self.vocab_size = int( + llm_model_config["pretrained_config"]["vocab_size"]) + self._setup_ptable_shape(llm_model_config) + # Parse model output configs and convert Triton types to numpy types output_names = [ "INPUT_ID", "DECODER_INPUT_ID", "REQUEST_INPUT_LEN", @@ -116,6 +144,16 @@ def initialize(self, args): pb_utils.get_output_config_by_name( model_config, output_name)['data_type'])) + def _setup_ptable_shape(self, llm_model_config): + max_prompt_embedding_table_size = llm_model_config['build_config'][ + 'max_prompt_embedding_table_size'] + max_batch_size = llm_model_config['build_config']['max_batch_size'] + + num_visual_features = max_prompt_embedding_table_size // max_batch_size + hidden_size = llm_model_config['pretrained_config']['hidden_size'] + + self.ptable_shape = (-1, num_visual_features, hidden_size) + def execute(self, requests): """`execute` must be implemented in every Python model. `execute` function receives a list of pb_utils.InferenceRequest as the only @@ -140,27 +178,17 @@ def execute(self, requests): # Every Python backend must iterate over everyone of the requests # and create a pb_utils.InferenceResponse for each of them. - logger = pb_utils.Logger for idx, request in enumerate(requests): # Get input tensors query = pb_utils.get_input_tensor_by_name(request, 'QUERY').as_numpy() + batch_size = query.shape[0] + decoder_query = pb_utils.get_input_tensor_by_name( request, 'DECODER_QUERY') if decoder_query is not None: decoder_query = decoder_query.as_numpy() - batch_dim = query.shape[0] - if batch_dim != 1: - - err_str = "Inflight batching backend expects requests with batch size of 1." - logger.log_error(err_str) - responses.append( - pb_utils.InferenceResponse( - output_tensors=[], - error=pb_utils.TritonError(err_str))) - continue - request_output_len = pb_utils.get_input_tensor_by_name( request, 'REQUEST_OUTPUT_LEN').as_numpy() @@ -190,7 +218,7 @@ def execute(self, requests): if end_id is not None: end_id = end_id.as_numpy() else: - end_id = [[self.tokenizer_end_id]] + end_id = [[self.tokenizer_end_id]] * batch_size # Take the pad_id from the input tensors # If not specified, use tokenizer to get pad_id @@ -198,7 +226,7 @@ def execute(self, requests): if pad_id is not None: pad_id = pad_id.as_numpy() else: - pad_id = [[self.tokenizer_pad_id]] + pad_id = [[self.tokenizer_pad_id]] * batch_size # Preprocessing input data. input_id, request_input_len = self._create_request(query) @@ -206,15 +234,16 @@ def execute(self, requests): decoder_input_id, request_decoder_input_len = self._create_request( decoder_query) else: - decoder_input_id = pad_id * np.ones((1, 1), np.int32) - request_decoder_input_len = 1 * np.ones((1, 1), np.int32) + decoder_input_id = pad_id * np.ones((batch_size, 1), np.int32) + request_decoder_input_len = 1 * np.ones( + (batch_size, 1), np.int32) - bad_words = self._to_word_list_format(bad_words_dict) - stop_words = self._to_word_list_format(stop_words_dict) + bad_words = self._to_word_list_format(bad_words_dict, batch_size) + stop_words = self._to_word_list_format(stop_words_dict, batch_size) embedding_bias = self._get_embedding_bias( embedding_bias_words, embedding_bias_weights, - self.embedding_bias_weights_dtype) + self.embedding_bias_weights_dtype, batch_size) # Create output tensors. You need pb_utils.Tensor # objects to create pb_utils.InferenceResponse. @@ -279,6 +308,43 @@ def _create_request(self, query): add_special_tokens=self.add_special_tokens)).astype( int) for s in query ] + + if self.is_multimodal: + if 'blip2' in self.model_type: + pre_prompt = None + post_prompt = None + elif 'llava' == self.model_type: + pre_prompt = "USER:\n" + post_prompt = " ASSISTANT:" + + fake_prompt_id = np.arange(self.vocab_size, + self.vocab_size + self.ptable_shape[1]) + + if pre_prompt is not None: + pre_prompt_id = np.array( + self.tokenizer.encode( + pre_prompt, + add_special_tokens=self.add_special_tokens, + padding=True)) + + if post_prompt is not None: + post_prompt_id = np.array( + self.tokenizer.encode( + post_prompt, + add_special_tokens=self.add_special_tokens, + padding=True)) + + if post_prompt is None: + start_ids = [ + np.concatenate((fake_prompt_id, ids), axis=0) + for ids in start_ids + ] + else: + start_ids = [ + np.concatenate( + (pre_prompt_id, fake_prompt_id, ids, post_prompt_id), + axis=0) for ids in start_ids + ] start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) max_len = 0 @@ -293,7 +359,8 @@ def _create_request(self, query): return start_ids, start_lengths - def _to_word_list_format(self, word_lists: List[List[str | bytes]]): + def _to_word_list_format(self, word_lists: List[List[str | bytes]], + batch_size): ''' word_lists format: len(word_lists) == batch_size @@ -303,7 +370,7 @@ def _to_word_list_format(self, word_lists: List[List[str | bytes]]): if word_lists is None: # Return an empty array of shape (1,2,0) - return np.empty([1, 2, 0], dtype="int32") + return np.empty([batch_size, 2, 0], dtype="int32") flat_ids = [] offsets = [] @@ -337,18 +404,19 @@ def _to_word_list_format(self, word_lists: List[List[str | bytes]]): (1, 0, 2)) def _get_embedding_bias(self, embedding_bias_words, embedding_bias_weights, - bias_dtype): + bias_dtype, batch_size): assert self.tokenizer != None, "need to set tokenizer" if embedding_bias_words is None or embedding_bias_weights is None: - return np.empty([1, 0], dtype=self.embedding_bias_weights_dtype) + return np.empty([batch_size, 0], + dtype=self.embedding_bias_weights_dtype) batch_embedding_bias = [] for words, weights in zip(embedding_bias_words, embedding_bias_weights): - vocab_size = self.tokenizer.vocab_size + vocab_size = len(self.tokenizer.vocab) embedding_bias = [0.] * vocab_size assert len(words) == len( diff --git a/src/triton_cli/templates/trt_llm/preprocessing/config.pbtxt b/src/triton_cli/templates/trt_llm/preprocessing/config.pbtxt index 165134c..18d2551 100644 --- a/src/triton_cli/templates/trt_llm/preprocessing/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/preprocessing/config.pbtxt @@ -31,18 +31,18 @@ input [ { name: "QUERY" data_type: TYPE_STRING - dims: [ -1 ] + dims: [ 1 ] }, { name: "DECODER_QUERY" data_type: TYPE_STRING - dims: [ -1 ] + dims: [ 1 ] optional: true }, { name: "REQUEST_OUTPUT_LEN" data_type: TYPE_INT32 - dims: [ -1 ] + dims: [ 1 ] }, { name: "BAD_WORDS_DICT" @@ -71,13 +71,13 @@ input [ { name: "END_ID" data_type: TYPE_INT32 - dims: [ -1 ] + dims: [ 1 ] optional: true }, { name: "PAD_ID" data_type: TYPE_INT32 - dims: [ -1 ] + dims: [ 1 ] optional: true } ] @@ -125,12 +125,12 @@ output [ { name: "OUT_END_ID" data_type: TYPE_INT32 - dims: [ -1 ] + dims: [ 1 ] }, { name: "OUT_PAD_ID" data_type: TYPE_INT32 - dims: [ -1 ] + dims: [ 1 ] } ] @@ -148,6 +148,20 @@ parameters { } } +parameters { + key: "visual_model_path" + value: { + string_value: "${visual_model_path}" + } +} + +parameters: { + key: "gpt_model_path" + value: { + string_value: "${engine_dir}" + } +} + instance_group [ { count: ${preprocessing_instance_count} diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py b/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py index 3bbf86d..51c5bc7 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py @@ -1,31 +1,67 @@ import datetime import json import os +import sys import time +from random import randint from threading import Lock, Thread import numpy as np +import torch import triton_python_backend_utils as pb_utils from torch import from_numpy +from torch.utils.dlpack import from_dlpack import tensorrt_llm.bindings.executor as trtllm -def get_input_tensor_by_name(request, name): +def get_input_tensor_by_name(request, + name, + expected_batch_size=None, + batch_index=None): tensor = pb_utils.get_input_tensor_by_name(request, name) if tensor is None: return None - return tensor.as_numpy() + if tensor.is_cpu(): + tensor = tensor.as_numpy() + else: + tensor = from_dlpack(tensor.to_dlpack()) -def get_input_scalar_by_name(request, name): - tensor = get_input_tensor_by_name(request, name) + if expected_batch_size is not None and tensor.shape[ + 0] != expected_batch_size: + raise pb_utils.TritonModelException( + f"Expected batch size doesn't match batch size for tensor {name}. Expected {expected_batch_size} got {tensor.shape[0]}" + ) + + if batch_index is not None and expected_batch_size is not None and batch_index >= expected_batch_size: + raise pb_utils.TritonModelException( + f"Invalid batch index in get_input_tensor_by_name for {name}") + + if batch_index is not None: + # Add leading 1 batch dimension + if isinstance(tensor, np.ndarray): + return np.expand_dims(tensor[batch_index], axis=0) + elif isinstance(tensor, torch.Tensor): + return torch.unsqueeze(tensor[batch_index], dim=0) + else: + return tensor + + +def get_input_scalar_by_name(request, + name, + expected_batch_size=1, + batch_index=0): + tensor = pb_utils.get_input_tensor_by_name(request, name) if tensor is None: return None - if tensor.size != 1: + tensor = tensor.as_numpy() + + if tensor.size != expected_batch_size: raise pb_utils.TritonModelException( - f"Expected a single value for {name}") - return tensor.item() + f"Expected a scalar tensor for tensor {name}") + + return tensor.item(batch_index) def read_parameter_as_type(value, name, pytype=str): @@ -90,88 +126,114 @@ def parse_medusa_choices(medusa_choices): return result -def get_sampling_config_from_request(request): +def get_sampling_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} - kwargs['beam_width'] = get_input_scalar_by_name(request, 'beam_width') or 1 - kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k') - kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p') + kwargs['beam_width'] = get_input_scalar_by_name( + request, 'beam_width', batch_size, batch_index) or 1 + kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k', + batch_size, batch_index) + kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p', + batch_size, batch_index) kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[ 'top_p'] <= 0 else kwargs['top_p'] - kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed') - kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature') - kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length') + kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed', + batch_size, batch_index) + kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature', + batch_size, batch_index) + kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length', + batch_size, batch_index) kwargs['repetition_penalty'] = get_input_scalar_by_name( - request, 'repetition_penalty') + request, 'repetition_penalty', batch_size, batch_index) kwargs['presence_penalty'] = get_input_scalar_by_name( - request, 'presence_penalty') + request, 'presence_penalty', batch_size, batch_index) kwargs['frequency_penalty'] = get_input_scalar_by_name( - request, 'frequency_penalty') - kwargs['length_penalty'] = get_input_scalar_by_name(request, 'len_penalty') + request, 'frequency_penalty', batch_size, batch_index) + kwargs['length_penalty'] = get_input_scalar_by_name( + request, 'len_penalty', batch_size, batch_index) kwargs['top_p_min'] = get_input_scalar_by_name(request, - 'runtime_top_p_min') + 'runtime_top_p_min', + batch_size, batch_index) kwargs['top_p_reset_ids'] = get_input_scalar_by_name( - request, 'runtime_top_p_reset_ids') + request, 'runtime_top_p_reset_ids', batch_size, batch_index) kwargs['top_p_decay'] = get_input_scalar_by_name(request, - 'runtime_top_p_decay') + 'runtime_top_p_decay', + batch_size, batch_index) kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name( - request, 'beam_search_diversity_rate') + request, 'beam_search_diversity_rate', batch_size, batch_index) kwargs['early_stopping'] = get_input_scalar_by_name( - request, 'early_stopping') + request, 'early_stopping', batch_size, batch_index) kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.SamplingConfig(**kwargs) -def get_output_config_from_request(request, exclude_input_from_output): +def get_output_config_from_request(request, + exclude_input_from_output, + batch_size=1, + batch_index=0): kwargs = {} kwargs["return_log_probs"] = get_input_scalar_by_name( - request, 'return_log_probs') + request, 'return_log_probs', batch_size, batch_index) kwargs["return_context_logits"] = get_input_scalar_by_name( - request, 'return_context_logits') + request, 'return_context_logits', batch_size, batch_index) kwargs["return_generation_logits"] = get_input_scalar_by_name( - request, 'return_generation_logits') + request, 'return_generation_logits', batch_size, batch_index) kwargs["exclude_input_from_output"] = exclude_input_from_output kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.OutputConfig(**kwargs) -def get_external_draft_tokens_config_from_request(request): +def get_external_draft_tokens_config_from_request(request, + batch_size=1, + batch_index=0): kwargs = {} - draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids') + draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids', + batch_size, batch_index) if draft_input_ids is not None: - kwargs['tokens'] = draft_input_ids.tolist() - draft_logits = get_input_tensor_by_name(request, 'draft_logits') + kwargs['tokens'] = draft_input_ids[0].tolist() + draft_logits = get_input_tensor_by_name(request, 'draft_logits', + batch_size, batch_index) if draft_logits is not None: - kwargs['logits'] = from_numpy(draft_logits) + kwargs['logits'] = from_numpy(draft_logits).squeeze() kwargs['acceptance_threshold'] = get_input_scalar_by_name( - request, 'draft_acceptance_threshold') + request, 'draft_acceptance_threshold', batch_size, batch_index) kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.ExternalDraftTokensConfig(**kwargs) return None -def get_prompt_tuning_config_from_request(request): +def get_prompt_tuning_config_from_request(request, + batch_size=1, + batch_index=0): # prompt_vocab_size is unused by executor. kwargs = {} prompt_embedding_table = get_input_tensor_by_name( - request, 'prompt_embedding_table') + request, 'prompt_embedding_table', batch_size, batch_index) if prompt_embedding_table is not None: - kwargs["embedding_table"] = from_numpy(prompt_embedding_table) + if isinstance(prompt_embedding_table, np.ndarray): + kwargs["embedding_table"] = from_numpy( + prompt_embedding_table).squeeze() + elif isinstance(prompt_embedding_table, torch.Tensor): + kwargs["embedding_table"] = from_dlpack( + prompt_embedding_table.to_dlpack()).squeeze(dim=0) kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.PromptTuningConfig(**kwargs) return None -def get_lora_config_from_request(request): +def get_lora_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} - kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id') - lora_weights = get_input_tensor_by_name(request, 'lora_weights') + kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id', + batch_size, batch_index) + lora_weights = get_input_tensor_by_name(request, 'lora_weights', + batch_size, batch_index) if lora_weights is not None: - kwargs["weights"] = from_numpy(lora_weights) - lora_config = get_input_tensor_by_name(request, 'lora_config') + kwargs["weights"] = from_numpy(lora_weights).squeeze() + lora_config = get_input_tensor_by_name(request, 'lora_config', batch_size, + batch_index) if lora_config is not None: - kwargs["config"] = from_numpy(lora_config) + kwargs["config"] = from_numpy(lora_config).squeeze() kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.LoraConfig(**kwargs) @@ -184,49 +246,77 @@ def convert_request(request, exclude_input_from_output, decoupled): if input_token_ids is None: raise pb_utils.TritonModelException( "A value is required for input_ids") - input_token_ids = input_token_ids.tolist() - if len(input_token_ids) == 0: + if len(input_token_ids.shape) != 2: raise pb_utils.TritonModelException(f"Invalid format for input_ids") - inputs['input_token_ids'] = input_token_ids[0] - # input_lengths is not not used by executor. - inputs['max_new_tokens'] = get_input_scalar_by_name( - request, 'request_output_len') - if inputs['max_new_tokens'] is None: - raise pb_utils.TritonModelException( - "A value is required for request_output_len") - inputs['streaming'] = get_input_scalar_by_name(request, 'streaming') - if inputs['streaming'] and not decoupled: - raise pb_utils.TritonModelException( - "Streaming is only supported in decoupled mode.") - inputs['end_id'] = get_input_scalar_by_name(request, 'end_id') - inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id') - inputs['stop_words'] = convert_word_list( - get_input_tensor_by_name(request, 'stop_words_list')) - inputs['bad_words'] = convert_word_list( - get_input_tensor_by_name(request, 'bad_words_list')) - embedding_bias = get_input_tensor_by_name(request, 'embedding_bias') - if embedding_bias is not None and embedding_bias.size != 0: - inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze() - - sampling_config = get_sampling_config_from_request(request) - output_config = get_output_config_from_request(request, - exclude_input_from_output) - external_draft_tokens_config = get_external_draft_tokens_config_from_request( - request) - prompt_tuning_config = get_prompt_tuning_config_from_request(request) - lora_config = get_lora_config_from_request(request) - - return trtllm.Request( - **inputs, - sampling_config=sampling_config, - output_config=output_config, - external_draft_tokens_config=external_draft_tokens_config, - prompt_tuning_config=prompt_tuning_config, - lora_config=lora_config, - ) - - -def convert_response(response): + batch_size = input_token_ids.shape[0] + requests = [] + for batch_index in range(0, batch_size): + input_token_ids = get_input_tensor_by_name(request, 'input_ids', + batch_size, batch_index)[0] + if input_token_ids is None: + raise pb_utils.TritonModelException( + "A value is required for input_ids") + input_token_ids = input_token_ids.tolist() + if len(input_token_ids) == 0: + raise pb_utils.TritonModelException( + f"Invalid format for input_ids") + + input_length = get_input_scalar_by_name(request, 'input_lengths', + batch_size, batch_index) + if input_length is None: + input_length = len(input_token_ids) + # Trim input token ids with input_lengths + inputs['input_token_ids'] = input_token_ids[0:input_length] + + inputs['max_new_tokens'] = get_input_scalar_by_name( + request, 'request_output_len', batch_size, batch_index) + if inputs['max_new_tokens'] is None: + raise pb_utils.TritonModelException( + "A value is required for request_output_len") + inputs['streaming'] = get_input_scalar_by_name(request, 'streaming', + batch_size, batch_index) + if inputs['streaming'] and not decoupled: + raise pb_utils.TritonModelException( + "Streaming is only supported in decoupled mode.") + inputs['end_id'] = get_input_scalar_by_name(request, 'end_id', + batch_size, batch_index) + inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id', + batch_size, batch_index) + inputs['stop_words'] = convert_word_list( + get_input_tensor_by_name(request, 'stop_words_list', batch_size, + batch_index)) + inputs['bad_words'] = convert_word_list( + get_input_tensor_by_name(request, 'bad_words_list', batch_size, + batch_index)) + embedding_bias = get_input_tensor_by_name(request, 'embedding_bias', + batch_size, batch_index) + if embedding_bias is not None and embedding_bias.size != 0: + inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze() + + sampling_config = get_sampling_config_from_request( + request, batch_size, batch_index) + output_config = get_output_config_from_request( + request, exclude_input_from_output, batch_size, batch_index) + external_draft_tokens_config = get_external_draft_tokens_config_from_request( + request, batch_size, batch_index) + prompt_tuning_config = get_prompt_tuning_config_from_request( + request, batch_size, batch_index) + lora_config = get_lora_config_from_request(request, batch_size, + batch_index) + + requests.append( + trtllm.Request( + **inputs, + sampling_config=sampling_config, + output_config=output_config, + external_draft_tokens_config=external_draft_tokens_config, + prompt_tuning_config=prompt_tuning_config, + lora_config=lora_config, + )) + return requests + + +def convert_response(response, batch_index): if response.has_error(): return pb_utils.InferenceResponse(output_tensors=[], error=pb_utils.TritonError( @@ -266,6 +356,10 @@ def convert_response(response): np.expand_dims(np.array(result.generation_logits, np.float32), 0) if result.generation_logits is not None else np.zeros( (1, 1, 1, 1), np.float32))) + output_tensors.append( + pb_utils.Tensor("batch_index", + np.expand_dims(np.array([batch_index], np.int32), 0))) + return pb_utils.InferenceResponse(output_tensors), result.is_final @@ -313,7 +407,8 @@ def convert_decoding_mode(decoding_mode: str): def convert_timestamp_to_seconds(timestamp: str): return int( - datetime.datetime.strptime(timestamp, "%m-%d-%Y %H:%M:%S").timestamp()) + datetime.datetime.strptime(timestamp, + "%m-%d-%Y %H:%M:%S.%f").timestamp()) class TritonPythonModel: @@ -337,8 +432,6 @@ def get_kv_cache_config(self, model_config): get_parameter(model_config, "max_tokens_in_paged_kv_cache", int), "sink_token_length": get_parameter(model_config, "sink_token_length", int), - "max_attention_window": - get_parameter(model_config, "max_attention_window_size", int), "free_gpu_memory_fraction": get_parameter(model_config, "kv_cache_free_gpu_mem_fraction", float), @@ -347,6 +440,12 @@ def get_kv_cache_config(self, model_config): "onboard_blocks": get_parameter(model_config, "kv_cache_onboard_blocks", bool), } + max_attention_window_size = get_parameter(model_config, + "max_attention_window_size") + if max_attention_window_size: + kwargs["max_attention_window"] = [ + int(x) for x in max_attention_window_size.split(",") + ] kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.KvCacheConfig(**kwargs) @@ -402,6 +501,16 @@ def get_decoding_config(self, model_config): kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.DecodingConfig(**kwargs) + def get_extended_runtime_perf_knob_config(self, model_config): + kwargs = { + "multi_block_mode": + get_parameter(model_config, "multi_block_mode", bool), + "enable_context_fmha_fp32_acc": + get_parameter(model_config, "enable_context_fmha_fp32_acc", bool) + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs) + def get_executor_config(self, model_config): kwargs = { "max_beam_width": @@ -423,6 +532,16 @@ def get_executor_config(self, model_config): self.get_peft_cache_config(model_config), "decoding_config": self.get_decoding_config(model_config), + "max_queue_size": + model_config.get( + "dynamic_batching", + {}, + ).get( + "default_queue_policy", + {}, + ).get("max_queue_size"), + "extended_runtime_perf_knob_config": + self.get_extended_runtime_perf_knob_config(model_config) } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.ExecutorConfig(**kwargs) @@ -619,8 +738,9 @@ def initialize(self, args): args["model_version"], is_v1_model=executor_config.batching_type == trtllm.BatchingType.STATIC) - self.triton_id_to_req_id = {} - self.req_id_to_response_sender = {} + self.triton_user_id_to_req_ids = {} + self.triton_req_id_to_req_ids = {} + self.req_id_to_request_data = {} self.lock = Lock() self.running = False self.awaiter_thread = Thread(target=self.awaiter_loop) @@ -635,17 +755,19 @@ def initialize(self, args): # In leader mode, worker ranks will wait here until leader is done. self.executor.shutdown() - def handle_stop_request(self, triton_id, response_sender): - if triton_id is None or triton_id == "": + def handle_stop_request(self, triton_user_id, response_sender): + if triton_user_id is None or triton_user_id == "": response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( "A request id must be provided for request cancellation")), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) return - if triton_id in self.triton_id_to_req_id: - req_id = self.triton_id_to_req_id[triton_id] - self.executor.cancel_request(req_id) + with self.lock: + if triton_user_id in self.triton_user_id_to_req_ids: + req_ids = self.triton_user_id_to_req_ids[triton_user_id] + for req_id in req_ids: + self.executor.cancel_request(req_id) response_sender.send( pb_utils.InferenceResponse(), @@ -672,17 +794,33 @@ def execute(self, requests): return # Convert to executor requests. + triton_requests = [] executor_requests = [] + batch_indices = [] + triton_user_ids = [] + triton_req_ids = [] + for request in requests: + + triton_user_id = request.request_id() + response_sender = request.get_response_sender() - if get_input_scalar_by_name(request, 'stop'): - self.handle_stop_request(request.request_id(), response_sender) + stop = get_input_scalar_by_name(request, 'stop') + + if stop: + self.handle_stop_request(triton_user_id, response_sender) else: + #Unique request id used to identify each triton request + triton_req_id = str(randint(0, sys.maxsize)) + self.triton_req_id_to_req_ids[triton_req_id] = set() + if triton_user_id is not None and triton_user_id != "": + self.triton_user_id_to_req_ids[triton_user_id] = set() + try: - converted = convert_request(request, - self.exclude_input_from_output, - self.decoupled) + converted_reqs = convert_request( + request, self.exclude_input_from_output, + self.decoupled) except Exception as e: response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( @@ -690,16 +828,26 @@ def execute(self, requests): )), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) else: - triton_requests.append(request) - executor_requests.append(converted) + for batch_index, converted_req in enumerate( + converted_reqs): + triton_requests.append(request) + executor_requests.append(converted_req) + triton_user_ids.append(triton_user_id) + triton_req_ids.append(triton_req_id) + batch_indices.append(batch_index) with self.lock: request_ids = self.executor.enqueue_requests(executor_requests) - for req_id, request in zip(request_ids, triton_requests): - triton_id = request.request_id() - self.req_id_to_response_sender[ - req_id] = triton_id, request.get_response_sender() - self.triton_id_to_req_id[triton_id] = req_id + for req_id, triton_req_id, triton_user_id, triton_request, batch_index in zip( + request_ids, triton_req_ids, triton_user_ids, + triton_requests, batch_indices): + self.req_id_to_request_data[ + req_id] = triton_req_id, triton_user_id, batch_index, triton_request.get_response_sender( + ) + self.triton_req_id_to_req_ids[triton_req_id].add(req_id) + if triton_user_id is not None and triton_user_id != "": + self.triton_user_id_to_req_ids[triton_user_id].add(req_id) + return None def awaiter_loop(self): @@ -709,21 +857,37 @@ def awaiter_loop(self): timeout=datetime.timedelta(milliseconds=1)): req_id = response.request_id with self.lock: - if req_id not in self.req_id_to_response_sender: + if req_id not in self.req_id_to_request_data: continue - triton_id, response_sender = self.req_id_to_response_sender[ + triton_req_id, triton_user_id, batch_index, response_sender = self.req_id_to_request_data[ req_id] - triton_response, is_final = convert_response(response) + triton_response, is_final = convert_response( + response, batch_index) + + triton_request_final = False + if is_final: + with self.lock: + # Check if all executor requests part of that triton request are finished + self.triton_req_id_to_req_ids[triton_req_id].remove( + req_id) + if len(self.triton_req_id_to_req_ids[triton_req_id] + ) == 0: + pb_utils.Logger.log_info( + f"DELETING Req id {req_id}, triton_req_id {triton_req_id} " + ) + triton_request_final = True + del self.triton_req_id_to_req_ids[triton_req_id] + if triton_user_id is not None and triton_user_id != "": + del self.triton_user_id_to_req_ids[ + triton_user_id] + del self.req_id_to_request_data[req_id] + response_sender.send( triton_response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - if is_final else 0) + if triton_request_final else 0) - if is_final: - with self.lock: - del self.triton_id_to_req_id[triton_id] - del self.req_id_to_response_sender[req_id] # Remove local reference so response_sender can be cleaned properly. del response_sender @@ -732,8 +896,9 @@ def cancellation_loop(self): while self.running: time.sleep(self.cancellation_check_period_ms / 1000.0) with self.lock: - for req_id, (triton_id, response_sender - ) in self.req_id_to_response_sender.items(): + for req_id, (triton_req_id, triton_user_id, batch_index, + response_sender + ) in self.req_id_to_request_data.items(): if response_sender.is_cancelled(): self.executor.cancel_request(req_id) # Remove local reference so response_sender can be cleaned properly. diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt b/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt index fd6c6d0..0a7ea6e 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt @@ -35,6 +35,7 @@ model_transaction_policy { dynamic_batching { preferred_batch_size: [ ${triton_max_batch_size} ] max_queue_delay_microseconds: ${max_queue_delay_microseconds} + default_queue_policy: { max_queue_size: ${max_queue_size} } } input [ @@ -54,6 +55,7 @@ input [ name: "request_output_len" data_type: TYPE_INT32 dims: [ 1 ] + reshape: { shape: [ ] } }, { name: "draft_input_ids" @@ -255,12 +257,14 @@ input [ name: "stop" data_type: TYPE_BOOL dims: [ 1 ] + reshape: { shape: [ ] } optional: true }, { name: "streaming" data_type: TYPE_BOOL dims: [ 1 ] + reshape: { shape: [ ] } optional: true }, { @@ -350,6 +354,11 @@ output [ name: "generation_logits" data_type: TYPE_FP32 dims: [ -1, -1, -1 ] + }, + { + name: "batch_index" + data_type: TYPE_INT32 + dims: [ 1 ] } ] instance_group [ @@ -539,3 +548,15 @@ parameters: { string_value: "${gpu_weights_percent}" } } +parameters: { + key: "enable_context_fmha_fp32_acc" + value: { + string_value: "${enable_context_fmha_fp32_acc}" + } +} +parameters: { + key: "multi_block_mode" + value: { + string_value: "${multi_block_mode}" + } +} diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/decode.py b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/decode.py index de9e28b..986faef 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/decode.py +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/decode.py @@ -29,6 +29,7 @@ from typing import Optional import numpy as np +import torch class RequestValidationError(Exception): @@ -41,7 +42,10 @@ def _validate_that(condition: bool, msg: str): def _validate_non_empty(data, msg: str): - _validate_that(data is not None and data.size > 0, msg) + if isinstance(data, torch.Tensor): + _validate_that(data is not None and data.numel() > 0, msg) + else: + _validate_that(data is not None and data.size > 0, msg) def _validate_single_gt_0(data, msg: str): @@ -59,7 +63,8 @@ def _single_value(data: Optional[np.ndarray]): class Request: text_input: np.ndarray = np.array([]) decoder_text_input: np.ndarray = None - max_tokens: np.ndarray = np.array([]) + image_input: Optional[np.ndarray] = None + max_tokens: Optional[np.ndarray] = None bad_words: Optional[np.ndarray] = None stop_words: Optional[np.ndarray] = None end_id: Optional[np.ndarray] = None @@ -91,13 +96,12 @@ def validate(self): "max_tokens must be a single value > 0") num_draft_tokens = _single_value(self.num_draft_tokens) - stream = _single_value(self.stream) _single_value(self.return_generation_logits) context_logits = _single_value(self.return_context_logits) if num_draft_tokens: _validate_that( - not stream, + not self.stream.any(), "streaming is not supported with speculative decoding") _validate_that( not context_logits, @@ -127,18 +131,22 @@ def with_new_inputs(cls, other, input_ids: Optional[np.ndarray] = None, input_lengths: Optional[np.ndarray] = None): - return cls( - input_ids=(input_ids - if input_ids is not None else other.input_ids), - input_lengths=(input_lengths if input_lengths is not None else - other.input_lengths), - decoder_input_ids=other.decoder_input_ids, - decoder_input_lengths=other.decoder_input_lengths, - bad_words_list=other.bad_words_list, - stop_words_list=other.stop_words_list, - end_id=other.end_id, - pad_id=other.pad_id, - ) + return cls(input_ids=(input_ids + if input_ids is not None else other.input_ids), + input_lengths=(input_lengths if input_lengths is not None + else other.input_lengths), + decoder_input_ids=other.decoder_input_ids, + decoder_input_lengths=other.decoder_input_lengths, + bad_words_list=other.bad_words_list, + stop_words_list=other.stop_words_list, + end_id=other.end_id, + pad_id=other.pad_id) + + +@dataclass +class MultimodalEncResponse: + prompt_embedding_table: Optional[torch.Tensor] = None + prompt_vocab_size: Optional[np.ndarray] = None @dataclass @@ -149,6 +157,7 @@ class GenerationResponse: output_log_probs: Optional[np.ndarray] = None context_logits: Optional[np.ndarray] = None generation_logits: Optional[np.ndarray] = None + batch_index: Optional[np.ndarray] = None @dataclass @@ -158,6 +167,7 @@ class Response: output_log_probs: Optional[np.ndarray] = None context_logits: Optional[np.ndarray] = None generation_logits: Optional[np.ndarray] = None + batch_index: Optional[np.ndarray] = None def __eq__(self, o) -> bool: """Just for testing""" @@ -166,8 +176,9 @@ def __eq__(self, o) -> bool: return (np.array_equal(self.text_output, o.text_output) and np.array_equal(self.cum_log_probs, o.cum_log_probs) and np.array_equal(self.output_log_probs, o.output_log_probs) - and np.array_equal(self.context_logits, o.context_logits) and - np.array_equal(self.generation_logits, o.generation_logits)) + and np.array_equal(self.context_logits, o.context_logits) + and np.array_equal(self.generation_logits, o.generation_logits) + and np.array_equal(self.batch_index, o.batch_index)) class Decoder: @@ -176,24 +187,41 @@ def __init__(self, streaming=False, accumulate=False): self._streaming = streaming self._accumulate = accumulate - self._accumulated_tokens = None + self._accumulated_tokens = [] def decode(self, request: Request, - speculative_decoding=False) -> Generator[Response, None, None]: + speculative_decoding=False, + is_multimodal=False) -> Generator[Response, None, None]: + + batch_size = request.text_input.shape[0] + self._accumulated_tokens = [None] * batch_size preproc_response = self.preprocess(request) + multimodal_enc_response = None + if is_multimodal: + multimodal_enc_response = self._multimodal_enc_generate(request) + if speculative_decoding: + if batch_size > 1: + raise Exception( + "speculative decoding is not supported with batch size > 1" + ) for gen_response in self._spec_generate(preproc_response, request): - yield self.postprocess(gen_response) + yield self.postprocess(gen_response, batch_size) else: - if not self._streaming: + if not self._streaming and batch_size == 1: gen_response = self._generate_non_streaming( - preproc_response, request) - yield self.postprocess(gen_response) + preproc_response, + request, + multimodal_enc_response=multimodal_enc_response) + yield self.postprocess(gen_response, batch_size) else: - for gen_response in self._generate(preproc_response, request): - yield self.postprocess(gen_response) + for gen_response in self._generate( + preproc_response, + request, + multimodal_enc_response=multimodal_enc_response): + yield self.postprocess(gen_response, batch_size) def encountered_stop_words(self, input_ids, stop_words_ids): for stop_word_ids in stop_words_ids: @@ -205,6 +233,10 @@ def _spec_generate( self, preproc: PreprocResponse, request: Request) -> Generator[GenerationResponse, None, None]: + if preproc.input_ids.shape[0] > 1: + raise Exception( + "Speculative decoding does not support batch size > 1.") + prompt_input_ids: np.ndarray = preproc.input_ids[0] input_ids: np.ndarray = prompt_input_ids output_len: int = request.max_tokens[0][0] @@ -282,23 +314,32 @@ def _draft_generate_non_streaming( num_draft_tokens: int) -> GenerationResponse: raise NotImplementedError() + def _multimodal_enc_generate( + self, + request: Request, + ) -> MultimodalEncResponse: + raise NotImplementedError() + def _generate( self, preproc: PreprocResponse, request: Request, - draft_request: Optional[DraftRequest] = None + draft_request: Optional[DraftRequest] = None, + multimodal_enc_response: Optional[MultimodalEncResponse] = None, ) -> Generator[GenerationResponse, None, None]: raise NotImplementedError() def _generate_non_streaming( - self, - preproc: PreprocResponse, - request: Request, - draft_request: Optional[DraftRequest] = None + self, + preproc: PreprocResponse, + request: Request, + draft_request: Optional[DraftRequest] = None, + multimodal_enc_response: Optional[MultimodalEncResponse] = None, ) -> GenerationResponse: raise NotImplementedError() - def postprocess(self, gen_response: GenerationResponse) -> Response: + def postprocess(self, gen_response: GenerationResponse, + batch_size) -> Response: if self._accumulate and self._streaming: new_tokens: np.ndarray = gen_response.output_ids if new_tokens.ndim != 3: @@ -310,12 +351,24 @@ def postprocess(self, gen_response: GenerationResponse) -> Response: "Accumulation of tokens is only implemented for beam width = 1" ) - self._accumulated_tokens = new_tokens if ( - self._accumulated_tokens is None) else np.concatenate( - (self._accumulated_tokens, new_tokens), axis=2) - sequence_lengths = np.array([[self._accumulated_tokens.shape[2]]], - dtype=np.int32) - return self._postprocess(self._accumulated_tokens, + batch_index = gen_response.batch_index + if batch_index.ndim != 2: + raise Exception("Expected batch_index tensor to have 2 dims.") + if batch_index.shape[0] != 1: + raise Exception("Expected batch size of 1") + if batch_index.shape[1] != 1: + raise Exception("Expected only one batch_index") + + batch_index = batch_index[0][0] + + self._accumulated_tokens[batch_index] = new_tokens if ( + self._accumulated_tokens[batch_index] is None + ) else np.concatenate( + (self._accumulated_tokens[batch_index], new_tokens), axis=2) + sequence_lengths = np.array( + [[self._accumulated_tokens[batch_index].shape[2]]], + dtype=np.int32) + return self._postprocess(self._accumulated_tokens[batch_index], sequence_lengths, gen_response) else: return self._postprocess(gen_response.output_ids, None, @@ -330,4 +383,4 @@ def preprocess(self, request: Request) -> PreprocResponse: raise NotImplementedError() def reset_decoder(self): - self._accumulated_tokens = None + self._accumulated_tokens = [] diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/triton_decoder.py b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/triton_decoder.py index 456ded5..fc9d881 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/triton_decoder.py +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/lib/triton_decoder.py @@ -30,6 +30,7 @@ import numpy as np import triton_python_backend_utils as pb_utils from lib.decode import * +from torch.utils.dlpack import from_dlpack, to_dlpack from typing_extensions import override @@ -41,12 +42,14 @@ def __init__(self, preproc_model_name="preprocessing", postproc_model_name="postprocessing", llm_model_name="tensorrt_llm", - draft_llm_model_name: Optional[str] = None): + draft_llm_model_name: Optional[str] = None, + multimodal_encoders_name: Optional[str] = None): super().__init__(streaming=streaming, accumulate=accumulate) self.preproc_model_name = preproc_model_name self.postproc_model_name = postproc_model_name self.llm_model_name = llm_model_name self.draft_llm_model_name = draft_llm_model_name + self.multimodal_encoders_name = multimodal_encoders_name self._preproc_outputs = [ "INPUT_ID", @@ -60,13 +63,14 @@ def __init__(self, "OUT_END_ID", ] + self._multimodal_enc_outputs = [ + "OUT_PROMPT_EMBEDDING_TABLE", "OUT_PROMPT_VOCAB_SIZE" + ] + self._llm_outputs = [ - "output_ids", - "sequence_length", - "cum_log_probs", - "output_log_probs", - "context_logits", - "generation_logits", + "output_ids", "sequence_length", "cum_log_probs", + "output_log_probs", "context_logits", "generation_logits", + "batch_index" ] self._postproc_outputs = [ @@ -76,6 +80,7 @@ def __init__(self, self.input_names = [ "text_input", "decoder_text_input", + "image_input", "max_tokens", "bad_words", "stop_words", @@ -145,7 +150,8 @@ def create_triton_response(self, response: Response): "cum_log_probs": "cum_log_probs", "output_log_probs": "output_log_probs", "context_logits": "context_logits", - "generation_logits": "generation_logits" + "generation_logits": "generation_logits", + "batch_index": "batch_index" } tensors = self.create_triton_tensors(response, name_map) return pb_utils.InferenceResponse(output_tensors=tensors) @@ -173,7 +179,11 @@ def convert_triton_response(self, if tensor is None: continue triton_name = tensor.name() - value = tensor.as_numpy() + if tensor.is_cpu(): + value = tensor.as_numpy() + else: + # If the tensor is in GPU memory make it torch.Tensor type + value = from_dlpack(tensor.to_dlpack()) target_name = triton_name if name_map and triton_name in name_map: target_name = name_map[triton_name] @@ -203,7 +213,12 @@ def create_triton_tensors(self, obj, name_map: dict): value = getattr(obj, name) if value is None: continue - t = pb_utils.Tensor(triton_name, self.__undo_reshape(value, name)) + if isinstance(value, np.ndarray): + t = pb_utils.Tensor(triton_name, + self.__undo_reshape(value, name)) + elif isinstance(value, torch.Tensor): + t = pb_utils.Tensor.from_dlpack( + triton_name, to_dlpack(self.__undo_reshape(value, name))) tensors.append(t) return tensors @@ -246,6 +261,31 @@ def _get_preproc_response(self, triton_output): return self.convert_triton_response(triton_output, PreprocResponse, name_map) + @override + def _multimodal_enc_generate(self, + request: Request) -> MultimodalEncResponse: + input_tensors = self._get_multimodal_enc_tensors(request) + triton_req = pb_utils.InferenceRequest( + model_name=self.multimodal_encoders_name, + inputs=input_tensors, + requested_output_names=self._multimodal_enc_outputs) + triton_output = self._exec_triton_request_single(triton_req) + return self._get_multimodal_enc_response(triton_output) + + def _get_multimodal_enc_tensors(self, preproc: PreprocResponse): + name_map = { + "image_input": "IMAGE", + } + return self.create_triton_tensors(preproc, name_map) + + def _get_multimodal_enc_response(self, triton_output): + name_map = { + "OUT_PROMPT_EMBEDDING_TABLE": "prompt_embedding_table", + "OUT_PROMPT_VOCAB_SIZE": "prompt_vocab_size", + } + return self.convert_triton_response(triton_output, + MultimodalEncResponse, name_map) + @override def _draft_generate_non_streaming( self, preproc: PreprocResponse, request: Request, @@ -265,10 +305,15 @@ def _generate( self, preproc: PreprocResponse, request: Request, - draft_request: Optional[DraftRequest] = None + draft_request: Optional[DraftRequest] = None, + multimodal_enc_response: Optional[MultimodalEncResponse] = None ) -> Generator[GenerationResponse, None, None]: - input_tensors = self._get_llm_tensors(preproc, request, None, - draft_request) + input_tensors = self._get_llm_tensors( + preproc, + request, + None, + draft_request, + multimodal_enc_response=multimodal_enc_response) triton_req = pb_utils.InferenceRequest( model_name=self.llm_model_name, inputs=input_tensors, @@ -278,13 +323,18 @@ def _generate( @override def _generate_non_streaming( - self, - preproc: PreprocResponse, - request: Request, - draft_request: Optional[DraftRequest] = None + self, + preproc: PreprocResponse, + request: Request, + draft_request: Optional[DraftRequest] = None, + multimodal_enc_response: Optional[MultimodalEncResponse] = None ) -> GenerationResponse: - input_tensors = self._get_llm_tensors(preproc, request, None, - draft_request) + input_tensors = self._get_llm_tensors( + preproc, + request, + None, + draft_request, + multimodal_enc_response=multimodal_enc_response) triton_req = pb_utils.InferenceRequest( model_name=self.llm_model_name, inputs=input_tensors, @@ -292,14 +342,19 @@ def _generate_non_streaming( r = self._exec_triton_request_single(triton_req) return self._get_llm_response(r) - def _get_llm_tensors(self, - preproc: PreprocResponse, - request: Request, - num_output_tokens: Optional[int] = None, - draft_request: Optional[DraftRequest] = None, - is_draft_model_request: bool = False): + def _get_llm_tensors( + self, + preproc: PreprocResponse, + request: Request, + num_output_tokens: Optional[int] = None, + draft_request: Optional[DraftRequest] = None, + is_draft_model_request: bool = False, + multimodal_enc_response: MultimodalEncResponse = None): tensors = [] tensors.extend(self._get_tensors_from_preproc(preproc)) + if multimodal_enc_response is not None: + tensors.extend( + self._get_tensors_from_multimodal_enc(multimodal_enc_response)) tensors.extend( self._get_llm_tensors_from_request(request, num_output_tokens, draft_request, @@ -319,6 +374,14 @@ def _get_tensors_from_preproc(self, preproc: PreprocResponse): } return self.create_triton_tensors(preproc, name_map) + def _get_tensors_from_multimodal_enc( + self, multimodal_enc_response: MultimodalEncResponse): + name_map = { + "prompt_embedding_table": "prompt_embedding_table", + "prompt_vocab_size": "prompt_vocab_size", + } + return self.create_triton_tensors(multimodal_enc_response, name_map) + def _get_llm_tensors_from_request( self, request: Request, @@ -329,6 +392,7 @@ def _get_llm_tensors_from_request( "beam_width": "beam_width", "top_k": "runtime_top_k", "top_p": "runtime_top_p", + "temperature": "temperature", "length_penalty": "len_penalty", "repetition_penalty": "repetition_penalty", "min_length": "min_length", @@ -340,23 +404,29 @@ def _get_llm_tensors_from_request( "prompt_embedding_table": "prompt_embedding_table", "prompt_vocab_size": "prompt_vocab_size", } + batch_size = request.text_input.shape[0] tensors = self.create_triton_tensors(request, name_map) + out_len_tensor = None + if request.max_tokens is not None: + out_len_tensor = request.max_tokens - out_len = request.max_tokens[0][0] if request.max_tokens else None + out_len = None if num_output_tokens is not None: out_len = num_output_tokens elif draft_request: - if draft_request.draft_input_ids is not None: - out_len = len(draft_request.draft_input_ids[0]) + 1 - else: - out_len = 1 + out_len = len( + draft_request.draft_input_ids[0] + ) + 1 if draft_request.draft_input_ids is not None else 1 + + if out_len is not None: + out_len_tensor = [[out_len]] * batch_size - if out_len is None: + if out_len_tensor is None: raise Exception("Could not determine request_output_len") else: tensors.append( pb_utils.Tensor("request_output_len", - np.array([[out_len]], dtype=np.int32))) + np.array(out_len_tensor, dtype=np.int32))) if draft_request: if draft_request.draft_input_ids is not None: @@ -369,24 +439,35 @@ def _get_llm_tensors_from_request( pb_utils.Tensor("draft_logits", draft_request.draft_logits)) - return_context_logits = False - return_generation_logits = False + return_context_logits_data = [False] + return_generation_logits_data = [False] if draft_request is None: if is_draft_model_request: - return_generation_logits = request.use_draft_logits[ - 0] if request.use_draft_logits is not None else False + return_generation_logits_data = request.use_draft_logits if request.use_draft_logits is not None else [ + False + ] else: - return_context_logits = request.return_context_logits[ - 0] if request.return_context_logits is not None else False - return_generation_logits = request.return_generation_logits[ - 0] if request.return_generation_logits is not None else False + return_context_logits_data = request.return_context_logits if request.return_context_logits is not None else [ + False + ] + return_generation_logits_data = request.return_generation_logits if request.return_generation_logits is not None else [ + False + ] + return_context_logits = np.array([return_context_logits_data] * + batch_size, + dtype=bool) + return_generation_logits = np.array([return_generation_logits_data] * + batch_size, + dtype=bool) + + assert len(return_context_logits.shape) == 2 + assert len(return_generation_logits.shape) == 2 tensors.append( - pb_utils.Tensor("return_context_logits", - np.array([[return_context_logits]]))) + pb_utils.Tensor("return_context_logits", return_context_logits)) tensors.append( pb_utils.Tensor("return_generation_logits", - np.array([[return_generation_logits]]))) + return_generation_logits)) return tensors def _get_llm_response(self, triton_output): @@ -397,6 +478,7 @@ def _get_llm_response(self, triton_output): "output_log_probs": "output_log_probs", "context_logits": "context_logits", "generation_logits": "generation_logits", + "batch_index": "batch_index", } return self.convert_triton_response(triton_output, GenerationResponse, name_map) @@ -436,5 +518,6 @@ def _get_response(self, triton_output, gen_res: GenerationResponse): cum_log_probs=gen_res.cum_log_probs, output_log_probs=gen_res.output_log_probs, context_logits=gen_res.context_logits, - generation_logits=gen_res.generation_logits) + generation_logits=gen_res.generation_logits, + batch_index=gen_res.batch_index) return response diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/model.py b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/model.py index 609e323..e0649c5 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/model.py +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/1/model.py @@ -31,6 +31,11 @@ from lib.triton_decoder import TritonDecoder +def get_valid_param_value(param, default_value=''): + value = param.get('string_value', '') + return default_value if value.startswith('${') or value == '' else value + + class TritonPythonModel: def initialize(self, args): @@ -40,10 +45,8 @@ def initialize(self, args): params = model_config['parameters'] - accumulate_tokens_str = '' - if 'accumulate_tokens' in params: - accumulate_tokens_str = params['accumulate_tokens']['string_value'] - + accumulate_tokens_str = get_valid_param_value( + params.get('accumulate_tokens', {})) self.accumulate_tokens = accumulate_tokens_str.lower() in [ 'true', 'yes', '1', 't' ] @@ -53,14 +56,16 @@ def initialize(self, args): self.logger = pb_utils.Logger - self.llm_model_name = "tensorrt_llm" - if "tensorrt_llm_model_name" in params: - self.llm_model_name = params["tensorrt_llm_model_name"][ - "string_value"] - self.draft_llm_model_name = None - if "tensorrt_llm_draft_model_name" in params: - self.draft_llm_model_name = params[ - "tensorrt_llm_draft_model_name"]["string_value"] + default_tensorrt_llm_model_name = 'tensorrt_llm' + self.llm_model_name = get_valid_param_value( + params.get('tensorrt_llm_model_name', {}), + default_tensorrt_llm_model_name) + + self.draft_llm_model_name = get_valid_param_value( + params.get('tensorrt_llm_draft_model_name', {}), None) + + self.multimodal_encoders_name = get_valid_param_value( + params.get('multimodal_encoders_name', {}), None) self.decoder = TritonDecoder( streaming=self.decoupled, @@ -68,7 +73,8 @@ def initialize(self, args): preproc_model_name="preprocessing", postproc_model_name="postprocessing", llm_model_name=self.llm_model_name, - draft_llm_model_name=self.draft_llm_model_name) + draft_llm_model_name=self.draft_llm_model_name, + multimodal_encoders_name=self.multimodal_encoders_name) def execute(self, requests): @@ -88,8 +94,16 @@ def execute(self, requests): raise Exception( "cannot perform speculative decoding without draft model" ) + is_multimodal = req.image_input is not None + + if speculative_decode and is_multimodal: + raise Exception( + "Multimodal and speculative decoding is not currently supported" + ) res_gen = self.decoder.decode( - req, speculative_decoding=speculative_decode) + req, + speculative_decoding=speculative_decode, + is_multimodal=is_multimodal) for res in res_gen: triton_response = self.decoder.create_triton_response(res) diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt index 7220246..29112cb 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt @@ -35,18 +35,24 @@ input [ { name: "text_input" data_type: TYPE_STRING - dims: [ -1 ] + dims: [ 1 ] }, { name: "decoder_text_input" data_type: TYPE_STRING - dims: [ -1 ] + dims: [ 1 ] + optional: true + }, + { + name: "image_input" + data_type: TYPE_FP16 + dims: [ 3, -1, -1 ] optional: true }, { name: "max_tokens" data_type: TYPE_INT32 - dims: [ -1 ] + dims: [ 1 ] }, { name: "bad_words" @@ -222,6 +228,11 @@ output [ name: "generation_logits" data_type: TYPE_FP32 dims: [ -1, -1, -1 ] + }, + { + name: "batch_index" + data_type: TYPE_INT32 + dims: [ 1 ] } ] @@ -243,6 +254,12 @@ parameters: { string_value: "${tensorrt_llm_draft_model_name}" } } +parameters: { + key: "multimodal_encoders_name" + value: { + string_value: "${multimodal_encoders_name}" + } +} instance_group [ { diff --git a/src/triton_cli/trt_llm/builder.py b/src/triton_cli/trt_llm/builder.py index 19ec9b7..80d2482 100644 --- a/src/triton_cli/trt_llm/builder.py +++ b/src/triton_cli/trt_llm/builder.py @@ -37,6 +37,8 @@ "meta-llama/Llama-2-7b-chat-hf": "llama", "meta-llama/Meta-Llama-3-8B": "llama", "meta-llama/Meta-Llama-3-8B-Instruct": "llama", + "meta-llama/Meta-Llama-3.1-8B": "llama", + "meta-llama/Meta-Llama-3.1-8B-Instruct": "llama", "facebook/opt-125m": "opt", "gpt2": "gpt2", } diff --git a/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py b/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py index 129d170..2038a61 100644 --- a/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py +++ b/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py @@ -1,46 +1,27 @@ import argparse -import functools -import json -import logging import os import shutil -import tarfile import time import traceback -from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Optional, Tuple, Union - -import numpy as np -import safetensors -import torch -import torch.nn as nn -import yaml -from tqdm import tqdm -from transformers import (AutoConfig, AutoModelForCausalLM, - AutoModelForVision2Seq, AutoTokenizer, GPT2Config) -from transformers.models.gpt2.modeling_gpt2 import GPT2Block -from transformers.pytorch_utils import Conv1D import tensorrt_llm -from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch +from tensorrt_llm._utils import release_gc from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.convert_utils import (load_calib_dataset, - retrieved_layer_index_from_name) +from tensorrt_llm.models import GPTForCausalLM +from tensorrt_llm.models.gpt.convert import (UnpackedNemoCheckpointDir, + copy_tokenizer_files, load_hf_gpt, + unpack_nemo_ckpt, + update_tokenizer_paths) +from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization import QuantAlgo -LOGGER = logging.getLogger(__name__) - def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str, default=None) parser.add_argument('--nemo_ckpt_path', type=str, default=None) - parser.add_argument('--load_nemo_on_gpu', - default=False, - action="store_true", - help="Whether to load NeMo checkpoint on GPU") parser.add_argument( '--gpt_variant', default=None, @@ -63,6 +44,7 @@ def parse_arguments(): type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) + parser.add_argument("--load_model_on_cpu", action="store_true") parser.add_argument( '--use_parallel_embedding', action="store_true", @@ -173,1727 +155,146 @@ def parse_arguments(): return args -def rename_keys(model_state, layer_rename_config: Dict[str, str]): - if not layer_rename_config: - return model_state - - new_state_dict = {} - for key, value in model_state.items(): - for old, new in layer_rename_config.items(): - key = key.replace(old, new) - assert key not in new_state_dict, f"Key already exists: {key}" - new_state_dict[key] = value - - return new_state_dict - - -def load_gpt_config(model_dir: str, - tp_size: int, - gpt_variant: Optional[str] = None) -> GPT2Config: - config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) - - if gpt_variant is None: - print("Inferring gpt variant from path...") - for v in [ - 'starcoder2', 'starcoder', 'santacoder', 'gpt2', 'persimmon', - 'kosmos-2', 'jais' - ]: - if v in config._name_or_path or ('fuyu' in config._name_or_path - and v == 'persimmon'): - gpt_variant = v - break - assert gpt_variant in [ - 'gpt2', 'santacoder', 'starcoder', 'starcoder2', 'persimmon', - 'kosmos-2', 'jais' - ] - print(f"Gpt variant: {gpt_variant}") - - if gpt_variant in ['starcoder2', 'persimmon']: - config.n_embd = config.hidden_size - config.n_inner = config.intermediate_size - config.n_head = config.num_attention_heads - config.n_kv_head = config.num_key_value_heads if hasattr( - config, 'num_key_value_heads') else config.n_head - config.n_layer = config.num_hidden_layers - config.n_positions = config.max_position_embeddings - config.activation_function = 'gelu' if gpt_variant == 'starcoder2' else 'squared-relu' - config.layer_norm_epsilon = config.norm_epsilon if gpt_variant == 'starcoder2' else config.layer_norm_eps - config.bias = config.use_bias if gpt_variant == 'starcoder2' else True - config.position_embedding_type = 'rope_gpt_neox' - config.rotary_base = config.rope_theta - config.rotary_pct = getattr(config, 'partial_rotary_factor', 1.0) - elif gpt_variant == "kosmos-2": - config.n_embd = config.text_config.embed_dim - config.n_inner = config.text_config.ffn_dim - config.n_head = config.text_config.attention_heads - config.n_kv_head = config.n_head - config.n_layer = config.text_config.layers - config.n_positions = config.text_config.max_position_embeddings - config.activation_function = config.text_config.activation_function - config.layer_norm_epsilon = config.text_config.layer_norm_eps - config.bias = True - config.vocab_size = config.text_config.vocab_size - else: - if config.n_inner is None: - config.n_inner = config.n_embd * 4 - if gpt_variant in ['santacoder', 'starcoder']: - config.n_kv_head = 1 - else: - config.n_kv_head = config.n_head - if gpt_variant == 'jais': - config.q_scaling = (config.n_embd // config.n_head)**0.5 - if hasattr(config, 'width_scale'): - config.logits_scale = config.width_scale - else: - config.logits_scale = config.mup_output_alpha * config.mup_width_scale - - if hasattr(config, 'mup_embeddings_scale'): - config.embeddings_scale = config.mup_embeddings_scale - else: - assert hasattr(config, 'embeddings_scale') - - config.n_inner += get_needed_padding(config.n_inner, tp_size) - - if gpt_variant == 'kosmos-2': - if config.text_config.scale_embedding: - config.embeddings_scale = config.n_embd**0.5 - - return config, gpt_variant - - -def get_needed_padding(value: int, multiple: int) -> int: - return (multiple - value % multiple) % multiple - - -def pad_array_up_to(v: torch.Tensor, axis: int, multiple: int) -> torch.Tensor: - a = [0 for i in range(len(v.shape) * 2)] - a[axis * 2 - 1] = get_needed_padding(v.shape[axis], multiple) - return torch.nn.functional.pad(v, a) - - -def split(param: torch.Tensor, - tp_rank: int, - tp_size: int, - is_column: bool = True) -> torch.Tensor: - """Split linear layer's weight, bias or scaling factors for tensor parallelism.""" - if param is None: - return None - assert param.ndim in [1, 2] - if tp_size == 1: - return param - if param.numel() == 1: - return param - if param.ndim == 1 and not is_column: - return param - split_dim = 0 if (param.ndim == 1 or is_column) else 1 - return torch.chunk(param, tp_size, dim=split_dim)[tp_rank].contiguous() - - -def split_qkv( - param: torch.Tensor, - tp_rank: int, - tp_size: int, - hidden_size: int, - num_heads: int, - num_kv_heads: Optional[int] = None, -) -> torch.Tensor: - """Split qkv layer's weight, bias or scaling factors for tensor parallelism. - - param: (num_heads*head_dim + 2*num_kv_heads*head_dim, in_dim) - """ - if param is None: - return None - assert hidden_size % num_heads == 0 - head_dim = hidden_size // num_heads - num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads - assert num_heads % num_kv_heads == 0 - assert num_heads % tp_size == 0 - - q_param, k_param, v_param = torch.split( - param, [hidden_size, num_kv_heads * head_dim, num_kv_heads * head_dim], - dim=0) - - if num_kv_heads < tp_size: - assert tp_size % num_kv_heads == 0 - num_dups = tp_size // num_kv_heads - remain_shape = k_param.shape[1:] - k_param = k_param.view( - num_kv_heads, head_dim, - *remain_shape).repeat_interleave(num_dups, dim=0).view( - num_kv_heads * head_dim * num_dups, *remain_shape) - v_param = v_param.view( - num_kv_heads, head_dim, - *remain_shape).repeat_interleave(num_dups, dim=0).view( - num_kv_heads * head_dim * num_dups, *remain_shape) - else: - assert num_kv_heads % tp_size == 0 - - q_param = split(q_param, tp_rank, tp_size, is_column=True) - k_param = split(k_param, tp_rank, tp_size, is_column=True) - v_param = split(v_param, tp_rank, tp_size, is_column=True) - return torch.cat([q_param, k_param, v_param], dim=0) - - -def split_embedding( - param: torch.Tensor, - tp_rank: int, - tp_size: int, - use_parallel_embedding: bool = False, - sharding_dim: int = 0, -) -> torch.Tensor: - if param is None: - return None - if not use_parallel_embedding: - return param - - vocab_size, hidden_size = param.size() - if sharding_dim == 0: - if vocab_size % tp_size != 0: - vocab_size_padded = pad_vocab_size(vocab_size, tp_size) - pad_width = vocab_size_padded - vocab_size - param = torch.nn.functional.pad(param, (0, 0, 0, pad_width), - value=0) - else: - assert hidden_size % tp_size == 0 - return split(param, tp_rank, tp_size, is_column=(sharding_dim == 0)) - - -def get_weight(params: Dict[str, torch.Tensor], prefix: str, - dtype: torch.dtype) -> torch.Tensor: - if f'{prefix}.weight' not in params: - return None - return params[f'{prefix}.weight'].to(dtype).detach().cpu() - - -def get_bias(params: Dict[str, torch.Tensor], prefix: str, - dtype: torch.dtype) -> torch.Tensor: - if f'{prefix}.bias' not in params: - return None - return params[f'{prefix}.bias'].to(dtype).detach().cpu() - - -def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str, - dtype: torch.dtype) -> Tuple[torch.Tensor]: - return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype) - - -def get_tllm_linear_weight( - weight: torch.Tensor, - prefix: str, - bias: Optional[torch.Tensor] = None, - use_weight_only: bool = False, - plugin_weight_only_quant_type: torch.dtype = torch.int8 -) -> Dict[str, torch.Tensor]: - results = {} - if use_weight_only: - v = weight.t().contiguous() - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - v, plugin_weight_only_quant_type) - results[f'{prefix}.weight'] = processed_torch_weights - results[f'{prefix}.per_channel_scale'] = torch_weight_scales - else: - results[f'{prefix}.weight'] = weight - - if bias is not None: - results[f'{prefix}.bias'] = bias - - return results - - -def convert_hf_gpt(hf_model: AutoModelForCausalLM, - hf_config: AutoConfig, - gpt_variant: str, - mapping: Mapping, - dtype: str = 'float32', - use_parallel_embedding: bool = False, - sharding_dim: int = 0, - share_embedding_table: bool = False, - use_weight_only: bool = False, - plugin_weight_only_quant_type: torch.dtype = torch.int8): - weights = {} - tik = time.time() - - model_params = dict(hf_model.named_parameters()) - dtype = getattr(torch, dtype) - num_attention_heads = hf_config.n_head - hidden_size = hf_config.n_embd - vocab_size = hf_config.vocab_size - num_kv_heads = hf_config.n_kv_head - num_hidden_layers = hf_config.n_layer - - layers_range = mapping.pp_layers(num_hidden_layers) - for l in layers_range: - if gpt_variant == 'starcoder2': - prefix = f'model.layers.{l}' - elif gpt_variant == 'persimmon': - is_fuyu = f'language_model.model.embed_tokens.weight' in model_params - prefix = f'language_model.model.layers.{l}' if is_fuyu else f'model.layers.{l}' - elif gpt_variant == 'kosmos-2': - prefix = f'text_model.model.layers.{l}' - else: - prefix = f'transformer.h.{l}' - tllm_prex = f'transformer.layers.{l-layers_range[0]}' - if gpt_variant == 'santacoder': - q_w, q_b = get_weight_and_bias(model_params, - f'{prefix}.attn.q_attn', dtype) - kv_w, kv_b = get_weight_and_bias(model_params, - f'{prefix}.attn.kv_attn', dtype) - qkv_w = torch.cat([q_w, kv_w], dim=-1) - qkv_b = torch.cat([q_b, kv_b], dim=-1) - elif gpt_variant in ['starcoder2', 'kosmos-2']: - q_w, q_b = get_weight_and_bias(model_params, - f'{prefix}.self_attn.q_proj', dtype) - k_w, k_b = get_weight_and_bias(model_params, - f'{prefix}.self_attn.k_proj', dtype) - v_w, v_b = get_weight_and_bias(model_params, - f'{prefix}.self_attn.v_proj', dtype) - qkv_w = torch.cat([q_w, k_w, v_w], dim=0) - qkv_b = torch.cat([q_b, k_b, v_b], dim=0) - elif gpt_variant == 'persimmon': - qkv_w, qkv_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.query_key_value', dtype) - else: - qkv_w, qkv_b = get_weight_and_bias(model_params, - f'{prefix}.attn.c_attn', dtype) - if gpt_variant in ['gpt2', 'santacoder', 'jais']: - qkv_w = qkv_w.t().contiguous() # transpose for Conv1D - - if gpt_variant == 'persimmon': - qkv_w = split(qkv_w, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - - qkv_b = split(qkv_b, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - else: - qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size, - hidden_size, num_attention_heads, num_kv_heads) - qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size, - hidden_size, num_attention_heads, num_kv_heads) - - weights.update( - get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', qkv_b, - use_weight_only, - plugin_weight_only_quant_type)) - - if gpt_variant == 'starcoder2': - attn_dense_w, attn_dense_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.o_proj', dtype) - elif gpt_variant == 'persimmon': - attn_dense_w, attn_dense_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.dense', dtype) - elif gpt_variant == 'kosmos-2': - attn_dense_w, attn_dense_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.out_proj', dtype) - else: - attn_dense_w, attn_dense_b = get_weight_and_bias( - model_params, f'{prefix}.attn.c_proj', dtype) - if gpt_variant in ['gpt2', 'santacoder', 'jais']: - attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D - attn_dense_w = split(attn_dense_w, - mapping.tp_rank, - mapping.tp_size, - is_column=False) - weights.update( - get_tllm_linear_weight(attn_dense_w, f'{tllm_prex}.attention.dense', - attn_dense_b, use_weight_only, - plugin_weight_only_quant_type)) - - if gpt_variant == 'persimmon': - mlp_fc_w, mlp_fc_b = get_weight_and_bias( - model_params, f'{prefix}.mlp.dense_h_to_4h', dtype) - elif gpt_variant == 'kosmos-2': - mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params, - f'{prefix}.ffn.fc1', dtype) - else: - mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params, - f'{prefix}.mlp.c_fc', - dtype) - if gpt_variant in ['gpt2', 'santacoder', 'jais']: - mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D - if gpt_variant in ['jais']: - mlp_fc_w = pad_array_up_to(mlp_fc_w, 0, mapping.tp_size) - mlp_fc_b = pad_array_up_to(mlp_fc_b, 0, mapping.tp_size) - mlp_fc_w = split(mlp_fc_w, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - mlp_fc_b = split(mlp_fc_b, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - if gpt_variant in ['jais']: - weights.update( - get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.gate', - mlp_fc_b, use_weight_only, - plugin_weight_only_quant_type)) - else: - weights.update( - get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', - mlp_fc_b, use_weight_only, - plugin_weight_only_quant_type)) - if gpt_variant in ['jais']: - mlp_fc2_w, mlp_fc2_b = get_weight_and_bias(model_params, - f'{prefix}.mlp.c_fc2', - dtype) - mlp_fc2_w = mlp_fc2_w.t().contiguous() - mlp_fc2_w = pad_array_up_to(mlp_fc2_w, 0, mapping.tp_size) - mlp_fc2_b = pad_array_up_to(mlp_fc2_b, 0, mapping.tp_size) - mlp_fc2_w = split(mlp_fc2_w, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - mlp_fc2_b = split(mlp_fc2_b, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - weights.update( - get_tllm_linear_weight(mlp_fc2_w, f'{tllm_prex}.mlp.fc', - mlp_fc2_b, use_weight_only, - plugin_weight_only_quant_type)) - - if gpt_variant == 'persimmon': - mlp_proj_w, mlp_proj_b = get_weight_and_bias( - model_params, f'{prefix}.mlp.dense_4h_to_h', dtype) - elif gpt_variant == 'kosmos-2': - mlp_proj_w, mlp_proj_b = get_weight_and_bias( - model_params, f'{prefix}.ffn.fc2', dtype) - else: - mlp_proj_w, mlp_proj_b = get_weight_and_bias( - model_params, f'{prefix}.mlp.c_proj', dtype) - if gpt_variant in ['gpt2', 'santacoder', 'jais']: - mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D - if gpt_variant in ['jais']: - mlp_proj_w = pad_array_up_to(mlp_proj_w, 1, mapping.tp_size) - mlp_proj_w = split(mlp_proj_w, - mapping.tp_rank, - mapping.tp_size, - is_column=False) - weights.update( - get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', - mlp_proj_b, use_weight_only, - plugin_weight_only_quant_type)) - - if gpt_variant in ['starcoder2', 'persimmon']: - input_ln_w, input_ln_b = get_weight_and_bias( - model_params, f'{prefix}.input_layernorm', dtype) - elif gpt_variant == 'kosmos-2': - input_ln_w, input_ln_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn_layer_norm', dtype) - else: - input_ln_w, input_ln_b = get_weight_and_bias( - model_params, f'{prefix}.ln_1', dtype) - weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w - if input_ln_b is not None: - weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b - - if gpt_variant in ['starcoder2', 'persimmon']: - post_ln_w, post_ln_b = get_weight_and_bias( - model_params, f'{prefix}.post_attention_layernorm', dtype) - elif gpt_variant == 'kosmos-2': - post_ln_w, post_ln_b = get_weight_and_bias( - model_params, f'{prefix}.final_layer_norm', dtype) - else: - post_ln_w, post_ln_b = get_weight_and_bias(model_params, - f'{prefix}.ln_2', dtype) - weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w - if post_ln_b is not None: - weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b - - if gpt_variant == 'persimmon': - q_layernorm_w, q_layernorm_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.q_layernorm', dtype) - - weights[f'{tllm_prex}.attention.q_layernorm.weight'] = q_layernorm_w - weights[f'{tllm_prex}.attention.q_layernorm.bias'] = q_layernorm_b - - k_layernorm_w, k_layernorm_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.k_layernorm', dtype) - - weights[f'{tllm_prex}.attention.k_layernorm.weight'] = k_layernorm_w - weights[f'{tllm_prex}.attention.k_layernorm.bias'] = k_layernorm_b - - if gpt_variant == 'kosmos-2': - q_layernorm_w, q_layernorm_b = get_weight_and_bias( - model_params, f'{prefix}.self_attn.inner_attn_ln', dtype) - - weights[ - f'{tllm_prex}.attention.inner_layernorm.weight'] = q_layernorm_w - weights[ - f'{tllm_prex}.attention.inner_layernorm.bias'] = q_layernorm_b - - k_layernorm_w, k_layernorm_b = get_weight_and_bias( - model_params, f'{prefix}.ffn.ffn_layernorm', dtype) - - weights[f'{tllm_prex}.mlp.inner_layernorm.weight'] = k_layernorm_w - weights[f'{tllm_prex}.mlp.inner_layernorm.bias'] = k_layernorm_b - - if mapping.is_first_pp_rank(): - if gpt_variant == 'starcoder2': - embed_w = get_weight(model_params, 'model.embed_tokens', dtype) - elif gpt_variant == 'kosmos-2': - embed_w = get_weight(model_params, 'text_model.model.embed_tokens', - dtype) - elif gpt_variant == 'persimmon': - embed_w = get_weight(model_params, - ('language_model.' if is_fuyu else '') + - 'model.embed_tokens', dtype) - else: - embed_w = get_weight(model_params, 'transformer.wte', dtype) - weights['transformer.vocab_embedding.weight'] = split_embedding( - embed_w, - mapping.tp_rank, - mapping.tp_size, - use_parallel_embedding=use_parallel_embedding, - sharding_dim=sharding_dim) - - if gpt_variant == 'kosmos-2': - padding_idx = hf_config.text_config.pad_token_id - sin_pos_embedding = hf_model.text_model.model.embed_positions.get_embedding( - padding_idx + 1 + hf_config.text_config.max_position_embeddings, - hf_config.text_config.embed_dim, - padding_idx=padding_idx) # [2 + num_embeddings, embed_dim] - pos_embed_w = sin_pos_embedding[2:].to(dtype).detach().cpu() - else: - pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype) - if pos_embed_w is not None: - weights['transformer.position_embedding.weight'] = split_embedding( - pos_embed_w, - mapping.tp_rank, - mapping.tp_size, - use_parallel_embedding=use_parallel_embedding, - sharding_dim=sharding_dim) - - if mapping.is_last_pp_rank(): - if gpt_variant == 'starcoder2': - embed_w = get_weight(model_params, 'lm_head', dtype) - if embed_w is None: - embed_w = get_weight(model_params, 'model.embed_tokens', dtype) - elif gpt_variant == 'persimmon': - embed_w = get_weight(model_params, - ('language_model.' if is_fuyu else '') + - 'lm_head', dtype) - elif gpt_variant == 'kosmos-2': - embed_w = get_weight(model_params, 'text_model.model.embed_tokens', - dtype) - else: - embed_w = get_weight(model_params, 'transformer.wte', dtype) - if not share_embedding_table: - if vocab_size % mapping.tp_size != 0: - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size - embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width), - value=0) - if hasattr(hf_config, 'logits_scale'): - embed_w *= hf_config.logits_scale - weights['lm_head.weight'] = split(embed_w.clone(), - mapping.tp_rank, - mapping.tp_size, - is_column=True) - if gpt_variant == 'starcoder2': - ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'model.norm', - dtype) - elif gpt_variant == 'persimmon': - ln_f_w, ln_f_b = get_weight_and_bias( - model_params, ('language_model.' if is_fuyu else '') + - 'model.final_layernorm', dtype) - elif gpt_variant == 'kosmos-2': - ln_f_w, ln_f_b = get_weight_and_bias(model_params, - 'text_model.model.layer_norm', - dtype) +def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: + '''return config dict with quantization info based on the command line args + ''' + quant_config = QuantConfig() + if args.use_weight_only: + if args.weight_only_precision == 'int8': + quant_config.quant_algo = QuantAlgo.W8A16 + elif args.weight_only_precision == 'int4': + quant_config.quant_algo = QuantAlgo.W4A16 + elif args.smoothquant: + quant_config.smoothquant_val = args.smoothquant + if args.per_channel: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN else: - ln_f_w, ln_f_b = get_weight_and_bias(model_params, - 'transformer.ln_f', dtype) - weights['transformer.ln_f.weight'] = ln_f_w - if ln_f_b is not None: - weights['transformer.ln_f.bias'] = ln_f_b - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') - return weights - - -def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): - """ - This function has two purposes: - - compute quantized weights, scaled either per-tensor or per-column - - compute scaling factors - - Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. - CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. - CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. - - Here is the list of what we need (T means per-tensor, C per-column): - - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) - - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) - - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) - - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) - to quant range (int8) (used for CUBLAS) (T, C) + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN - Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, - but then the model would change depending on the number of GPUs used. + if args.int8_kv_cache: + quant_config.kv_cache_quant_algo = QuantAlgo.INT8 - For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it - as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. - """ + return quant_config - # compute weight scaling factors for fp->int8 and int8->fp - if is_qkv and not multi_query_mode: - scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( - dim=-1, keepdims=True)[0].cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, - -1).cpu().numpy() - elif is_qkv and multi_query_mode: - raise ValueError( - f"Multi-query w/ int8 quant has not been supported yet") - else: - scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() - scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t - scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c - # compute the rest of needed scaling factors - scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) - scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) - scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) - scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_t) - scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_c) - if is_qkv: - scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, - scale_w_orig_quant_c.shape) - scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, - scale_w_orig_quant_c.shape) +def convert_and_save_hf(args): + model_dir = args.model_dir + load_model_on_cpu = args.load_model_on_cpu + world_size = args.tp_size * args.pp_size - to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) - return { - "weight.int8": to_i8(weights * scale_w_orig_quant_t), - "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), - "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), - "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), - "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), - "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), - "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), - "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), + override_fields = { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, } + quant_config = args_to_quant_config(args) -@torch.no_grad() -def apply_smoothing(scales, - gemm_weights, - layernorm_weights=None, - layernorm_bias=None, - dtype=torch.float32, - layernorm_1p=False): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - - if layernorm_weights is not None: - assert layernorm_weights.numel() == scales.numel() - layernorm_weights.div_(scales).to(dtype) - if layernorm_bias is not None: - assert layernorm_bias.numel() == scales.numel() - layernorm_bias.div_(scales).to(dtype) - if layernorm_1p: - layernorm_weights += (1 / scales) - 1 - - for gemm in gemm_weights: - gemm.mul_(scales.view(1, -1)).to(dtype) - - -@torch.no_grad() -def smooth_gemm(gemm_weights, - act_scales, - layernorm_weights=None, - layernorm_bias=None, - alpha=0.5, - weight_scales=None): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - orig_dtype = gemm_weights[0].dtype - - for gemm in gemm_weights: - # gemm_weights are expected to be transposed - assert gemm.shape[1] == act_scales.numel() - - if weight_scales is None: - weight_scales = torch.cat( - [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], - dim=0) - weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) - scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5) - - apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, - orig_dtype) - - return scales - - -@torch.no_grad() -def capture_activation_range(model, - tokenizer, - dataset, - num_samples=512, - seq_len=512): - model.eval() - device = next(model.parameters()).device - act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) - - def stat_tensor(name, tensor, act_scales, key): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float() - - if act_scales[name][key] is None: - act_scales[name][key] = comming_max - else: - act_scales[name][key] = torch.max(act_scales[name][key], - comming_max) - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x, act_scales, "x") - stat_tensor(name, y, act_scales, "y") - - if act_scales[name]["w"] is None: - act_scales[name]["w"] = m.weight.abs().clip(1e-8, - None).max(dim=0)[0] - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear) or isinstance(m, Conv1D): - hooks.append( - m.register_forward_hook( - functools.partial(stat_input_hook, name=name))) - - for i in tqdm(range(num_samples), desc="calibrating model"): - input_ids = tokenizer(dataset[i], - return_tensors="pt", - max_length=seq_len, - truncation=True).input_ids.to(device) - model(input_ids) - - for h in hooks: - h.remove() - - return act_scales - - -@torch.no_grad() -def smooth_gpt_model(model, scales, alpha): - # Smooth the activation and weights with smoother = $\diag{s}$ - for name, module in model.named_modules(): - if not isinstance(module, GPT2Block): - continue - - # qkv_proj - layer_name = name + ".attn.c_attn" - smoother = smooth_gemm(module.attn.c_attn.weight.T, - scales[layer_name]["x"], module.ln_1.weight, - module.ln_1.bias, alpha) - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=0)[0] - - # fc1 - layer_name = name + ".mlp.c_fc" - smoother = smooth_gemm(module.mlp.c_fc.weight.T, - scales[layer_name]["x"], module.ln_2.weight, - module.ln_2.bias, alpha) - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.mlp.c_fc.weight.abs().max(dim=0)[0] - - -def get_tllm_linear_sq_weight(vals, - prefix, - shape, - tensor_parallel, - is_qkv=False, - per_token=False, - per_channel=False, - last_prefix=None, - bias=None, - smoother_value=None, - smoother_shape=None, - rank=0, - cat_dim=0, - multi_query_mode=False): - results = {} - - def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): - q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) - q_split = np.split(q, tp_size, axis=-1) - k_split = np.split(k, tp_size, axis=-1) - v_split = np.split(v, tp_size, axis=-1) - return [ - np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) - for ii in range(tp_size) - ][cur_rank] - - col_shape = shape if (is_qkv or per_channel) else [1, 1] - - if per_token: - if per_channel: - original_weights = np.array(vals["weight.int8.col"]) - else: - original_weights = np.array(vals["weight.int8"]) - local_dim = original_weights.shape[0] - head_size = (original_weights.shape[1] - local_dim) // 2 - - if multi_query_mode: - cur_weights = multi_query_split(original_weights, local_dim, - head_size, tensor_parallel, rank) - else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] - if is_qkv: - hidden_dim = cur_weights.shape[0] - cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() - if smoother_value is None: - results[last_prefix] = torch.from_numpy( - np.array([1.0], dtype=np.float32)) - - if per_channel: - cur_per_channel_value = vals["scale_w_quant_orig.col"] - if smoother_value is None: - if multi_query_mode: - cur_per_channel_value = multi_query_split( - vals["scale_w_quant_orig.col"], local_dim, head_size, - tensor_parallel, rank) - else: - cur_per_channel_value = np.split( - vals["scale_w_quant_orig.col"], - tensor_parallel, - axis=cat_dim)[rank] - else: - cur_per_channel_value = vals["scale_w_quant_orig"] - if is_qkv: - if multi_query_mode: - cur_per_channel_value = multi_query_split( - vals["scale_w_quant_orig"], local_dim, head_size, - tensor_parallel, rank) - else: - cur_per_channel_value = np.split(vals["scale_w_quant_orig"], - tensor_parallel, - axis=cat_dim)[rank] - - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() + if args.smoothquant is not None or args.int8_kv_cache: + mapping = Mapping(world_size=world_size, + tp_size=args.tp_size, + pp_size=args.pp_size) + GPTForCausalLM.quantize( + args.model_dir, + args.output_dir, + dtype=args.dtype, + mapping=mapping, + quant_config=quant_config, + device='cpu' if args.load_model_on_cpu else 'cuda', + calib_dataset=args.calib_dataset, + **override_fields) else: - if per_channel: - original_weights = np.array(vals["weight.int8.col"]) - else: - original_weights = np.array(vals["weight.int8"]) - local_dim = original_weights.shape[0] - head_size = (original_weights.shape[1] - local_dim) // 2 - - if multi_query_mode: - cur_weights = multi_query_split(original_weights, local_dim, - head_size, tensor_parallel, rank) - else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] - if is_qkv: - hidden_dim = cur_weights.shape[0] - cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() - - if per_channel: - cur_per_channel_value = vals["scale_y_accum_quant.col"] - if smoother_value is None: - if multi_query_mode: - cur_per_channel_value = multi_query_split( - vals["scale_y_accum_quant.col"], local_dim, head_size, - tensor_parallel, rank) - else: - cur_per_channel_value = np.split( - vals["scale_y_accum_quant.col"], - tensor_parallel, - axis=cat_dim)[rank] - else: - cur_per_channel_value = vals["scale_y_accum_quant"] - # QKV is always per_channel - if is_qkv: - if multi_query_mode: - cur_per_channel_value = multi_query_split( - vals["scale_y_accum_quant"], local_dim, head_size, - tensor_parallel, rank) - else: - cur_per_channel_value = np.split( - vals["scale_y_accum_quant"], - tensor_parallel, - axis=cat_dim)[rank] - - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array([cur_per_channel_value], - dtype=np.float32).reshape(col_shape)).contiguous() - - results[last_prefix] = torch.from_numpy( - np.array([vals['scale_x_orig_quant']], - dtype=np.float32)).contiguous() - - results[prefix + 'act_scale'] = torch.from_numpy( - np.array([[vals["scale_y_quant_orig"]]], - dtype=np.float32)).contiguous() - - if smoother_value is not None: - cur_smoother_value = np.split(smoother_value, - tensor_parallel, - axis=cat_dim)[rank] - results[prefix + 'smoother'] = cur_smoother_value.reshape( - smoother_shape).contiguous().to(torch.float32) - - if bias is not None: - results[prefix + 'bias'] = bias - - return results - - -def convert_hf_gpt_legacy(hf_model: AutoModelForCausalLM, - hf_config: AutoConfig, - gpt_variant: str, - mapping: Mapping, - dtype: str = 'float32', - use_parallel_embedding: bool = False, - sharding_dim: int = 0, - share_embedding_table: bool = False, - use_smooth_quant=False, - per_channel=False, - per_token=False, - int8_kv_cache=False, - act_range=None): - weights = {} - tik = time.time() - - model_params = dict(hf_model.named_parameters()) - dtype = getattr(torch, dtype) - num_attention_heads = hf_config.n_head - hidden_size = hf_config.n_embd - vocab_size = hf_config.vocab_size - num_kv_heads = hf_config.n_kv_head - num_hidden_layers = hf_config.n_layer - multi_query_mode = (num_kv_heads != num_attention_heads) - tensor_parallel = mapping.tp_size - - layers_range = mapping.pp_layers(num_hidden_layers) - for l in layers_range: - prefix = f'transformer.h.{l}' - tllm_prex = f'transformer.layers.{l-layers_range[0]}' - - if gpt_variant == 'santacoder': - q_w, q_b = get_weight_and_bias(model_params, - f'{prefix}.attn.q_attn', dtype) - kv_w, kv_b = get_weight_and_bias(model_params, - f'{prefix}.attn.kv_attn', dtype) - qkv_w = torch.cat([q_w, kv_w], dim=-1) - qkv_b = torch.cat([q_b, kv_b], dim=-1) - else: - qkv_w, qkv_b = get_weight_and_bias(model_params, - f'{prefix}.attn.c_attn', dtype) - if gpt_variant in ['gpt2', 'santacoder']: - qkv_w = qkv_w.t().contiguous() # transpose for Conv1D - - if use_smooth_quant: - qkv_out_dim = qkv_w.shape[0] - qkv_w_numpy = qkv_w.t().numpy() - if not multi_query_mode: - qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size) - int8_weights = generate_int8(qkv_w_numpy, - act_range.get(f'{prefix}.attn.c_attn'), - is_qkv=True, - multi_query_mode=multi_query_mode) - qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size, - hidden_size, num_attention_heads, num_kv_heads) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - f'{tllm_prex}.attention.qkv.', - [1, qkv_out_dim // tensor_parallel], - tensor_parallel, - is_qkv=True, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.input_layernorm.scale_to_int', - bias=qkv_b, - smoother_value=None, - smoother_shape=None, - rank=rank, - cat_dim=-1, - multi_query_mode=multi_query_mode)) - else: - qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size, - hidden_size, num_attention_heads, num_kv_heads) - qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size, - hidden_size, num_attention_heads, num_kv_heads) - weights.update( - get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', - qkv_b)) - - if int8_kv_cache: - qkv_w_numpy = qkv_w.t().numpy() - if not multi_query_mode: - qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size) - int8_weights = generate_int8(qkv_w_numpy, - act_range.get(f'{prefix}.attn.c_attn'), - is_qkv=True, - multi_query_mode=multi_query_mode) - weights[ - f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.from_numpy( - np.array([int8_weights['scale_y_quant_orig']], - dtype=np.float32)).contiguous() - - attn_dense_w, attn_dense_b = get_weight_and_bias( - model_params, f'{prefix}.attn.c_proj', dtype) - if gpt_variant in ['gpt2', 'santacoder']: - attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D - if use_smooth_quant: - attn_dense_w_numpy = attn_dense_w.t().numpy() - int8_weights = generate_int8(attn_dense_w_numpy, - act_range.get(f'{prefix}.attn.c_proj')) - # change it to the real smoother if dense layer is applied smooth quant - fake_smoother_value = torch.ones([1, hidden_size], - dtype=torch.float32) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - f'{tllm_prex}.attention.dense.', [1, hidden_size], - tensor_parallel, - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix= - f'{tllm_prex}.attention.quantization_scaling_factor', - bias=attn_dense_b, - smoother_value=fake_smoother_value, - smoother_shape=[1, hidden_size // tensor_parallel], - rank=rank, - cat_dim=0)) - else: - attn_dense_w = split(attn_dense_w, - mapping.tp_rank, - mapping.tp_size, - is_column=False) - weights.update( - get_tllm_linear_weight(attn_dense_w, - f'{tllm_prex}.attention.dense', - attn_dense_b)) - - mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params, - f'{prefix}.mlp.c_fc', dtype) - if gpt_variant in ['gpt2', 'santacoder']: - mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D - if use_smooth_quant: - mlp_fc_w_numpy = mlp_fc_w.t().numpy() - int8_weights = generate_int8(mlp_fc_w_numpy, - act_range.get(f'{prefix}.mlp.c_fc')) - mlp_fc_b = split(mlp_fc_b, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - f'{tllm_prex}.mlp.fc.', - [1, 4 * hidden_size // tensor_parallel], - tensor_parallel, - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.post_layernorm.scale_to_int', - bias=mlp_fc_b, - smoother_value=None, - smoother_shape=None, - rank=rank, - cat_dim=-1)) - else: - mlp_fc_w = split(mlp_fc_w, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - mlp_fc_b = split(mlp_fc_b, - mapping.tp_rank, - mapping.tp_size, - is_column=True) - weights.update( - get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', - mlp_fc_b)) - - mlp_proj_w, mlp_proj_b = get_weight_and_bias(model_params, - f'{prefix}.mlp.c_proj', - dtype) - if gpt_variant in ['gpt2', 'santacoder']: - mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D - if use_smooth_quant: - mlp_proj_w_numpy = mlp_proj_w.t().numpy() - int8_weights = generate_int8(mlp_proj_w_numpy, - act_range.get(f'{prefix}.mlp.c_proj')) - # change it to the real smoother if proj layer is applied smooth quant - fake_smoother_value = torch.ones([1, 4 * hidden_size], - dtype=torch.float32) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - f'{tllm_prex}.mlp.proj.', [1, hidden_size], - tensor_parallel, - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.mlp.quantization_scaling_factor', - bias=mlp_proj_b, - smoother_value=fake_smoother_value, - smoother_shape=[1, 4 * hidden_size // tensor_parallel], - rank=rank, - cat_dim=0)) - else: - mlp_proj_w = split(mlp_proj_w, - mapping.tp_rank, - mapping.tp_size, - is_column=False) - weights.update( - get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', - mlp_proj_b)) - - input_ln_w, input_ln_b = get_weight_and_bias(model_params, - f'{prefix}.ln_1', dtype) - weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w - if input_ln_b is not None: - weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b - - post_ln_w, post_ln_b = get_weight_and_bias(model_params, - f'{prefix}.ln_2', dtype) - weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w - if post_ln_b is not None: - weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b - - if mapping.is_first_pp_rank(): - embed_w = get_weight(model_params, 'transformer.wte', dtype) - weights['transformer.vocab_embedding.weight'] = split_embedding( - embed_w, - mapping.tp_rank, - mapping.tp_size, - use_parallel_embedding=use_parallel_embedding, - sharding_dim=sharding_dim) - - pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype) - if pos_embed_w is not None: - weights['transformer.position_embedding.weight'] = split_embedding( - pos_embed_w, - mapping.tp_rank, - mapping.tp_size, - use_parallel_embedding=use_parallel_embedding, - sharding_dim=sharding_dim) - - if mapping.is_last_pp_rank(): - embed_w = get_weight(model_params, 'transformer.wte', dtype) - if not share_embedding_table: - if vocab_size % mapping.tp_size != 0: - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size - embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width), - value=0) - weights['lm_head.weight'] = split(embed_w.clone(), - mapping.tp_rank, - mapping.tp_size, - is_column=True) - ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f', - dtype) - weights['transformer.ln_f.weight'] = ln_f_w - if ln_f_b is not None: - weights['transformer.ln_f.bias'] = ln_f_b - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') - return weights - - -def cpu_map_location(storage, loc): - return storage.cpu() - - -def gpu_map_location(storage, loc): - if loc.startswith("cuda"): - training_gpu_idx = int(loc.split(":")[1]) - inference_gpu_idx = training_gpu_idx % torch.cuda.device_count() - return storage.cuda(inference_gpu_idx) - elif loc.startswith("cpu"): - return storage.cpu() + hf_model = load_hf_gpt(model_dir, load_model_on_cpu) + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + model = GPTForCausalLM.from_hugging_face(hf_model, + args.dtype, + mapping=mapping, + quant_config=quant_config, + gpt_variant=args.gpt_variant, + **override_fields) + model.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del model + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) else: - raise ValueError(f"Not handled {loc}") - - -def copy_tokenizer_files(config, out_dir): - basenames = { - "model": "tokenizer", - "vocab_file": "vocab", - "merge_file": "merges", - } - - for key in basenames.keys(): - if config[key] is None: - continue - path = Path(config[key]) - if not path.exists(): - LOGGER.debug(f"Tokenizer {key}: {path} file not found") - continue - - dst_path = out_dir / f"{basenames[key]}{path.suffix}" - LOGGER.debug(f"Copy tokenizer {key}: {path}->{dst_path}") - shutil.copy(path.as_posix(), dst_path.as_posix()) - - -def update_tokenizer_paths(tokenizer_config: Dict, - tokenizer_file_paths: Dict[str, Optional[str]]): - for key, new_path in tokenizer_file_paths.items(): - old_path = tokenizer_config[key] - if old_path is None: - continue - old_path = Path(old_path) - if new_path: - LOGGER.debug(f"Update tokenizer {key} {old_path} -> {new_path}") - tokenizer_config[key] = new_path.as_posix() - elif not old_path.exists(): - LOGGER.warning( - f"Tokenizer {key}'s path {old_path} does not exists: set it to None" - ) - tokenizer_config[key] = None - return tokenizer_config - - -def unpack_nemo_ckpt(nemo_archive_path: Union[str, Path], - out_dir_path: Union[str, Path]): - nemo_archive_path = Path(nemo_archive_path) - if not nemo_archive_path.exists(): - raise FileNotFoundError(f"{nemo_archive_path} does not exist") - - for tar_mode in ["r:", "r:gz"]: - try: - with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file: - - def is_within_directory(directory, target): - - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_members(tar_file): - members = [] - for member in tar_file.getmembers(): - member_path = os.path.join(out_dir_path, member.name) - if not is_within_directory(out_dir_path, member_path): - raise Exception( - "Attempted Path Traversal in Tar File") - members.append(member) - return members - - tar_file.extractall(out_dir_path, - members=safe_members(tar_file), - numeric_owner=False) - - return out_dir_path - except tarfile.ReadError: - pass + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." - raise RuntimeError(f"Could not unpack {nemo_archive_path}") +def convert_and_save_nemo(args): + world_size = args.tp_size * args.pp_size + quant_config = args_to_quant_config(args) -def extract_layers_with_prefix(model_, prefix): - length_to_trim = len(prefix) - model_state = model_.get("state_dict", model_) - return { - key[length_to_trim:]: model_state[key] - for key in model_state.keys() if prefix in key + override_fields = { + 'use_parallel_embedding': True, + 'embedding_sharding_dim': 0, + 'share_embedding_table': args.use_embedding_sharing, } + nemo_ckpt_dir = os.path.join(args.output_dir, "unpacked") + nemo_ckpt_dir = unpack_nemo_ckpt(args.nemo_ckpt_path, nemo_ckpt_dir) -class UnpackedNemoCheckpointDir: - - def __init__(self, - checkpoints_dir: Union[str, Path], - load_checkpoints_to_cpu: bool = False): - self._checkpoints_dir = Path(checkpoints_dir) - self._load_checkpoints_to_cpu = load_checkpoints_to_cpu - - @property - @functools.lru_cache - def model_config(self): - model_config = None - - model_config_filename = "model_config.yaml" - model_configs_paths = list( - self._checkpoints_dir.rglob(model_config_filename)) - if model_configs_paths: - if len(model_configs_paths) > 1: - raise RuntimeError( - f"There are more than single {model_config_filename} " - f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}" - ) - model_config_path = model_configs_paths[0] - LOGGER.debug("Loading model config from %s", model_config_path) - with model_config_path.open("r") as model_config_file: - model_config = yaml.load(model_config_file, - Loader=yaml.SafeLoader) - else: - LOGGER.debug("Searching model config in checkpoints") - # try to obtain from checkpoint - checkpoint_name = self.checkpoint_name - checkpoints_paths = sorted( - self._checkpoints_dir.rglob(checkpoint_name)) - if checkpoints_paths: - # assume that parallel ranks 0 checkpoint should have model config embedded - checkpoint_path = checkpoints_paths[0] - - map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location - - model_00 = torch.load(checkpoint_path, - map_location=map_location_fn) - if "hyper_parameters" in model_00 and "cfg" in model_00[ - "hyper_parameters"]: - model_config = model_00["hyper_parameters"]["cfg"] - LOGGER.debug("Loaded model config from checkpoint %s", - checkpoint_path) - else: - LOGGER.debug("Could not find model config in checkpoint %s", - checkpoint_path) - del model_00 - - if model_config is None: - LOGGER.warning( - "Could not find checkpoint with NeMo model config in %s", - self._checkpoints_dir) - - LOGGER.debug("Loaded model config %s", model_config) - - return model_config - - @property - def checkpoints_dir(self): - return self._checkpoints_dir - - def get_checkpoints_paths(self, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1): - """ - Injects tensor/pipeline model parallel ranks into the filepath. - Does nothing if not using model parallelism. - """ - - checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name - - def _inject_parallel_ranks(tp_rank, pp_rank): - if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1: - if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1: - checkpoint_path = (checkpoint_path_without_rank.parent / - f"mp_rank_{tp_rank:02d}" / - checkpoint_path_without_rank.name) - else: - checkpoint_path = ( - checkpoint_path_without_rank.parent / - f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" / - checkpoint_path_without_rank.name) - return checkpoint_path - else: - return checkpoint_path_without_rank - - return [[ - _inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank) - for pp_rank in range(pipeline_model_parallel_size) - ] for tp_rank in range(tensor_model_parallel_size)] - - @property - @functools.lru_cache - def checkpoint_name(self): - patterns = [ - "model_weights.ckpt", # older megatron checkpoints - "*last.ckpt", # newer format of checkpoints - ] - for pattern in patterns: - model_files = sorted(list(self._checkpoints_dir.rglob(pattern))) - if model_files: - return model_files[0].name - - raise ValueError( - f"Could not find checkpoint files in {self._checkpoints_dir}") - - @functools.lru_cache - def get_tokenizer_file_path(self, tokenizer_key, file_key, - default_filename_pattern): - model_config = self.model_config - file_property = None - if tokenizer_key in model_config and file_key in model_config[ - tokenizer_key]: - file_property = model_config[tokenizer_key][file_key] - elif file_key in model_config: - file_property = model_config[file_key] - - LOGGER.debug("model_config[%s][%s]=%s", tokenizer_key, file_key, - file_property) - - if file_property and file_property.startswith("nemo:"): - filename = file_property.split("nemo:")[1] - filename_pattern = f"*{filename}" - elif file_property and file_property.startswith("/artifacts/"): - filename = Path(file_property).name - filename_pattern = f"*{filename}" - elif file_property is None or file_property == "None": - filename_pattern = None - else: - filename_pattern = default_filename_pattern - LOGGER.warning( - f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} " - f"looks like unsupported path. Pattern {filename_pattern} will be used." - ) - - file_path = None - if filename_pattern is not None: - files_paths = list(self._checkpoints_dir.glob(filename_pattern)) - if files_paths: - assert len(files_paths) == 1 - file_path = files_paths[0] - - return file_path - - @functools.lru_cache - def get_all_tokenizer_file_paths(self): - return { - "model": - self.get_tokenizer_file_path("tokenizer", "model", "*.model"), - "vocab_file": - self.get_tokenizer_file_path("tokenizer", "vocab_file", "*vocab*"), - "merge_file": - self.get_tokenizer_file_path("tokenizer", "merge_file", - "*merge*.txt"), - } - - -def load_nemo_gpt_config( - unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, - layer_rename_config: Dict[str, str] = None) -> GPT2Config: + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + model = GPTForCausalLM.from_nemo( + nemo_ckpt_dir, + dtype=args.dtype, + mapping=mapping, + quant_config=quant_config, + load_model_on_cpu=args.load_model_on_cpu, + nemo_rename_key=args.nemo_rename_key, + **override_fields) + model.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del model + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + # Copy tokenizer files + unpacked_checkpoints_dir = UnpackedNemoCheckpointDir( + nemo_ckpt_dir, load_checkpoints_to_cpu=args.load_model_on_cpu) nemo_model_config = unpacked_checkpoints_dir.model_config - - training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1) - training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1) - - checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( - training_tp_size, - training_pp_size, - ) - if unpacked_checkpoints_dir._load_checkpoints_to_cpu: - map_location_fn = cpu_map_location - else: - map_location_fn = gpu_map_location - model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn) - model_00 = rename_keys(model_00, layer_rename_config) - vocab_size = model_00[ - "model.language_model.embedding.word_embeddings.weight"].shape[ - 0] * training_tp_size - del model_00 - - hf_config = GPT2Config( - vocab_size=vocab_size, - n_positions=nemo_model_config['max_position_embeddings'], - n_embd=nemo_model_config['hidden_size'], - n_layer=nemo_model_config['num_layers'], - n_head=nemo_model_config['num_attention_heads'], - n_inner=nemo_model_config['ffn_hidden_size'], - activation_function=nemo_model_config['activation'], - layer_norm_epsilon=nemo_model_config['layernorm_epsilon'], - ) - hf_config.n_kv_head = hf_config.n_head - hf_config.bias = nemo_model_config['bias'] - # hf_config.apply_query_key_layer_scaling = nemo_model_config['apply_query_key_layer_scaling'] - hf_config.apply_query_key_layer_scaling = False - - hf_config.position_embedding_type = nemo_model_config.get( - 'position_embedding_type', 'learned_absolute') - if hf_config.position_embedding_type == 'rope': - hf_config.position_embedding_type = 'rope_gpt_neox' - hf_config.rotary_base = nemo_model_config.get('rotary_base', 10000.0) - hf_config.rotary_pct = nemo_model_config.get('rotary_percentage', 1.0) - assert hf_config.rotary_pct >= 0 and hf_config.rotary_pct <= 1 - - rotary_scaling_factor = nemo_model_config.get( - 'seq_len_interpolation_factor', None) - if rotary_scaling_factor is None: - hf_config.rotary_scaling = None - else: - assert rotary_scaling_factor > 1 - hf_config.rotary_scaling = { - 'type': 'linear', - 'factor': rotary_scaling_factor - } - tokenizer_config = update_tokenizer_paths( nemo_model_config["tokenizer"], unpacked_checkpoints_dir.get_all_tokenizer_file_paths()) + copy_tokenizer_files(tokenizer_config, Path(args.output_dir)) - return hf_config, tokenizer_config - - -@torch.no_grad() -def load_torch_checkpoints(checkpoints_paths, - merge_factor, - tp_rank, - pp_rank, - map_location_fn, - handle_model_level_weights, - layer_rename_config: Dict[str, str] = {}): - models = [] - for k in range(merge_factor): - rank_weights = checkpoints_paths[tp_rank * merge_factor + k][pp_rank] - model = torch.load(rank_weights, map_location=map_location_fn) - model = rename_keys(model, layer_rename_config) - handle_model_level_weights(model, tp_rank * merge_factor + k, pp_rank) - layers = extract_layers_with_prefix(model, - "model.language_model.encoder.") - models.append(layers) - return models - - -@torch.no_grad() -def convert_nemo_gpt(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, - mapping: Mapping, - dtype: str = 'float32', - layer_rename_config: Dict[str, str] = None): - nemo_model_config = unpacked_checkpoints_dir.model_config - - checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( - nemo_model_config.get("tensor_model_parallel_size", 1), - nemo_model_config.get("pipeline_model_parallel_size", 1), - ) - - if unpacked_checkpoints_dir._load_checkpoints_to_cpu: - map_location_fn = cpu_map_location - else: - map_location_fn = gpu_map_location - dtype = str_dtype_to_torch(dtype) - - # load position_embedding from rank 0 - model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn) - model_00 = model_00.get("state_dict", model_00) - model_00 = rename_keys(model_00, layer_rename_config) - has_position_embedding = "model.language_model.embedding.position_embeddings.weight" in model_00 - has_lm_head = "model.language_model.output_layer.weight" in model_00 - del model_00 - - num_layers = nemo_model_config["num_layers"] - training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1) - training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1) - inference_tp_size = mapping.tp_size - inference_tp_rank = mapping.tp_rank - - apply_layernorm_1p = (nemo_model_config.get('normalization', - '') == "layernorm1p") - split_gated_activation = ("swiglu" - in nemo_model_config.get('activation', "gelu")) - num_attention_heads = nemo_model_config["num_attention_heads"] - # use_attention_nemo_shape = True - transpose_weights = True - # multi_query_mode = False - local_dim = None - - # merge_factor: how many TP training nodes are merged into an inference TP node - # split_factor: in how many parts a TP training node is split - gcd = np.gcd(training_tp_size, inference_tp_size) - merge_factor = training_tp_size // gcd - split_factor = inference_tp_size // gcd - - model_level_weights = defaultdict(list) + # Clean up unpacked nemo checkpoint + shutil.rmtree(nemo_ckpt_dir) - def handle_model_level_weights(model, tp_idx: int, pp_idx: int): - if tp_idx == 0 and pp_idx == 0: - if has_position_embedding: - val = model[ - "model.language_model.embedding.position_embeddings.weight"] - model_level_weights[ - "transformer.position_embedding.weight"].append(val) - if pp_idx == 0: - val = model.get( - "state_dict", - model)["model.language_model.embedding.word_embeddings.weight"] - model_level_weights["transformer.vocab_embedding.weight"].append( - val) - if has_lm_head and pp_idx == training_pp_size - 1: - val = model.get("state_dict", - model)["model.language_model.output_layer.weight"] - model_level_weights["lm_head.weight"].append(val) - weights = {} - tik = time.time() - tp_rank = inference_tp_rank // split_factor - # for tp_rank in range(training_tp_size // merge_factor): - for pp_rank in range(training_pp_size): - models = load_torch_checkpoints(checkpoints_paths, merge_factor, - tp_rank, pp_rank, map_location_fn, - handle_model_level_weights, - layer_rename_config) - for name in list(models[0].keys()): - params = [model[name] for model in models] - if transpose_weights and params[0].ndim == 2: - params = [p.T for p in params] - if "layernorm.weight" in name and apply_layernorm_1p: - params = [p + 1.0 for p in params] - - l = retrieved_layer_index_from_name(name) - if l is not None: - new_l = l + pp_rank * num_layers // training_pp_size - prefix = f'transformer.layers.{new_l}' - - if 'attention.query_key_value' in name: - if name.endswith('weight'): - hidden_dim = params[0].shape[0] - if local_dim is None: - local_dim = params[0].shape[-1] // 3 - - # multi_query_mode = False; use_attention_nemo_shape = True - head_num = num_attention_heads // training_tp_size - size_per_head = hidden_dim // num_attention_heads - params = [ - param.reshape(hidden_dim, head_num, 3, - size_per_head) for param in params - ] - params = [param.permute(0, 2, 1, 3) for param in params] - params = [ - param.reshape(hidden_dim, 3, local_dim) - for param in params - ] - cat_dim = -1 - param = torch.concat(params, dim=cat_dim) - param = torch.chunk(param, split_factor, - dim=cat_dim)[inference_tp_rank % - split_factor] - weights[ - f'{prefix}.attention.qkv.weight'] = param.reshape( - hidden_dim, -1).t() - else: - if local_dim is None: - local_dim = params[0].shape[-1] // 3 - - # multi_query_mode = False; use_attention_nemo_shape = True - head_num = num_attention_heads // training_tp_size - size_per_head = local_dim // head_num - params = [ - param.reshape(head_num, 3, size_per_head) - for param in params - ] - params = [param.permute(1, 0, 2) for param in params] - params = [ - param.reshape(3, local_dim) for param in params - ] - cat_dim = -1 - param = torch.concat(params, dim=cat_dim) - param = torch.chunk(param, split_factor, - dim=cat_dim)[inference_tp_rank % - split_factor] - weights[f'{prefix}.attention.qkv.bias'] = param.reshape( - -1) - - elif 'attention.dense' in name: - if name.endswith('weight'): - cat_dim = 0 - param = torch.concat(params, dim=cat_dim) - param = torch.chunk(param, split_factor, - dim=cat_dim)[inference_tp_rank % - split_factor] - weights[f'{prefix}.attention.dense.weight'] = param.t() - else: - weights[f'{prefix}.attention.dense.bias'] = params[0] - - elif 'mlp.dense_h_to_4h' in name: - if name.endswith('weight'): - if split_gated_activation: - params = [torch.chunk(p, 2, dim=-1) for p in params] - params, gate_params = list(zip(*params)) - cat_dim = -1 - param = torch.concat(params, dim=cat_dim) - param = torch.chunk(param, split_factor, - dim=cat_dim)[inference_tp_rank % - split_factor] - weights[f'{prefix}.mlp.fc.weight'] = param.t() - if split_gated_activation: - gate_param = torch.concat(gate_params, dim=cat_dim) - gate_param = torch.chunk( - gate_param, split_factor, - dim=cat_dim)[inference_tp_rank % split_factor] - weights[f'{prefix}.mlp.gate.weight'] = gate_param.t( - ) - else: - if split_gated_activation: - params = [torch.chunk(p, 2, dim=-1) for p in params] - params, gate_params = list(zip(*params)) - cat_dim = -1 - param = torch.concat(params, dim=cat_dim) - param = torch.chunk(param, split_factor, - dim=cat_dim)[inference_tp_rank % - split_factor] - weights[f'{prefix}.mlp.fc.bias'] = param - if split_gated_activation: - gate_param = torch.concat(gate_params, dim=cat_dim) - gate_param = torch.chunk( - gate_param, split_factor, - dim=cat_dim)[inference_tp_rank % split_factor] - weights[f'{prefix}.mlp.gate.bias'] = gate_param - - elif 'mlp.dense_4h_to_h' in name: - if name.endswith('weight'): - cat_dim = 0 - param = torch.concat(params, dim=cat_dim) - param = torch.chunk(param, split_factor, - dim=cat_dim)[inference_tp_rank % - split_factor] - weights[f'{prefix}.mlp.proj.weight'] = param.t() - else: - weights[f'{prefix}.mlp.proj.bias'] = params[0] - - elif 'input_layernorm' in name: - if name.endswith('weight'): - weights[f'{prefix}.input_layernorm.weight'] = params[0] - else: - weights[f'{prefix}.input_layernorm.bias'] = params[0] - elif 'post_attention_layernorm' in name: - if name.endswith('weight'): - weights[f'{prefix}.post_layernorm.weight'] = params[0] - else: - weights[f'{prefix}.post_layernorm.bias'] = params[0] - - elif 'final_layernorm' in name: - if name.endswith('weight'): - weights['transformer.ln_f.weight'] = params[0] - else: - weights['transformer.ln_f.bias'] = params[0] - for model in models: - del model[name] - del models - for key in list(model_level_weights.keys()): - weights[key] = torch.concat(model_level_weights[key], dim=0) - del model_level_weights[key] - for key, param in weights.items(): - weights[key] = weights[key].to(dtype).contiguous() - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') - return weights - - -if __name__ == '__main__': +def main(): # TODO(qijun): Currently, the convert script depends on a torch op: # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, # which is included in tensorrt_llm Python package. Otherwise, the convert @@ -1901,209 +302,24 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): # the op with PyTorch. print(tensorrt_llm.__version__) args = parse_arguments() - world_size = args.tp_size * args.pp_size + args.tp_size * args.pp_size tik = time.time() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - quant_algo = None - kv_cache_quant_algo = None - plugin_weight_only_quant_type = None - if args.use_weight_only: - if args.weight_only_precision == 'int8': - plugin_weight_only_quant_type = torch.int8 - quant_algo = QuantAlgo.W8A16 - elif args.weight_only_precision == 'int4': - plugin_weight_only_quant_type = torch.quint4x2 - quant_algo = QuantAlgo.W4A16 - elif args.smoothquant: - if args.per_token and args.per_channel: - quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN - elif not args.per_token and not args.per_channel: - quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN - elif not args.per_token and args.per_channel: - quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN - elif args.per_token and not args.per_channel: - quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN - - if args.int8_kv_cache: - kv_cache_quant_algo = QuantAlgo.INT8 - if args.model_dir is not None: - hf_config, gpt_variant = load_gpt_config(args.model_dir, args.tp_size, - args.gpt_variant) + convert_and_save_hf(args) elif args.nemo_ckpt_path is not None: - nemo_dir = Path(args.output_dir) / "unpacked" - nemo_dir = unpack_nemo_ckpt(args.nemo_ckpt_path, nemo_dir) - unpacked_checkpoints_dir = UnpackedNemoCheckpointDir( - nemo_dir, load_checkpoints_to_cpu=not args.load_nemo_on_gpu) - layer_rename_config = { - pattern.split(':')[0]: pattern.split(':')[1] - for pattern in args.nemo_rename_key - } - hf_config, tokenizer_config = load_nemo_gpt_config( - unpacked_checkpoints_dir, layer_rename_config) - copy_tokenizer_files(tokenizer_config, Path(args.output_dir)) - args.use_parallel_embedding = True - args.embedding_sharding_dim = 0 + convert_and_save_nemo(args) else: raise NotImplementedError("No source model path specified!") - config = { - 'architecture': - 'GPTForCausalLM', - 'dtype': - args.dtype, - 'num_hidden_layers': - hf_config.n_layer, - 'num_attention_heads': - hf_config.n_head, - 'num_key_value_heads': - hf_config.n_kv_head, - 'hidden_size': - hf_config.n_embd, - 'intermediate_size': - hf_config.n_inner, - 'norm_epsilon': - hf_config.layer_norm_epsilon, - 'vocab_size': - hf_config.vocab_size, - 'position_embedding_type': - getattr(hf_config, 'position_embedding_type', 'learned_absolute'), - 'max_position_embeddings': - hf_config.n_positions, - 'hidden_act': - hf_config.activation_function, - 'use_parallel_embedding': - args.use_parallel_embedding, - 'embedding_sharding_dim': - args.embedding_sharding_dim, - 'share_embedding_table': - args.use_embedding_sharing, - 'quantization': { - 'quant_algo': quant_algo, - 'kv_cache_quant_algo': kv_cache_quant_algo, - }, - 'mapping': { - 'world_size': world_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - }, - 'bias': - getattr(hf_config, 'bias', True), - 'apply_query_key_layer_scaling': - getattr(hf_config, 'apply_query_key_layer_scaling', False), - 'rotary_pct': - getattr(hf_config, 'rotary_pct', 1.0), - 'rotary_base': - getattr(hf_config, 'rotary_base', 10000.0), - 'rotary_scaling': - getattr(hf_config, 'rotary_scaling', None), - 'qk_layernorm': - args.model_dir is not None and gpt_variant == 'persimmon', - 'inner_layernorm': - args.model_dir is not None and gpt_variant == 'kosmos-2', - 'norm_before_bmm1': - args.model_dir is not None and gpt_variant == 'kosmos-2', - 'q_scaling': - getattr(hf_config, 'q_scaling', 1), - 'embedding_scale': - getattr(hf_config, 'embeddings_scale', None), - } - - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - - if args.model_dir is not None: - if gpt_variant == 'kosmos-2': - hf_model = AutoModelForVision2Seq.from_pretrained( - args.model_dir, trust_remote_code=True) - else: - hf_model = AutoModelForCausalLM.from_pretrained( - args.model_dir, - trust_remote_code=True, - device_map="auto", - torch_dtype="auto") - if args.smoothquant is not None or args.int8_kv_cache: - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - dataset = load_calib_dataset(args.calib_dataset, - cache_dir=args.dataset_cache_dir) - act_range = capture_activation_range(hf_model, tokenizer, dataset) - if args.smoothquant is not None: - smooth_gpt_model(hf_model, act_range, args.smoothquant) - - def convert_and_save(rank): - mapping = Mapping(world_size=world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size) - - if args.model_dir is not None: - if args.smoothquant is not None or args.int8_kv_cache: - weights = convert_hf_gpt_legacy( - hf_model, - hf_config, - gpt_variant, - mapping, - dtype=args.dtype, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - use_smooth_quant=(args.smoothquant is not None), - per_channel=args.per_channel, - per_token=args.per_token, - int8_kv_cache=args.int8_kv_cache, - act_range=act_range, - ) - else: - weights = convert_hf_gpt( - hf_model, - hf_config, - gpt_variant, - mapping, - dtype=args.dtype, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - use_weight_only=args.use_weight_only, - plugin_weight_only_quant_type=plugin_weight_only_quant_type, - ) - - elif args.nemo_ckpt_path is not None: - weights = convert_nemo_gpt(unpacked_checkpoints_dir, mapping, - args.dtype, layer_rename_config) - - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) - - if args.workers == 1: - for rank in range(world_size): - convert_and_save(rank) - else: - with ThreadPoolExecutor(max_workers=args.workers) as p: - futures = [ - p.submit(convert_and_save, rank) for rank in range(world_size) - ] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." - - if args.model_dir is not None: - del hf_model - elif args.nemo_ckpt_path is not None: - shutil.rmtree(nemo_dir) - tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py b/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py index f272c88..dabfe9f 100644 --- a/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py +++ b/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py @@ -5,14 +5,14 @@ import traceback from concurrent.futures import ThreadPoolExecutor, as_completed +from transformers import AutoConfig + import tensorrt_llm from tensorrt_llm._utils import release_gc from tensorrt_llm.layers import MoeConfig +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import LLaMAConfig, LLaMAForCausalLM -from tensorrt_llm.models.convert_utils import has_safetensors -from tensorrt_llm.models.llama.convert import (load_hf_llama, - load_weights_from_gptq) +from tensorrt_llm.models import LLaMAForCausalLM from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization import QuantAlgo @@ -55,6 +55,8 @@ def parse_arguments(): parser.add_argument('--n_kv_head', type=int, default=None) parser.add_argument('--n_embd', type=int, default=4096) parser.add_argument('--inter_size', type=int, default=11008) + parser.add_argument('--multiple_of', type=int, default=None) + parser.add_argument('--ffn_dim_multiplier', type=float, default=None) parser.add_argument('--rms_norm_eps', type=float, default=1e-06) parser.add_argument( @@ -120,11 +122,22 @@ def parse_arguments(): help= 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) + parser.add_argument( + '--fp8_kv_cache', + default=False, + action="store_true", + help= + 'By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV' + ) parser.add_argument( '--quant_ckpt_path', type=str, default=None, help='Path of a quantized model checkpoint in .safetensors format') + parser.add_argument("--use_fp8_rowwise", + action="store_true", + default=False, + help="Enable Fp8 per-token per-channel quantization") parser.add_argument( '--per_group', @@ -208,6 +221,14 @@ def parse_arguments(): help= 'Only save the model config w/o read and converting weights, be careful, this is for debug only' ) + parser.add_argument( + '--remove_duplicated_kv_heads', + action="store_true", + default=False, + help= + 'Only used to remove the duplicated kv heads of llama-3.1 405B HF model.' + ) + parser.add_argument('--log_level', type=str, default='info') args = parser.parse_args() # changing the default to be consistent as the cli help said. @@ -237,10 +258,17 @@ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN else: quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN + elif args.use_fp8_rowwise: + quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN + # this will be overwritten if specified in the hf config. + quant_config.clamp_val = [-1200.0, 1200.0] if args.int8_kv_cache: quant_config.kv_cache_quant_algo = QuantAlgo.INT8 + if args.fp8_kv_cache: + quant_config.kv_cache_quant_algo = QuantAlgo.FP8 + if args.weight_only_precision == 'int4_gptq': quant_config.group_size = args.group_size quant_config.has_zero_point = True @@ -250,6 +278,21 @@ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: return quant_config +def update_quant_config_from_hf(quant_config, hf_config) -> QuantConfig: + hf_config_dict = hf_config.to_dict() + if hf_config_dict.get('quantization_config'): + # update the quant_algo, and clamp_val. + if hf_config_dict['quantization_config'].get( + 'quant_method') == 'fbgemm_fp8': + logger.info( + "Load quantization configs from huggingface model_config.") + quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN + activation_scale_ub = hf_config_dict['quantization_config'].get( + 'activation_scale_ub', 1200.0) + quant_config.clamp_val = [-activation_scale_ub, activation_scale_ub] + return quant_config + + def convert_and_save_meta(args, rank): mapping = Mapping(world_size=args.tp_size * args.pp_size, tp_size=args.tp_size, @@ -257,11 +300,10 @@ def convert_and_save_meta(args, rank): moe_tp_size=args.moe_tp_size, moe_ep_size=args.moe_ep_size, rank=rank) - assert not args_to_quant_config(args).quant_mode.has_any_quant(), \ - "quantization from meta checkpoint or empty model were never supported" llama = LLaMAForCausalLM.from_meta_ckpt( args.meta_ckpt_dir, args.dtype, + quant_config=args_to_quant_config(args), mapping=mapping, use_parallel_embedding=args.use_parallel_embedding, embedding_sharding_dim=args.embedding_sharding_dim) @@ -274,7 +316,10 @@ def args_to_build_options(args): 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, 'disable_weight_only_quant_plugin': - args.disable_weight_only_quant_plugin + args.disable_weight_only_quant_plugin, + 'remove_duplicated_kv_heads': args.remove_duplicated_kv_heads, + 'quant_ckpt_path': args.quant_ckpt_path, + 'load_model_on_cpu': args.load_model_on_cpu, } @@ -288,6 +333,8 @@ def from_cli_args(args): 'num_attention_heads': args.n_head, 'hidden_size': args.n_embd, 'intermediate_size': args.inter_size, + 'ffn_dim_multiplier': args.ffn_dim_multiplier, + 'multiple_of': args.multiple_of, 'num_key_value_heads': n_kv_head, 'vocab_size': args.vocab_size, 'position_embedding_type': 'rope_gpt_neox', @@ -315,7 +362,6 @@ def from_cli_args(args): def convert_and_save_hf(args): model_dir = args.model_dir - load_model_on_cpu = args.load_model_on_cpu load_by_shard = args.load_by_shard world_size = args.tp_size * args.pp_size # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. @@ -326,15 +372,21 @@ def convert_and_save_hf(args): quant_config = args_to_quant_config(args) + try: + hf_config = AutoConfig.from_pretrained(model_dir, + trust_remote_code=True) + quant_config = update_quant_config_from_hf(quant_config, hf_config) + except: + # llava_llama needs its own defined config. + logger.warning("AutoConfig cannot load the huggingface config.") + if args.smoothquant is not None or args.int8_kv_cache: assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported" - mapping = Mapping( - world_size=world_size, - rank=-1, #intentinoally make -1 to avoid mistake - tp_size=args.tp_size, - pp_size=args.pp_size, - moe_tp_size=args.moe_tp_size, - moe_ep_size=args.moe_ep_size) + mapping = Mapping(world_size=world_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size) # TODO: support moe quantization for tp + ep LLaMAForCausalLM.quantize( args.model_dir, @@ -348,15 +400,6 @@ def convert_and_save_hf(args): else: # When not loading by shard, preload one complete model and then slice per rank weights from this # this saves the disk reloading time - - hf_model = None - if "vila" in model_dir or "llava" in model_dir: - hf_model = load_hf_llama(model_dir, load_model_on_cpu) - elif not (args.load_by_shard or - (has_safetensors(model_dir) - and not quant_config.quant_mode.has_any_quant())): - hf_model = load_hf_llama(model_dir, load_model_on_cpu) - def convert_and_save_rank(args, rank): mapping = Mapping(world_size=world_size, rank=rank, @@ -365,7 +408,7 @@ def convert_and_save_rank(args, rank): moe_tp_size=args.moe_tp_size, moe_ep_size=args.moe_ep_size) llama = LLaMAForCausalLM.from_hugging_face( - model_dir if hf_model is None else hf_model, + model_dir, args.dtype, mapping=mapping, quant_config=quant_config, @@ -379,23 +422,6 @@ def convert_and_save_rank(args, rank): release_gc() -def convert_and_save_gptq(args, rank): - mapping = Mapping(world_size=args.tp_size * args.pp_size, - tp_size=args.tp_size, - rank=rank, - pp_size=args.pp_size) - config = LLaMAConfig.from_hugging_face( - args.model_dir, - args.dtype, - mapping=mapping, - quant_config=args_to_quant_config(args), - ) - model = LLaMAForCausalLM(config) - weights = load_weights_from_gptq(args.quant_ckpt_path, config) - model.load(weights) - model.save_checkpoint(args.output_dir, rank == 0) - - def execute(workers, func, args): if workers == 1: for rank, f in enumerate(func): @@ -418,6 +444,7 @@ def execute(workers, func, args): def main(): print(tensorrt_llm.__version__) args = parse_arguments() + logger.set_level(args.log_level) world_size = args.tp_size * args.pp_size if (args.moe_tp_size == -1 and args.moe_ep_size == -1): @@ -443,13 +470,12 @@ def main(): elif args.meta_ckpt_dir is not None: assert args.model_dir is None, "Shall not specify both meta checkpoint dir and hugging face dir" execute(args.workers, [convert_and_save_meta] * world_size, args) - elif args.weight_only_precision == 'int4_gptq': - assert args.model_dir is not None - assert args.quant_ckpt_path is not None - execute(args.workers, [convert_and_save_gptq] * world_size, args) - else: # all other non-gptq paths from hf model + else: # all other paths from hf model assert args.model_dir is not None - assert args.quant_ckpt_path is None, "only gptq weights only needs this option" + assert ( + args.quant_ckpt_path is not None + and args.weight_only_precision == 'int4_gptq' + ) or args.quant_ckpt_path is None, "only gptq weights only needs this option" convert_and_save_hf(args) tok = time.time() diff --git a/tests/test_cli.py b/tests/test_cli.py index a144568..2883e70 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -117,7 +117,9 @@ def test_triton_profile(self, mocker, monkeypatch): monkeypatch.setattr("sys.argv", test_args) args = parse_args() args.func(args) - mock_run.assert_called_once_with(["genai-perf", "-m", "add_sub"], check=True) + mock_run.assert_called_once_with( + ["genai-perf", "profile", "-m", "add_sub"], check=True + ) @pytest.mark.parametrize("model", ["mock_llm"]) def test_triton_metrics(self, model):