diff --git a/README.md b/README.md index cc96537..0514105 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@ and running the CLI from within the latest corresponding `tritonserver` container image, which should have all necessary system dependencies installed. For vLLM and TRT-LLM, you can use their respective images: -- `nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3` -- `nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3` +- `nvcr.io/nvidia/tritonserver:24.09-vllm-python-py3` +- `nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3` If you decide to run the CLI on the host or in a custom image, please see this list of [additional dependencies](#additional-dependencies-for-custom-environments) @@ -38,13 +38,14 @@ matrix below: | Triton CLI Version | TRT-LLM Version | Triton Container Tag | |:------------------:|:---------------:|:--------------------:| +| 0.1.0 | v0.13.0 | 24.09 | | 0.0.11 | v0.12.0 | 24.08 | | 0.0.10 | v0.11.0 | 24.07 | -| 0.0.9 | v0.10.0 | 24.06 | -| 0.0.8 | v0.9.0 | 24.05 | -| 0.0.7 | v0.9.0 | 24.04 | -| 0.0.6 | v0.8.0 | 24.02, 24.03 | -| 0.0.5 | v0.7.1 | 24.01 | +| 0.0.9 | v0.10.0 | 24.06 | +| 0.0.8 | v0.9.0 | 24.05 | +| 0.0.7 | v0.9.0 | 24.04 | +| 0.0.6 | v0.8.0 | 24.02, 24.03 | +| 0.0.5 | v0.7.1 | 24.01 | ### Install from GitHub @@ -58,7 +59,7 @@ It is also possible to install from a specific branch name, a commit hash or a tag name. For example to install `triton_cli` with a specific tag: ```bash -GIT_REF="0.0.11" +GIT_REF="0.1.0" pip install git+https://github.com/triton-inference-server/triton_cli.git@${GIT_REF} ``` @@ -93,7 +94,7 @@ triton -h triton import -m gpt2 # Start server pointing at the default model repository -triton start --image nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3 +triton start --image nvcr.io/nvidia/tritonserver:24.09-vllm-python-py3 # Infer with CLI triton infer -m gpt2 --prompt "machine learning is" @@ -119,26 +120,50 @@ minutes. > in Huggingface through either `huggingface-cli login` or setting the `HF_TOKEN` > environment variable. +### Model Sources -### Serving a vLLM Model + -vLLM models will be downloaded at runtime when starting the server if not found -locally in the HuggingFace cache. No offline engine building step is required, -but you can pre-download the model in advance to avoid downloading at server -startup time. +The `triton import` command helps automate the process of creating a model repository +to serve with Triton Inference Server. When preparing models, a `--source` is required +to point at the location containing a model/weights. This argument is overloaded to support +a few types of locations: +- HuggingFace (`--source hf:`) +- Local Filesystem (`--source local:`) + +#### Model Source Aliases -The following models have currently been tested for vLLM through the CLI: + + +For convenience, the Triton CLI supports short aliases for a handful +of models which will automatically set the correct `--source` for you. +A full list of aliases can be found from `KNOWN_MODEL_SOURCES` within `parser.py`, +but some examples can be found below: - `gpt2` - `opt125m` - `mistral-7b` -- `falcon-7b` -- `llama-2-7b` - `llama-2-7b-chat` -- `llama-3-8b` - `llama-3-8b-instruct` -- `llama-3.1-8b` - `llama-3.1-8b-instruct` +For example, this command will go get Llama 3.1 8B Instruct from HuggingFace: +```bash +triton import -m llama-3.1-8b-instruct + +# Equivalent command without alias: +# triton import --model llama-3.1-8b-instruct --source "hf:meta-llama/Llama-3.1-8B-Instruct" +``` + +For full control and flexibility, you can always manually specify the `--source`. + +### Serving a vLLM Model + +vLLM models will be downloaded at runtime when starting the server if not found +locally in the HuggingFace cache. No offline engine building step is required, +but you can pre-download the model in advance to avoid downloading at server +startup time. + +The following models are supported by vLLM: https://docs.vllm.ai/en/latest/models/supported_models.html #### Example @@ -149,10 +174,10 @@ docker run -ti \ --shm-size=1g --ulimit memlock=-1 \ -v ${HOME}/models:/root/models \ -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ - nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3 + nvcr.io/nvidia/tritonserver:24.09-vllm-python-py3 # Install the Triton CLI -pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.11 +pip install git+https://github.com/triton-inference-server/triton_cli.git@0.1.0 # Authenticate with huggingface for restricted models like Llama-2 and Llama-3 huggingface-cli login @@ -189,15 +214,7 @@ triton profile -m llama-3-8b-instruct --backend vllm > see [here](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#instance-groups). The following models are currently supported for automating TRT-LLM -engine builds through the CLI: -- `gpt2` -- `opt125m` -- `llama-2-7b` -- `llama-2-7b-chat` -- `llama-3-8b` -- `llama-3-8b-instruct` -- `llama-3.1-8b` -- `llama-3.1-8b-instruct` +engine builds through the CLI: https://nvidia.github.io/TensorRT-LLM/llm-api-examples/index.html#supported-models > [!NOTE] > 1. Building a TRT-LLM engine for Llama-2-7B, Llama-3-8B, or Llama-3.1-8B @@ -222,10 +239,10 @@ docker run -ti \ -v /tmp:/tmp \ -v ${HOME}/models:/root/models \ -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ - nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3 + nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3 # Install the Triton CLI -pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.11 +pip install git+https://github.com/triton-inference-server/triton_cli.git@0.1.0 # Authenticate with huggingface for restricted models like Llama-2 and Llama-3 huggingface-cli login @@ -282,5 +299,3 @@ and may not be as optimized as possible for your system or use case. - Triton CLI currently uses the TRT-LLM dependencies installed in its environment to build TRT-LLM engines, so you must take care to match the build-time and run-time versions of TRT-LLM. -- Triton CLI currently does not support launching the server as a background -process. diff --git a/pyproject.toml b/pyproject.toml index 080aa17..dfacafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,10 +47,14 @@ keywords = [] requires-python = ">=3.10,<4" # TODO: Add [gpu] set of dependencies for trtllm once it's available on pypi dependencies = [ - "grpcio>=1.65.5", + # Client deps - generally versioned together + "grpcio>=1.66.1", + # Use explicit client version matching genai-perf version for tagged release + "tritonclient[all] == 2.50", + "genai-perf @ git+https://github.com/triton-inference-server/perf_analyzer.git@r24.09#subdirectory=genai-perf", + # Misc deps "directory-tree == 0.0.4", # may remove in future "docker == 6.1.3", - "genai-perf @ git+https://github.com/triton-inference-server/perf_analyzer.git@r24.08#subdirectory=genai-perf", # TODO: rely on tritonclient to pull in protobuf and numpy dependencies? "numpy >=1.21,<2", "protobuf>=3.7.0", @@ -58,8 +62,6 @@ dependencies = [ "psutil >= 5.9.5", # may remove later "rich == 13.5.2", # TODO: Test on cpu-only machine if [cuda] dependency is an issue, - # Use explicit client version matching genai-perf version for tagged release - "tritonclient[all] == 2.49", "huggingface-hub >= 0.19.4", # Testing "pytest >= 8.1.1", # may remove later diff --git a/src/triton_cli/__init__.py b/src/triton_cli/__init__.py index 1adf4f9..dae273d 100644 --- a/src/triton_cli/__init__.py +++ b/src/triton_cli/__init__.py @@ -24,4 +24,4 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -__version__ = "0.0.11" +__version__ = "0.1.0" diff --git a/src/triton_cli/docker/Dockerfile b/src/triton_cli/docker/Dockerfile index 7bc4c87..4701f21 100644 --- a/src/triton_cli/docker/Dockerfile +++ b/src/triton_cli/docker/Dockerfile @@ -1,9 +1,9 @@ # TRT-LLM image contains engine building and runtime dependencies -FROM nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3 +FROM nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3 # Setup vLLM Triton backend RUN mkdir -p /opt/tritonserver/backends/vllm && \ - git clone -b r24.08 https://github.com/triton-inference-server/vllm_backend.git /tmp/vllm_backend && \ + git clone -b r24.09 https://github.com/triton-inference-server/vllm_backend.git /tmp/vllm_backend && \ cp -r /tmp/vllm_backend/src/* /opt/tritonserver/backends/vllm && \ rm -r /tmp/vllm_backend diff --git a/src/triton_cli/parser.py b/src/triton_cli/parser.py index 077a1ce..9a6b085 100755 --- a/src/triton_cli/parser.py +++ b/src/triton_cli/parser.py @@ -228,8 +228,7 @@ def parse_args_repo(parser): "--source", type=str, required=False, - help="Local model path or model identifier. Use prefix 'hf:' to specify a HuggingFace model ID. " - "NOTE: HuggingFace model support is currently limited to Transformer models through the vLLM backend.", + help="Local model path or model identifier. Use prefix 'hf:' to specify a HuggingFace model ID, or 'local:' prefix to specify a file path to a model.", ) repo_remove = parser.add_parser("remove", help="Remove model from model repository") diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index 7e475ac..66ecbdf 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -30,7 +30,6 @@ import logging import subprocess from pathlib import Path -from rich.console import Console from directory_tree import display_tree @@ -41,7 +40,6 @@ TritonCLIException, ) from triton_cli.trt_llm.engine_config_parser import parse_and_substitute -from triton_cli.trt_llm.builder import TRTLLMBuilder from huggingface_hub import snapshot_download from huggingface_hub import utils as hf_utils @@ -66,6 +64,7 @@ SOURCE_PREFIX_HUGGINGFACE = "hf:" SOURCE_PREFIX_NGC = "ngc:" +SOURCE_PREFIX_LOCAL = "local:" TRT_TEMPLATES_PATH = Path(__file__).parent / "templates" / "trt_llm" @@ -75,35 +74,6 @@ HF_TOKEN_PATH = Path.home() / ".cache" / "huggingface" / "token" -# TODO: Improve this flow and reduce hard-coded model check locations -SUPPORTED_TRT_LLM_BUILDERS = { - "facebook/opt-125m": { - "hf_allow_patterns": ["*.bin", "*.json", "*.txt"], - }, - "meta-llama/Llama-2-7b-hf": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - }, - "meta-llama/Llama-2-7b-chat-hf": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - }, - "meta-llama/Meta-Llama-3-8B": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - }, - "meta-llama/Meta-Llama-3-8B-Instruct": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - }, - "meta-llama/Meta-Llama-3.1-8B": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - }, - "meta-llama/Meta-Llama-3.1-8B-Instruct": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - }, - "gpt2": { - "hf_allow_patterns": ["*.safetensors", "*.json"], - "hf_ignore_patterns": ["onnx/*"], - }, -} - # NOTE: Thin wrapper around NGC CLI is a WAR for now. # TODO: Move out to generic files/interface for remote model stores @@ -206,11 +176,19 @@ def add( backend = "tensorrtllm" # Local model path else: - logger.debug("No supported prefix detected, assuming local path") + if source.startswith(SOURCE_PREFIX_LOCAL): + logger.debug("Local prefix detected, parsing local file path") + else: + logger.info( + "No supported --source prefix detected, assuming local path" + ) + source_type = "local" model_path = Path(source) if not model_path.exists(): - raise TritonCLIException(f"{model_path} does not exist") + raise TritonCLIException( + f"Local file path '{model_path}' provided by --source does not exist" + ) model_dir, version_dir = self.__create_model_repository(name, version, backend) @@ -349,23 +327,15 @@ def __generate_ngc_model(self, name: str, source: str): str(self.repo), name, engines_path, engines_path, "auto", dry_run=False ) - def __generate_trtllm_model(self, name, huggingface_id): - builder_info = SUPPORTED_TRT_LLM_BUILDERS.get(huggingface_id) - if not builder_info: - raise TritonCLIException( - f"Building a TRT LLM engine for {huggingface_id} is not currently supported." - ) - + def __generate_trtllm_model(self, name: str, huggingface_id: str): engines_path = ENGINE_DEST_PATH + "/" + name - hf_download_path = ENGINE_DEST_PATH + "/" + name + "/hf_download" - engines = [engine for engine in Path(engines_path).glob("*.engine")] if engines: logger.warning( f"Found existing engine(s) at {engines_path}, skipping build." ) else: - self.__build_trtllm_engine(huggingface_id, hf_download_path, engines_path) + self.__build_trtllm_engine(huggingface_id, engines_path) # NOTE: In every case, the TRT LLM template should be filled in with values. # If the model exists, the CLI will raise an exception when creating the model repo. @@ -375,30 +345,26 @@ def __generate_trtllm_model(self, name, huggingface_id): triton_model_dir=str(self.repo), bls_model_name=name, engine_dir=engines_path, - token_dir=hf_download_path, + token_dir=engines_path, token_type="auto", dry_run=False, ) - def __build_trtllm_engine(self, huggingface_id, hf_download_path, engines_path): - builder_info = SUPPORTED_TRT_LLM_BUILDERS.get(huggingface_id) - hf_allow_patterns = builder_info["hf_allow_patterns"] - hf_ignore_patterns = builder_info.get("hf_ignore_patterns", None) - self.__download_hf_model( - huggingface_id, - hf_download_path, - allow_patterns=hf_allow_patterns, - ignore_patterns=hf_ignore_patterns, - ) - - builder = TRTLLMBuilder( - huggingface_id=huggingface_id, - hf_download_path=hf_download_path, - engine_output_path=engines_path, - ) - console = Console() - with console.status(f"Building TRT-LLM engine for {huggingface_id}..."): - builder.build() + def __build_trtllm_engine(self, huggingface_id: str, engines_path: Path): + from tensorrt_llm import LLM, BuildConfig + + # NOTE: Given config.json, can read from 'build_config' section and from_dict + config = BuildConfig() + # TODO: Expose more build args to user + # TODO: Discuss LLM API BuildConfig defaults + # NOTE: Using some defaults from trtllm-build because LLM API defaults are too low + config.max_input_len = 1024 + config.max_seq_len = 8192 + config.max_batch_size = 256 + + engine = LLM(huggingface_id, build_config=config) + # TODO: Investigate if LLM is internally saving a copy to a temp dir + engine.save(str(engines_path)) def __create_model_repository( self, name: str, version: int = 1, backend: str = None diff --git a/src/triton_cli/trt_llm/builder.py b/src/triton_cli/trt_llm/builder.py deleted file mode 100644 index 80d2482..0000000 --- a/src/triton_cli/trt_llm/builder.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import logging -import subprocess -from pathlib import Path - -from triton_cli.common import LOGGER_NAME - -logger = logging.getLogger(LOGGER_NAME) - -CHECKPOINT_MODULE_MAP = { - "meta-llama/Llama-2-7b-hf": "llama", - "meta-llama/Llama-2-7b-chat-hf": "llama", - "meta-llama/Meta-Llama-3-8B": "llama", - "meta-llama/Meta-Llama-3-8B-Instruct": "llama", - "meta-llama/Meta-Llama-3.1-8B": "llama", - "meta-llama/Meta-Llama-3.1-8B-Instruct": "llama", - "facebook/opt-125m": "opt", - "gpt2": "gpt2", -} - - -class TRTLLMBuilder: - def __init__(self, huggingface_id, hf_download_path, engine_output_path): - self.checkpoint_id = CHECKPOINT_MODULE_MAP[huggingface_id] - self.hf_download_path = hf_download_path - self.converted_weights_path = self.hf_download_path + "/converted_weights" - self.engine_output_path = engine_output_path - - # TODO: User should be able to specify a what parameters they want to use to build a - # TRT LLM engine. A input JSON should be suitable for this goal. - def build(self): - self._convert_checkpoint() - self._trtllm_build() - - def _convert_checkpoint(self): - if Path(self.converted_weights_path).exists(): - logger.info( - f"Converted weights path {self.converted_weights_path} already exists, skipping checkpoint conversion." - ) - return - - weight_conversion_args = [ - "--model_dir", - self.hf_download_path, - "--output_dir", - self.converted_weights_path, - "--dtype=float16", - ] - - # Need to specify gpt variant for gpt models - if self.checkpoint_id in ["gpt2"]: - weight_conversion_args += ["--gpt_variant", self.checkpoint_id] - - ckpt_script = ( - Path(__file__).resolve().parent - / "checkpoint_scripts" - / self.checkpoint_id - / "convert_checkpoint.py" - ) - cmd = ["python3", str(ckpt_script)] + weight_conversion_args - cmd_str = " ".join(cmd) - logger.info(f"Running '{cmd_str}'") - subprocess.run(cmd, check=True) - - def _trtllm_build(self): - # TODO: Move towards config-driven build args per-model - build_args = [ - f"--checkpoint_dir={self.converted_weights_path}", - f"--output_dir={self.engine_output_path}", - "--gpt_attention_plugin=float16", - "--gemm_plugin=float16", - ] - - cmd = ["trtllm-build"] + build_args - cmd_str = " ".join(cmd) - logger.info(f"Running '{cmd_str}'") - subprocess.run(cmd, check=True) diff --git a/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py b/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py deleted file mode 100644 index 2038a61..0000000 --- a/src/triton_cli/trt_llm/checkpoint_scripts/gpt2/convert_checkpoint.py +++ /dev/null @@ -1,325 +0,0 @@ -import argparse -import os -import shutil -import time -import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path - -import tensorrt_llm -from tensorrt_llm._utils import release_gc -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import GPTForCausalLM -from tensorrt_llm.models.gpt.convert import (UnpackedNemoCheckpointDir, - copy_tokenizer_files, load_hf_gpt, - unpack_nemo_ckpt, - update_tokenizer_paths) -from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization import QuantAlgo - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument('--nemo_ckpt_path', type=str, default=None) - parser.add_argument( - '--gpt_variant', - default=None, - choices=[ - None, 'gpt2', 'santacoder', 'starcoder', 'starcoder2', 'persimmon', - 'kosmos-2' - ], - help= - "By default the script will try to infer the gpt_variant from model_dir. " - "Or users may overwrite gpt_variant by explicitly passing the variant.") - parser.add_argument('--tp_size', - type=int, - default=1, - help='N-way tensor parallelism size') - parser.add_argument('--pp_size', - type=int, - default=1, - help='N-way pipeline parallelism size') - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float32', 'bfloat16', 'float16']) - parser.add_argument("--load_model_on_cpu", action="store_true") - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=0, - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) - parser.add_argument( - '--use_embedding_sharing', - action="store_true", - default=False, - help= - 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' - 'Note: the flag might not take effect when the criteria are not met.') - - parser.add_argument( - '--use_weight_only', - default=False, - action="store_true", - help='Quantize weights for the various GEMMs to INT4/INT8.' - 'See --weight_only_precision to set the precision') - parser.add_argument( - '--weight_only_precision', - const='int8', - type=str, - nargs='?', - default='int8', - choices=['int8', 'int4'], - help= - 'Define the precision for the weights when using weight-only quantization.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - - parser.add_argument( - '--calib_dataset', - type=str, - default='lambada', - help= - "The huggingface dataset name or the local directory of the dataset for calibration." - ) - parser.add_argument( - '--int8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' - ) - parser.add_argument( - '--per_channel', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' - 'per_channel instead uses a different static scaling factor for each channel. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--per_token', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' - 'per_token chooses at run time, and for each token, a custom scaling factor. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - "--smoothquant", - "-sq", - type=float, - default=None, - help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" - " to Smoothquant the model, and output int8 weights." - " A good first try is 0.5. Must be in [0, 1]") - parser.add_argument("--dataset_cache_dir", - type=str, - default=None, - help="cache dir to load the hugging face dataset") - parser.add_argument('--output_dir', - type=str, - default='tllm_checkpoint', - help='The path to save the TensorRT-LLM checkpoint') - parser.add_argument( - '--workers', - type=int, - default=1, - help='The number of workers for converting checkpoint in parallel') - parser.add_argument('--log_level', type=str, default='info') - parser.add_argument( - '--nemo_rename_key', - type=str, - nargs='+', - default=[], - help= - "Change a layer name when loading a NeMo checkpoint. Should follow :" - ) - - args = parser.parse_args() - - tensorrt_llm.logger.set_level(args.log_level) - return args - - -def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: - '''return config dict with quantization info based on the command line args - ''' - quant_config = QuantConfig() - if args.use_weight_only: - if args.weight_only_precision == 'int8': - quant_config.quant_algo = QuantAlgo.W8A16 - elif args.weight_only_precision == 'int4': - quant_config.quant_algo = QuantAlgo.W4A16 - elif args.smoothquant: - quant_config.smoothquant_val = args.smoothquant - if args.per_channel: - if args.per_token: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN - else: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN - else: - if args.per_token: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN - else: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN - - if args.int8_kv_cache: - quant_config.kv_cache_quant_algo = QuantAlgo.INT8 - - return quant_config - - -def convert_and_save_hf(args): - model_dir = args.model_dir - load_model_on_cpu = args.load_model_on_cpu - world_size = args.tp_size * args.pp_size - - override_fields = { - 'use_parallel_embedding': args.use_parallel_embedding, - 'embedding_sharding_dim': args.embedding_sharding_dim, - 'share_embedding_table': args.use_embedding_sharing, - } - - quant_config = args_to_quant_config(args) - - if args.smoothquant is not None or args.int8_kv_cache: - mapping = Mapping(world_size=world_size, - tp_size=args.tp_size, - pp_size=args.pp_size) - GPTForCausalLM.quantize( - args.model_dir, - args.output_dir, - dtype=args.dtype, - mapping=mapping, - quant_config=quant_config, - device='cpu' if args.load_model_on_cpu else 'cuda', - calib_dataset=args.calib_dataset, - **override_fields) - else: - hf_model = load_hf_gpt(model_dir, load_model_on_cpu) - - def convert_and_save_rank(args, rank): - mapping = Mapping(world_size=world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size) - model = GPTForCausalLM.from_hugging_face(hf_model, - args.dtype, - mapping=mapping, - quant_config=quant_config, - gpt_variant=args.gpt_variant, - **override_fields) - model.save_checkpoint(args.output_dir, save_config=(rank == 0)) - del model - - execute(args.workers, [convert_and_save_rank] * world_size, args) - release_gc() - - -def execute(workers, func, args): - if workers == 1: - for rank, f in enumerate(func): - f(args, rank) - else: - with ThreadPoolExecutor(max_workers=workers) as p: - futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." - - -def convert_and_save_nemo(args): - world_size = args.tp_size * args.pp_size - quant_config = args_to_quant_config(args) - - override_fields = { - 'use_parallel_embedding': True, - 'embedding_sharding_dim': 0, - 'share_embedding_table': args.use_embedding_sharing, - } - - nemo_ckpt_dir = os.path.join(args.output_dir, "unpacked") - nemo_ckpt_dir = unpack_nemo_ckpt(args.nemo_ckpt_path, nemo_ckpt_dir) - - def convert_and_save_rank(args, rank): - mapping = Mapping(world_size=world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size) - model = GPTForCausalLM.from_nemo( - nemo_ckpt_dir, - dtype=args.dtype, - mapping=mapping, - quant_config=quant_config, - load_model_on_cpu=args.load_model_on_cpu, - nemo_rename_key=args.nemo_rename_key, - **override_fields) - model.save_checkpoint(args.output_dir, save_config=(rank == 0)) - del model - - execute(args.workers, [convert_and_save_rank] * world_size, args) - release_gc() - - # Copy tokenizer files - unpacked_checkpoints_dir = UnpackedNemoCheckpointDir( - nemo_ckpt_dir, load_checkpoints_to_cpu=args.load_model_on_cpu) - nemo_model_config = unpacked_checkpoints_dir.model_config - tokenizer_config = update_tokenizer_paths( - nemo_model_config["tokenizer"], - unpacked_checkpoints_dir.get_all_tokenizer_file_paths()) - copy_tokenizer_files(tokenizer_config, Path(args.output_dir)) - - # Clean up unpacked nemo checkpoint - shutil.rmtree(nemo_ckpt_dir) - - -def main(): - # TODO(qijun): Currently, the convert script depends on a torch op: - # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, - # which is included in tensorrt_llm Python package. Otherwise, the convert - # script does not need to import tensorrt_llm. Will remove it after reimplementing - # the op with PyTorch. - print(tensorrt_llm.__version__) - args = parse_arguments() - args.tp_size * args.pp_size - - tik = time.time() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - if args.model_dir is not None: - convert_and_save_hf(args) - elif args.nemo_ckpt_path is not None: - convert_and_save_nemo(args) - else: - raise NotImplementedError("No source model path specified!") - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Total time of converting checkpoints: {t}') - - -if __name__ == '__main__': - main() diff --git a/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py b/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py deleted file mode 100644 index dabfe9f..0000000 --- a/src/triton_cli/trt_llm/checkpoint_scripts/llama/convert_checkpoint.py +++ /dev/null @@ -1,487 +0,0 @@ -import argparse -import json -import os -import time -import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed - -from transformers import AutoConfig - -import tensorrt_llm -from tensorrt_llm._utils import release_gc -from tensorrt_llm.layers import MoeConfig -from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import LLaMAForCausalLM -from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization import QuantAlgo - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument('--meta_ckpt_dir', type=str, default=None) - - parser.add_argument('--tp_size', - type=int, - default=1, - help='N-way tensor parallelism size') - parser.add_argument('--pp_size', - type=int, - default=1, - help='N-way pipeline parallelism size') - parser.add_argument( - '--moe_tp_size', - type=int, - default=-1, - help= - 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' - ) - parser.add_argument( - '--moe_ep_size', - type=int, - default=-1, - help= - 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' - ) - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float32', 'bfloat16', 'float16']) - parser.add_argument('--vocab_size', type=int, default=32000) - parser.add_argument('--n_positions', type=int, default=2048) - parser.add_argument('--n_layer', type=int, default=32) - parser.add_argument('--n_head', type=int, default=32) - parser.add_argument('--n_kv_head', type=int, default=None) - parser.add_argument('--n_embd', type=int, default=4096) - parser.add_argument('--inter_size', type=int, default=11008) - parser.add_argument('--multiple_of', type=int, default=None) - parser.add_argument('--ffn_dim_multiplier', type=float, default=None) - parser.add_argument('--rms_norm_eps', type=float, default=1e-06) - - parser.add_argument( - '--use_weight_only', - default=False, - action="store_true", - help='Quantize weights for the various GEMMs to INT4/INT8.' - 'See --weight_only_precision to set the precision') - parser.add_argument( - '--disable_weight_only_quant_plugin', - default=False, - action="store_true", - help= - 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - '--weight_only_precision', - const='int8', - type=str, - nargs='?', - default='int8', - choices=['int8', 'int4', 'int4_gptq'], - help= - 'Define the precision for the weights when using weight-only quantization.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - '--calib_dataset', - type=str, - default='ccdv/cnn_dailymail', - help= - "The huggingface dataset name or the local directory of the dataset for calibration." - ) - parser.add_argument( - "--smoothquant", - "-sq", - type=float, - default=None, - help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" - " to Smoothquant the model, and output int8 weights." - " A good first try is 0.5. Must be in [0, 1]") - parser.add_argument( - '--per_channel', - action="store_true", - default=False, - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' - 'per_channel instead uses a different static scaling factor for each channel. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--per_token', - action="store_true", - default=False, - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' - 'per_token chooses at run time, and for each token, a custom scaling factor. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--int8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' - ) - parser.add_argument( - '--fp8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV' - ) - parser.add_argument( - '--quant_ckpt_path', - type=str, - default=None, - help='Path of a quantized model checkpoint in .safetensors format') - parser.add_argument("--use_fp8_rowwise", - action="store_true", - default=False, - help="Enable Fp8 per-token per-channel quantization") - - parser.add_argument( - '--per_group', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor to scale weights in the int4 range. ' - 'per_group chooses at run time, and for each group, a custom scaling factor. ' - 'The flag is built for GPTQ/AWQ quantization.') - - parser.add_argument('--load_by_shard', - action='store_true', - help='Load a pretrained model shard-by-shard.') - parser.add_argument('--hidden_act', type=str, default='silu') - - parser.add_argument('--rotary_base', type=float, default=10000.0) - - parser.add_argument('--group_size', - type=int, - default=128, - help='Group size used in GPTQ quantization.' - ) # AWQ is only supported by quantize.py script - - parser.add_argument("--load_model_on_cpu", action="store_true") - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=0, - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) - parser.add_argument( - '--use_embedding_sharing', - action="store_true", - default=False, - help= - 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' - 'Note: the flag might not take effect when the criteria are not met.') - parser.add_argument('--output_dir', - type=str, - default='tllm_checkpoint', - help='The path to save the TensorRT-LLM checkpoint') - parser.add_argument( - '--workers', - type=int, - default=1, - help='The number of workers for converting checkpoint in parallel') - parser.add_argument( - '--moe_num_experts', - default=0, - type=int, - help='Specify the number of experts to use for MOE layers') - parser.add_argument( - '--moe_top_k', - default=0, - type=int, - help= - 'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set' - ) - parser.add_argument( - '--moe_renorm_mode', - default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, - type=int, - help= - 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', - ) - parser.add_argument( - '--save_config_only', - action="store_true", - default=False, - help= - 'Only save the model config w/o read and converting weights, be careful, this is for debug only' - ) - parser.add_argument( - '--remove_duplicated_kv_heads', - action="store_true", - default=False, - help= - 'Only used to remove the duplicated kv heads of llama-3.1 405B HF model.' - ) - parser.add_argument('--log_level', type=str, default='info') - - args = parser.parse_args() - # changing the default to be consistent as the cli help said. - if args.moe_num_experts and args.moe_top_k == 0: - args.moe_top_k = 1 - return args - - -def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: - '''return config dict with quantization info based on the command line args - ''' - quant_config = QuantConfig() - if args.use_weight_only: - if args.weight_only_precision == 'int8': - quant_config.quant_algo = QuantAlgo.W8A16 - elif args.weight_only_precision == 'int4': - quant_config.quant_algo = QuantAlgo.W4A16 - elif args.smoothquant: - quant_config.smoothquant_val = args.smoothquant - if args.per_channel: - if args.per_token: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN - else: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN - else: - if args.per_token: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN - else: - quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN - elif args.use_fp8_rowwise: - quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN - # this will be overwritten if specified in the hf config. - quant_config.clamp_val = [-1200.0, 1200.0] - - if args.int8_kv_cache: - quant_config.kv_cache_quant_algo = QuantAlgo.INT8 - - if args.fp8_kv_cache: - quant_config.kv_cache_quant_algo = QuantAlgo.FP8 - - if args.weight_only_precision == 'int4_gptq': - quant_config.group_size = args.group_size - quant_config.has_zero_point = True - quant_config.pre_quant_scale = False - quant_config.quant_algo = QuantAlgo.W4A16_GPTQ - - return quant_config - - -def update_quant_config_from_hf(quant_config, hf_config) -> QuantConfig: - hf_config_dict = hf_config.to_dict() - if hf_config_dict.get('quantization_config'): - # update the quant_algo, and clamp_val. - if hf_config_dict['quantization_config'].get( - 'quant_method') == 'fbgemm_fp8': - logger.info( - "Load quantization configs from huggingface model_config.") - quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN - activation_scale_ub = hf_config_dict['quantization_config'].get( - 'activation_scale_ub', 1200.0) - quant_config.clamp_val = [-activation_scale_ub, activation_scale_ub] - return quant_config - - -def convert_and_save_meta(args, rank): - mapping = Mapping(world_size=args.tp_size * args.pp_size, - tp_size=args.tp_size, - pp_size=args.pp_size, - moe_tp_size=args.moe_tp_size, - moe_ep_size=args.moe_ep_size, - rank=rank) - llama = LLaMAForCausalLM.from_meta_ckpt( - args.meta_ckpt_dir, - args.dtype, - quant_config=args_to_quant_config(args), - mapping=mapping, - use_parallel_embedding=args.use_parallel_embedding, - embedding_sharding_dim=args.embedding_sharding_dim) - llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) - - -def args_to_build_options(args): - return { - 'use_parallel_embedding': args.use_parallel_embedding, - 'embedding_sharding_dim': args.embedding_sharding_dim, - 'share_embedding_table': args.use_embedding_sharing, - 'disable_weight_only_quant_plugin': - args.disable_weight_only_quant_plugin, - 'remove_duplicated_kv_heads': args.remove_duplicated_kv_heads, - 'quant_ckpt_path': args.quant_ckpt_path, - 'load_model_on_cpu': args.load_model_on_cpu, - } - - -def from_cli_args(args): - n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head - config = { - 'architecture': "LlamaForCausalLM", - 'dtype': args.dtype, - 'logits_dtype': 'float32', - 'num_hidden_layers': args.n_layer, - 'num_attention_heads': args.n_head, - 'hidden_size': args.n_embd, - 'intermediate_size': args.inter_size, - 'ffn_dim_multiplier': args.ffn_dim_multiplier, - 'multiple_of': args.multiple_of, - 'num_key_value_heads': n_kv_head, - 'vocab_size': args.vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': args.n_positions, - 'hidden_act': args.hidden_act, - 'rotary_base': args.rotary_base, - 'norm_epsilon': args.rms_norm_eps, - 'moe': { - 'num_experts': args.moe_num_experts, - 'top_k': args.moe_top_k, - 'normalization_mode': args.moe_renorm_mode, - }, - 'mapping': { - 'world_size': args.tp_size * args.pp_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - 'moe_tp_size': args.moe_tp_size, - 'moe_ep_size': args.moe_ep_size, - }, - 'quantization': args_to_quant_config(args).to_dict() - } - config.update(args_to_build_options(args)) - return config - - -def convert_and_save_hf(args): - model_dir = args.model_dir - load_by_shard = args.load_by_shard - world_size = args.tp_size * args.pp_size - # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. - # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, - # before the refactor is done. - override_fields = {} - override_fields.update(args_to_build_options(args)) - - quant_config = args_to_quant_config(args) - - try: - hf_config = AutoConfig.from_pretrained(model_dir, - trust_remote_code=True) - quant_config = update_quant_config_from_hf(quant_config, hf_config) - except: - # llava_llama needs its own defined config. - logger.warning("AutoConfig cannot load the huggingface config.") - - if args.smoothquant is not None or args.int8_kv_cache: - assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported" - mapping = Mapping(world_size=world_size, - tp_size=args.tp_size, - pp_size=args.pp_size, - moe_tp_size=args.moe_tp_size, - moe_ep_size=args.moe_ep_size) - # TODO: support moe quantization for tp + ep - LLaMAForCausalLM.quantize( - args.model_dir, - args.output_dir, - dtype=args.dtype, - mapping=mapping, - quant_config=quant_config, - device='cpu' if args.load_model_on_cpu else 'cuda', - calib_dataset=args.calib_dataset, - **override_fields) - else: - # When not loading by shard, preload one complete model and then slice per rank weights from this - # this saves the disk reloading time - def convert_and_save_rank(args, rank): - mapping = Mapping(world_size=world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size, - moe_tp_size=args.moe_tp_size, - moe_ep_size=args.moe_ep_size) - llama = LLaMAForCausalLM.from_hugging_face( - model_dir, - args.dtype, - mapping=mapping, - quant_config=quant_config, - load_by_shard=load_by_shard, - **override_fields, - ) - llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) - del llama - - execute(args.workers, [convert_and_save_rank] * world_size, args) - release_gc() - - -def execute(workers, func, args): - if workers == 1: - for rank, f in enumerate(func): - f(args, rank) - else: - with ThreadPoolExecutor(max_workers=workers) as p: - futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." - - -def main(): - print(tensorrt_llm.__version__) - args = parse_arguments() - logger.set_level(args.log_level) - - world_size = args.tp_size * args.pp_size - if (args.moe_tp_size == -1 and args.moe_ep_size == -1): - # moe default to tp-only - args.moe_tp_size = args.tp_size - args.moe_ep_size = 1 - elif (args.moe_tp_size == -1): - args.moe_tp_size = args.tp_size // args.moe_ep_size - elif (args.moe_ep_size == -1): - args.moe_ep_size = args.tp_size // args.moe_tp_size - assert (args.moe_tp_size * args.moe_ep_size == args.tp_size - ), "moe_tp_size * moe_ep_size must equal to tp_size" - tik = time.time() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - if (args.model_dir is None - and args.meta_ckpt_dir is None): # generate fake config.json - config = from_cli_args(args) - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - elif args.meta_ckpt_dir is not None: - assert args.model_dir is None, "Shall not specify both meta checkpoint dir and hugging face dir" - execute(args.workers, [convert_and_save_meta] * world_size, args) - else: # all other paths from hf model - assert args.model_dir is not None - assert ( - args.quant_ckpt_path is not None - and args.weight_only_precision == 'int4_gptq' - ) or args.quant_ckpt_path is None, "only gptq weights only needs this option" - convert_and_save_hf(args) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Total time of converting checkpoints: {t}') - - -if __name__ == '__main__': - main() diff --git a/src/triton_cli/trt_llm/checkpoint_scripts/opt/convert_checkpoint.py b/src/triton_cli/trt_llm/checkpoint_scripts/opt/convert_checkpoint.py deleted file mode 100644 index 5d075ab..0000000 --- a/src/triton_cli/trt_llm/checkpoint_scripts/opt/convert_checkpoint.py +++ /dev/null @@ -1,411 +0,0 @@ -import argparse -import json -import os -import time -import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed - -import safetensors -import torch -from transformers import AutoModelForCausalLM, Blip2ForConditionalGeneration - -import tensorrt_llm -from tensorrt_llm._utils import pad_vocab_size -from tensorrt_llm.quantization import QuantAlgo - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument( - '--model_type', - type=str, - default='opt', - choices=['opt', 'blip2'], - help= - 'Multimodal type when this script is used for multimodal conversion.') - parser.add_argument('--tp_size', - type=int, - default=1, - help='N-way tensor parallelism size') - parser.add_argument('--pp_size', - type=int, - default=1, - help='N-way pipeline parallelism size') - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float32', 'bfloat16', 'float16']) - parser.add_argument( - '--use_weight_only', - default=False, - action="store_true", - help='Quantize weights for the various GEMMs to INT4/INT8.' - 'See --weight_only_precision to set the precision') - parser.add_argument( - '--weight_only_precision', - const='int8', - type=str, - nargs='?', - default='int8', - choices=['int8', 'int4'], - help= - 'Define the precision for the weights when using weight-only quantization.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=0, - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) - parser.add_argument( - '--use_embedding_sharing', - action="store_true", - default=False, - help= - 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' - 'Note: the flag might not take effect when the criteria are not met.') - parser.add_argument('--output_dir', - type=str, - default='tllm_checkpoint', - help='The path to save the TensorRT-LLM checkpoint') - parser.add_argument( - '--workers', - type=int, - default=1, - help='The number of workers for converting checkpoint in parallel') - args = parser.parse_args() - - return args - - -def split(v, tp_size, idx, dim=0): - if tp_size == 1: - return v - if len(v.shape) == 1: - return torch.chunk(v, tp_size)[idx].contiguous() - else: - return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() - - -def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank): - """ - Splits the QKV matrix according to tensor parallelism - """ - v = v.reshape(3, n_hidden, n_hidden) - split_v = split(v, tensor_parallel, rank, dim=1) - split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden) - return split_v.contiguous() - - -def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): - """ - Splits the QKV bias according to tensor parallelism - """ - v = v.reshape(3, n_hidden) - split_v = split(v, tensor_parallel, rank, dim=1) - split_v = split_v.reshape(3 * (n_hidden // tensor_parallel)) - return split_v.contiguous() - - -def split_matrix_tp(v, tensor_parallel, rank, dim): - return split(v, tensor_parallel, rank, dim=dim) - - -def split_embedding( - param: torch.Tensor, - tp_size: int, - tp_rank: int, - use_parallel_embedding: bool = False, - sharding_dim: int = 0, -) -> torch.Tensor: - if param is None: - return None - if not use_parallel_embedding: - return param - - vocab_size, hidden_size = param.size() - if sharding_dim == 0: - if vocab_size % tp_size != 0: - vocab_size_padded = pad_vocab_size(vocab_size, tp_size) - pad_width = vocab_size_padded - vocab_size - param = torch.nn.functional.pad(param, (0, 0, 0, pad_width), - value=0) - else: - assert hidden_size % tp_size == 0 - return split(param, tp_size, tp_rank, dim=sharding_dim) - - -def get_weight(config, prefix, dtype): - return config[prefix + '.weight'].to(dtype).detach() - - -def get_bias(config, prefix, dtype): - return config[prefix + '.bias'].to(dtype).detach() - - -def get_weight_and_bias(config, prefix, dtype): - return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype) - - -def get_tllm_linear_weight(weight, - prefix, - bias=None, - use_weight_only=False, - plugin_weight_only_quant_type=torch.int8): - results = {} - if use_weight_only: - v = weight.t().contiguous() - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - v, plugin_weight_only_quant_type) - results[prefix + 'weight'] = processed_torch_weights - results[prefix + 'per_channel_scale'] = torch_weight_scales - else: - results[prefix + 'weight'] = weight.contiguous() - - if bias is not None: - results[prefix + 'bias'] = bias - - return results - - -def convert_hf_opt(hf_model, - rank=0, - tensor_parallel=1, - dtype='float32', - use_parallel_embedding=False, - sharding_dim=0, - share_embedding_table=False, - use_weight_only=False, - plugin_weight_only_quant_type=torch.int8): - - weights = {} - tik = time.time() - - model_params = dict(hf_model.named_parameters()) - dtype = getattr(torch, dtype) - do_layer_norm_before = hf_model.config.do_layer_norm_before - num_attention_heads = hf_model.config.num_attention_heads - hidden_size = hf_model.config.hidden_size - - for l in range(hf_model.config.num_hidden_layers): - prefix = f'model.decoder.layers.{l}.' - tllm_prex = f'transformer.layers.{l}.' - - q_weight, q_bias = get_weight_and_bias(model_params, - prefix + 'self_attn.q_proj', - dtype) - k_weight, k_bias = get_weight_and_bias(model_params, - prefix + 'self_attn.k_proj', - dtype) - v_weight, v_bias = get_weight_and_bias(model_params, - prefix + 'self_attn.v_proj', - dtype) - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size, - tensor_parallel, rank) - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) - bias = split_qkv_bias_tp(qkv_bias, num_attention_heads, hidden_size, - tensor_parallel, rank) - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.', bias, - use_weight_only, - plugin_weight_only_quant_type)) - - attn_dense_weight, attn_dense_bias = get_weight_and_bias( - model_params, prefix + 'self_attn.out_proj', dtype) - split_v = split_matrix_tp(attn_dense_weight, - tensor_parallel, - rank, - dim=1) - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.', - attn_dense_bias, use_weight_only, - plugin_weight_only_quant_type)) - - mlp_fc_weight, mlp_fc_bias = get_weight_and_bias( - model_params, prefix + 'fc1', dtype) - split_v = split_matrix_tp(mlp_fc_weight, tensor_parallel, rank, dim=0) - bias = split_matrix_tp(mlp_fc_bias, tensor_parallel, rank, dim=0) - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', bias, - use_weight_only, - plugin_weight_only_quant_type)) - - mlp_proj_weight, mlp_proj_bias = get_weight_and_bias( - model_params, prefix + 'fc2', dtype) - split_v = split_matrix_tp(mlp_proj_weight, tensor_parallel, rank, dim=1) - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', - mlp_proj_bias, use_weight_only, - plugin_weight_only_quant_type)) - - # Layer norms do not use tensor parallelism - input_ln_weight, input_ln_bias = get_weight_and_bias( - model_params, prefix + 'self_attn_layer_norm', dtype) - weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight - weights[tllm_prex + 'input_layernorm.bias'] = input_ln_bias - - post_ln_weight, post_ln_bias = get_weight_and_bias( - model_params, prefix + 'final_layer_norm', dtype) - weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight - weights[tllm_prex + 'post_layernorm.bias'] = post_ln_bias - - embed_w = get_weight(model_params, 'model.decoder.embed_tokens', dtype) - if 'model.decoder.project_in.weight' in model_params.keys(): - project_in = get_weight(model_params, 'model.decoder.project_in', dtype) - project_out = get_weight(model_params, 'model.decoder.project_out', - dtype) - lm_head_w = torch.matmul(embed_w.float(), project_out.float()).to(dtype) - embed_w = torch.matmul(embed_w.float(), - project_in.t().float()).to(dtype) - else: - lm_head_w = embed_w.clone() - - if not share_embedding_table: - weights['lm_head.weight'] = split_matrix_tp(lm_head_w, - tensor_parallel, - rank, - dim=0) - - weights['transformer.vocab_embedding.weight'] = split_embedding( - embed_w, - tp_size=tensor_parallel, - tp_rank=rank, - use_parallel_embedding=use_parallel_embedding, - sharding_dim=sharding_dim) - - embed_p = get_weight(model_params, 'model.decoder.embed_positions', dtype) - weights['transformer.position_embedding.weight'] = split_embedding( - embed_p[2:, :], - tp_size=tensor_parallel, - tp_rank=rank, - use_parallel_embedding=use_parallel_embedding, - sharding_dim=sharding_dim) - - if do_layer_norm_before: - ln_f_w, ln_f_b = get_weight_and_bias(model_params, - 'model.decoder.final_layer_norm', - dtype) - weights['transformer.ln_f.weight'] = ln_f_w - weights['transformer.ln_f.bias'] = ln_f_b - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') - return weights - - -if __name__ == '__main__': - # TODO(qijun): Currently, the convert script depends on a torch op: - # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, - # which is included in tensorrt_llm Python package. Otherwise, the convert - # script does not need to import tensorrt_llm. Will remove it after reimplementing - # the op with PyTorch. - print(tensorrt_llm.__version__) - args = parse_arguments() - world_size = args.tp_size * args.pp_size - assert args.pp_size == 1, "Pipeline parallelism is not supported." - - tik = time.time() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - if args.model_type == 'opt': - hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, - torch_dtype="auto") - elif args.model_type == 'blip2': - hf_model = Blip2ForConditionalGeneration.from_pretrained( - args.model_dir, torch_dtype="auto").language_model - - hf_config = hf_model.config - if hf_config.hidden_size != hf_config.word_embed_proj_dim: - args.use_embedding_sharing = False - args.use_parallel_embedding = False - - quant_algo = None - plugin_weight_only_quant_type = None - if args.use_weight_only and args.weight_only_precision == 'int8': - plugin_weight_only_quant_type = torch.int8 - quant_algo = QuantAlgo.W8A16 - elif args.use_weight_only and args.weight_only_precision == 'int4': - plugin_weight_only_quant_type = torch.quint4x2 - quant_algo = QuantAlgo.W4A16 - - config = { - 'architecture': hf_config.architectures[0], - 'dtype': args.dtype, - 'num_hidden_layers': hf_config.num_hidden_layers, - 'num_attention_heads': hf_config.num_attention_heads, - 'hidden_size': hf_config.hidden_size, - 'vocab_size': hf_config.vocab_size, - 'position_embedding_type': 'learned_absolute', - 'max_position_embeddings': hf_config.max_position_embeddings, - 'hidden_act': hf_config.activation_function, - 'quantization': { - 'quant_algo': quant_algo - }, - 'mapping': { - 'world_size': world_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - }, - 'use_parallel_embedding': args.use_parallel_embedding, - 'embedding_sharding_dim': args.embedding_sharding_dim, - 'share_embedding_table': args.use_embedding_sharing, - 'do_layer_norm_before': hf_config.do_layer_norm_before, - } - - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - - def covert_and_save(rank): - weights = convert_hf_opt( - hf_model, - rank, - world_size, - dtype=args.dtype, - use_weight_only=args.use_weight_only, - plugin_weight_only_quant_type=plugin_weight_only_quant_type, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing) - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) - - if args.workers == 1: - for rank in range(world_size): - covert_and_save(rank) - else: - with ThreadPoolExecutor(max_workers=args.workers) as p: - futures = [ - p.submit(covert_and_save, rank) for rank in range(world_size) - ] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Total time of converting checkpoints: {t}') diff --git a/tests/utils.py b/tests/utils.py index 6f42251..88bd5ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -73,6 +73,15 @@ def _infer(model, prompt=None, protocol=None): run(args) def _profile(model, backend): + # FIXME: WAR for genai-perf bug in 24.09, remove in 24.10 + import genai_perf + + if genai_perf.__version__ == "0.0.6dev": + print( + "[WARNING] Skipping call to 'triton profile' due to known issue in genai-perf" + ) + return + args = ["profile", "-m", model, "--backend", backend] # NOTE: With default parameters, genai-perf may take upwards of 1m30s or 2m to run, # so limit the genai-perf run with --request-count to reduce time for testing purposes.