diff --git a/.github/workflows/python-package.yaml b/.github/workflows/python-package.yaml index a1b2a21..bc474d6 100644 --- a/.github/workflows/python-package.yaml +++ b/.github/workflows/python-package.yaml @@ -41,7 +41,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-22.04"] - python-version: ["3.8", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index 179a6ec..c9a18b6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ __pycache__/ # Distribution / packaging .Python +artifacts/ build/ develop-eggs/ dist/ @@ -17,12 +18,20 @@ eggs/ lib64/ parts/ sdist/ +tests/checkpoints/ +tests/output_dir/ +tests/output_model_repository/ +tests/plots/ +tests/reports/ +tests/results/ +tmp/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg +llm_inputs.json MANIFEST # PyInstaller diff --git a/README.md b/README.md index 2ecfb9e..9d2d62d 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Server. ## Pre-requisites When using Triton and related tools on your host (outside of a Triton container -image) there are a number of additional dependencies that may be required for +image), there are a number of additional dependencies that may be required for various workflows. Most system dependency issues can be resolved by installing and running the CLI from within the latest corresponding `tritonserver` container image, which should have all necessary system dependencies installed. @@ -162,7 +162,7 @@ triton start # Interact with model triton infer -m llama-3-8b-instruct --prompt "machine learning is" -# Profile model with Perf Analyzer +# Profile model with GenAI-Perf triton profile -m llama-3-8b-instruct --backend vllm ``` @@ -232,7 +232,7 @@ triton start # Interact with model triton infer -m llama-3-8b-instruct --prompt "machine learning is" -# Profile model with Perf Analyzer +# Profile model with GenAI-Perf triton profile -m llama-3-8b-instruct --backend tensorrtllm ``` ## Additional Dependencies for Custom Environments @@ -269,10 +269,6 @@ sudo apt install libopenmpi-dev ``` ## Known Limitations -- Triton CLI's `profile` command currently only supports TRT-LLM and vLLM models. -- Triton CLI's `profile` command will be migrating to use -[genai-perf](https://github.com/triton-inference-server/client/tree/main/src/c++/perf_analyzer/genai-perf) -as the backbone for LLM profiling soon. - Models and configurations generated by Triton CLI are focused on ease-of-use, and may not be as optimized as possible for your system or use case. - Triton CLI currently uses the TRT-LLM dependencies installed in its environment diff --git a/pyproject.toml b/pyproject.toml index b56931e..4e2d257 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,8 +37,6 @@ classifiers = [ "Topic :: Scientific/Engineering", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Operating System :: Unix", @@ -46,11 +44,12 @@ classifiers = [ authors = [] maintainers = [] keywords = [] -requires-python = ">=3.8,<4" +requires-python = ">=3.10,<4" # TODO: Add [gpu] set of dependencies for trtllm once it's available on pypi dependencies = [ "directory-tree == 0.0.4", # may remove in future "docker == 6.1.3", + "genai-perf @ git+https://github.com/triton-inference-server/client.git@r24.04#subdirectory=src/c++/perf_analyzer/genai-perf", # TODO: rely on tritonclient to pull in protobuf and numpy dependencies? "numpy >= 1.21", "protobuf>=3.7.0", @@ -58,11 +57,12 @@ dependencies = [ "psutil >= 5.9.5", # may remove later "rich == 13.5.2", # TODO: Test on cpu-only machine if [cuda] dependency is an issue - "tritonclient[all] >= 2.38", + "tritonclient[all] >= 2.45", "huggingface-hub >= 0.19.4", # Testing "pytest >= 8.1.1", # may remove later "pytest-timeout", # may remove later + "pytest-mock >= 3.13.0", # may remove later ] # CLI Entrypoint @@ -81,6 +81,9 @@ build-backend = "hatchling.build" [tool.hatch.version] path = "src/triton_cli/__init__.py" +[tool.hatch.metadata] +allow-direct-references = true + # Pre-commit hook tool configs [tool.codespell] # note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - diff --git a/src/triton_cli/__init__.py b/src/triton_cli/__init__.py index fd0db64..8fe212f 100644 --- a/src/triton_cli/__init__.py +++ b/src/triton_cli/__init__.py @@ -24,4 +24,4 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -__version__ = "0.0.7" +__version__ = "0.0.8dev" diff --git a/src/triton_cli/parser.py b/src/triton_cli/parser.py index 6339092..045c030 100755 --- a/src/triton_cli/parser.py +++ b/src/triton_cli/parser.py @@ -26,6 +26,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import subprocess +import sys import time import logging import argparse @@ -41,9 +43,9 @@ ) from triton_cli.client.client import InferenceServerException, TritonClient from triton_cli.metrics import MetricsClient +from triton_cli.profile import add_unknown_args_to_args, build_command from triton_cli.repository import ModelRepository from triton_cli.server.server_factory import TritonServerFactory -from triton_cli.profiler import Profiler logger = logging.getLogger(LOGGER_NAME) @@ -159,41 +161,6 @@ def add_model_args(subcommands): ) -def add_profile_args(subcommands): - for subcommand in subcommands: - subcommand.add_argument( - "-b", - "--batch-size", - type=int, - default=1, - required=False, - help="The batch size / concurrency to benchmark. (Default: 1)", - ) - subcommand.add_argument( - "--input-length", - type=int, - default=128, - required=False, - help="The input length (tokens) to use for benchmarking LLMs. (Default: 128)", - ) - subcommand.add_argument( - "--output-length", - type=int, - default=128, - required=False, - help="The output length (tokens) to use for benchmarking LLMs. (Default: 128)", - ) - # TODO: Revisit terminology here. Online/offline vs streaming, etc. - subcommand.add_argument( - "--profile-mode", - type=str, - choices=["online", "offline"], - default="online", - required=False, - help="Profiling mode: offline means one full response will be generated, online means response will be streaming tokens as they are generated.", - ) - - def add_client_args(subcommands): # Add protocol/url/port to all client-based subcommands for subcommand in subcommands: @@ -396,49 +363,17 @@ def handle_infer(args: argparse.Namespace): # Profile # ================================================ def parse_args_profile(parser): - profile = parser.add_parser( - "profile", help="Profile LLM models using Perf Analyzer" - ) + profile = parser.add_parser("profile", help="Profile models", add_help=False) profile.set_defaults(func=handle_profile) - add_model_args([profile]) - add_profile_args([profile]) - add_backend_args([profile]) - add_client_args([profile]) + profile.add_argument( + "--help", action="store_true", help="Show help message and exit" + ) def handle_profile(args: argparse.Namespace): - client = TritonClient(url=args.url, port=args.port, protocol=args.protocol) - profile_model(args, client) - - -# TODO: Move to utils? <-- Delete? -def profile_model(args: argparse.Namespace, client: TritonClient): - if args.protocol != "grpc": - raise Exception("Profiler only supports 'grpc' protocol at this time.") - - if not args.port: - args.port = 8001 if args.protocol == "grpc" else 8000 - - # TODO: Consider python(BLS)/ensemble case for the model - # receiving requests in the case of TRT-LLM. For now, TRT-LLM - # should be manually specified. - backend = args.backend - if not args.backend: - # Profiler needs to know TRT-LLM vs vLLM to form correct payload - backend = client.get_model_backend(args.model) - - logger.info(f"Running Perf Analyzer profiler on '{args.model}'...") - Profiler.profile( - model=args.model, - backend=backend, - batch_size=args.batch_size, - url=f"{args.url}:{args.port}", - input_length=args.input_length, - output_length=args.output_length, - # Should be "online" for IFB / streaming, and "offline" for non-streaming - offline=(args.profile_mode == "offline"), - verbose=args.verbose, - ) + cmd = build_command(args, "genai-perf") + logger.info(f"Running: '{' '.join(cmd)}'") + subprocess.run(cmd, check=True) # ================================================ @@ -502,5 +437,12 @@ def parse_args(argv=None): parse_args_profile(subcommands) parse_args_utils(subcommands) add_verbose_args([parser]) - args = parser.parse_args(argv) + + argv_ = argv if argv is not None else sys.argv[1:] + # Add special argparse handling for passthrough to genai-perf CLI + if argv_[0] == "profile": + args, unknown_args = parser.parse_known_args(argv_) + args = add_unknown_args_to_args(args, unknown_args) + else: + args = parser.parse_args(argv_) return args diff --git a/src/triton_cli/profile.py b/src/triton_cli/profile.py new file mode 100755 index 0000000..e5799c8 --- /dev/null +++ b/src/triton_cli/profile.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import List + + +# ================================================ +# Helper functions +# ================================================ +def build_command(args: argparse.Namespace, executable: str): + skip_args = ["func"] + cmd = [executable] + for arg, value in vars(args).items(): + if arg in skip_args: + pass + elif value is False: + pass + elif value is True: + if len(arg) == 1: + cmd += [f"-{arg}"] + else: + cmd += [f"--{arg}"] + # [DLIS-6656] - Remove backend renaming. + # This allows "tensorrtllm" to be used as the backend for consistency. + # Once GenAI-Perf releases 24.05, "tensorrtllm" as the backend value + # will be supported by default. + elif arg == "backend" and value in ["tensorrtllm", "trtllm"]: + cmd += ["--backend", "trtllm"] + else: + if len(arg) == 1: + cmd += [f"-{arg}", f"{value}"] + else: + cmd += [f"--{arg}", f"{value}"] + return cmd + + +def add_unknown_args_to_args(args: argparse.Namespace, unknown_args: List[str]): + """Add unknown args to args list""" + unknown_args_dict = turn_unknown_args_into_dict(unknown_args) + for key, value in unknown_args_dict.items(): + setattr(args, key, value) + return args + + +def turn_unknown_args_into_dict(unknown_args: List[str]): + """Convert list of unknown args to dictionary""" + it = iter(unknown_args) + unknown_args_dict = {} + try: + while True: + arg = next(it) + if arg.startswith(("-", "--")): + key = arg.lstrip("-") + # Peek to see if next item is a value or another flag + next_arg = next(it, None) + if next_arg and not next_arg.startswith(("-", "--")): + unknown_args_dict[key] = next_arg + else: + unknown_args_dict[key] = True + if next_arg: + it = iter([next_arg] + list(it)) + else: + raise ValueError(f"Argument does not start with a '-' or '--': {arg}") + except StopIteration: + pass + return unknown_args_dict diff --git a/src/triton_cli/profiler.py b/src/triton_cli/profiler.py deleted file mode 100644 index 8311591..0000000 --- a/src/triton_cli/profiler.py +++ /dev/null @@ -1,708 +0,0 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import logging -import subprocess -from dataclasses import dataclass -from itertools import tee -from pathlib import Path -from typing import Optional - -from rich.progress import Progress - -import numpy as np - -from triton_cli.constants import LOGGER_NAME - -logger = logging.getLogger(LOGGER_NAME) - - -INPUT_FILENAME = "generated_input_data.json" -METRIC_FIELDS = { - # "max_first_token_latency": ("Max first token latency", "ms"), - # "min_first_token_latency": ("Min first token latency", "ms"), - "avg_first_token_latency": ("Avg first token latency", "ms"), - # "p50_first_token_latency": ("p50 first token latency", "ms"), - # "p90_first_token_latency": ("p90 first token latency", "ms"), - # "p95_first_token_latency": ("p95 first token latency", "ms"), - # "p99_first_token_latency": ("p99 first token latency", "ms"), - # "max_gen_latency": ("Max generation latency", "ms"), - # "min_gen_latency": ("Min generation latency", "ms"), - # "avg_gen_latency": ("Avg generation latency", "ms"), - # "p50_gen_latency": ("p50 generation latency", "ms"), - # "p90_gen_latency": ("p90 generation latency", "ms"), - # "p95_gen_latency": ("p95 generation latency", "ms"), - # "p99_gen_latency": ("p99 generation latency", "ms"), - # "avg_output_token_latency": ("Avg output token latency", "ms/output token"), - # "avg_total_t2t_latency": ("Avg total token-to-token latency", "ms"), - # "max_e2e_latency": ("Max end-to-end latency", "ms"), - # "min_e2e_latency": ("Min end-to-end latency", "ms"), - "avg_e2e_latency": ("Avg end-to-end latency", "ms"), - # "p50_e2e_latency": ("p50 end-to-end latency", "ms"), - # "p90_e2e_latency": ("p90 end-to-end latency", "ms"), - # "p95_e2e_latency": ("p95 end-to-end latency", "ms"), - # "p99_e2e_latency": ("p99 end-to-end latency", "ms"), - # "max_e2e_throughput": ("Max end-to-end throughput", "tokens/s"), - # "min_e2e_throughput": ("Min end-to-end throughput", "tokens/s"), - "avg_e2e_throughput": ("Avg end-to-end throughput", "tokens/s"), - # "p50_e2e_throughput": ("p50 end-to-end throughput", "tokens/s"), - # "p90_e2e_throughput": ("p90 end-to-end throughput", "tokens/s"), - # "p95_e2e_throughput": ("p95 end-to-end throughput", "tokens/s"), - # "p99_e2e_throughput": ("p99 end-to-end throughput", "tokens/s"), - # "max_gen_throughput": ("Max generation throughput", "output tokens/s"), - # "min_gen_throughput": ("Min generation throughput", "output tokens/s"), - "avg_gen_throughput": ("Avg generation throughput", "output tokens/s"), - # "p50_gen_throughput": ("p50 generation throughput", "output tokens/s"), - # "p90_gen_throughput": ("p90 generation throughput", "output tokens/s"), - # "p95_gen_throughput": ("p95 generation throughput", "output tokens/s"), - # "p99_gen_throughput": ("p99 generation throughput", "output tokens/s"), -} - - -# Built-in to itertools in Python 3.10+ -def pairwise(iterable): - # n=2 for pairs - a, b = tee(iterable, 2) - next(b, None) - return zip(a, b) - - -@dataclass -class ProfileResults: - prompt_size: int - # max_first_token_latency: Optional[float] = None - # min_first_token_latency: Optional[float] = None - avg_first_token_latency: Optional[float] = None - # p50_first_token_latency: Optional[float] = None - # p90_first_token_latency: Optional[float] = None - # p95_first_token_latency: Optional[float] = None - # p99_first_token_latency: Optional[float] = None - # max_gen_latency: Optional[float] = None - # min_gen_latency: Optional[float] = None - # avg_gen_latency: Optional[float] = None - # p50_gen_latency: Optional[float] = None - # p90_gen_latency: Optional[float] = None - # p95_gen_latency: Optional[float] = None - # p99_gen_latency: Optional[float] = None - # avg_output_token_latency: Optional[float] = None - # avg_total_t2t_latency: Optional[float] = None - # avg_periodic_t2t_latencies: Optional[list[float]] = None - # max_e2e_latency: Optional[float] = None - # min_e2e_latency: Optional[float] = None - avg_e2e_latency: Optional[float] = None - # p50_e2e_latency: Optional[float] = None - # p90_e2e_latency: Optional[float] = None - # p95_e2e_latency: Optional[float] = None - # p99_e2e_latency: Optional[float] = None - # max_e2e_throughput: Optional[float] = None - # min_e2e_throughput: Optional[float] = None - avg_e2e_throughput: Optional[float] = None - # p50_e2e_throughput: Optional[float] = None - # p90_e2e_throughput: Optional[float] = None - # p95_e2e_throughput: Optional[float] = None - # p99_e2e_throughput: Optional[float] = None - # max_gen_throughput: Optional[float] = None - # min_gen_throughput: Optional[float] = None - avg_gen_throughput: Optional[float] = None - # p50_gen_throughput: Optional[float] = None - # p90_gen_throughput: Optional[float] = None - # p95_gen_throughput: Optional[float] = None - # p99_gen_throughput: Optional[float] = None - - -def load_json_data(filename): - with open(filename) as f: - return json.load(f) - - -def save_json_data(data, filename): - with open(filename, "w") as f: - json.dump(data, f) - - -def get_postfix(args, prompt_size): - """Generate postfix for profile export filename and plot. - - e.g. - - trtllm-ensemble-prompt100-maxtokens256 - - trtllm-ensemble-prompt100-periodic1_100_1-period32-maxtokens1024 - """ - stream_type = "offline" if args.offline else "online" - postfix = f"{args.backend}-{args.model}-{stream_type}-prompt{prompt_size}-" - if args.periodic_concurrency_range: - start, end, step = args.periodic_concurrency_range - postfix += f"periodic{start}_{end}_{step}-period{args.request_period}-" - postfix += f"maxtokens{args.max_tokens}" - return postfix - - -def get_export_filename(args, prompt_size): - postfix = get_postfix(args, prompt_size) - filename = f"profile_export-{postfix}.json" - return filename - - -def get_plot_filename(args, prompt_size): - postfix = get_postfix(args, prompt_size) - filename = f"inflight_batching_benchmark-{postfix}.png" - return filename - - -def print_benchmark_summary(profile_results): - print("[ BENCHMARK SUMMARY ]") - for pr in profile_results: - for metric, (name, unit) in METRIC_FIELDS.items(): - if getattr(pr, metric): - print(f" * {name}: {getattr(pr, metric):.4f} {unit}") - print("") - - -def plot_results(latencies, filename="inflight_batching_benchmark.png"): - """Plot in-flight batching LLM bencharmark results.""" - import matplotlib.pyplot as plt # Lazy import - - periods = np.arange(1, len(latencies) + 1) - fig, ax = plt.subplots() - ax.plot(periods, latencies) - - # Set pyplot parameters - ax.grid(linestyle="--") - ax.set_xlabel("i-th Request Period", fontsize=12) - ax.set_ylabel("Avg Token-to-Token Latency (ms)", fontsize=12) - ax.set_title("In-Flight Batching Benchmark Summary", fontsize=14) - ax.set_ylim(bottom=0.0) - - fig.savefig(filename, dpi=300) - - -def add_latencies_to_bins(bins, pos, responses, request_period): - """Add token-to-token latencies into the corresponding bin. - - Given the responses of a single request, calculate token-to-token - latency and add it into bin. Update the bin position to the next - for every request period. - """ - for response_id, (prev_res, res) in enumerate(pairwise(responses)): - bins[pos].append(res - prev_res) - if (response_id + 1) % request_period == 0: - pos += 1 - - -def update_start_position(request_id, start_pos, initial_requests, step): - """Shift the start position of the bin. - - Once we iterate through the entire requests, we shift - the start position. Then, we shift the start position for every - requests. - """ - if (request_id + 1) >= initial_requests: - num_requests_after_start = request_id + 1 - initial_requests - if num_requests_after_start % step == 0: - start_pos += 1 - return start_pos - - -def collect_periodic_latencies(args, export_data): - """Split the entire benchmark results into segments with size - of request period and collect latencies for each segment. - """ - start, end, step = args.periodic_concurrency_range - - num_bins = args.max_tokens // args.request_period + (end - start) // step - if args.max_tokens % args.request_period != 0: - num_bins += 1 # extra bin - - bins = [[] for _ in range(num_bins)] - bin_start_position = 0 - requests = export_data["experiments"][0]["requests"] - - for i, r in enumerate(requests): - add_latencies_to_bins( - bins=bins, - pos=bin_start_position, - responses=r["response_timestamps"], - request_period=args.request_period, - ) - bin_start_position = update_start_position( - request_id=i, - start_pos=bin_start_position, - initial_requests=start, - step=step, - ) - return bins - - -def calculate_avg_periodic_latencies(args, profile_result, export_data): - """Calculate average token-to-token latency for each request period.""" - bins = collect_periodic_latencies(args, export_data) - - latencies = [] - for bin in bins: - latencies.append(np.mean(bin) / 1_000_000) - - profile_result.avg_periodic_t2t_latencies = latencies - - -def collect_online_metrics(export_data, output_tokens): - # Example json demonstrating format: - # see client/src/c++/perf_analyzer/docs/examples/decoupled_output_file.json - first_token_latencies = [] - generation_latencies = [] - token_to_token_latencies = [] - generation_throughputs = [] - requests = export_data["experiments"][0]["requests"] - - for r in requests: - init_request, responses = r["timestamp"], r["response_timestamps"] - first_token_latency = (responses[0] - init_request) / 1_000_000 - first_token_latencies.append(first_token_latency) - if output_tokens > 1: - generation_latency_ms = (responses[-1] - responses[0]) / 1_000_000 # msec - generation_latency_s = (responses[-1] - responses[0]) / 1_000_000_000 # sec - generation_latencies.append(generation_latency_ms) - generation_throughputs.append(output_tokens / generation_latency_s) - for prev_res, res in pairwise(responses): - token_to_token_latencies.append((res - prev_res) / 1_000_000) - return ( - first_token_latencies, - generation_latencies, - token_to_token_latencies, - generation_throughputs, - ) - - -# TODO: take concurrency > 1 into account for all metrics -def calculate_online_metrics(args, profile_result, export_data): - """Calculate online metrics for more fine-grained performance information.""" - latencies = collect_online_metrics(export_data, args.max_tokens) - ( - first_token_latencies, - generation_latencies, - token_to_token_latencies, - generation_throughputs, - ) = latencies - - profile_result.max_first_token_latency = max(first_token_latencies) - profile_result.min_first_token_latency = min(first_token_latencies) - profile_result.avg_first_token_latency = np.mean(first_token_latencies) - profile_result.p50_first_token_latency = np.percentile( - first_token_latencies, 50, method="lower" - ) - profile_result.p90_first_token_latency = np.percentile( - first_token_latencies, 90, method="lower" - ) - profile_result.p95_first_token_latency = np.percentile( - first_token_latencies, 95, method="lower" - ) - profile_result.p99_first_token_latency = np.percentile( - first_token_latencies, 99, method="lower" - ) - - if args.max_tokens > 1: - profile_result.avg_total_t2t_latency = np.mean(token_to_token_latencies) - - profile_result.max_gen_latency = max(generation_latencies) - profile_result.min_gen_latency = min(generation_latencies) - profile_result.avg_gen_latency = np.mean(generation_latencies) - profile_result.p50_gen_latency = np.percentile( - generation_latencies, 50, method="lower" - ) - profile_result.p90_gen_latency = np.percentile( - generation_latencies, 90, method="lower" - ) - profile_result.p95_gen_latency = np.percentile( - generation_latencies, 95, method="lower" - ) - profile_result.p99_gen_latency = np.percentile( - generation_latencies, 99, method="lower" - ) - - token_latencies = [t / args.max_tokens for t in generation_latencies] - profile_result.avg_output_token_latency = np.mean(token_latencies) - - profile_result.max_gen_throughput = max(generation_throughputs) - profile_result.min_gen_throughput = min(generation_throughputs) - # profile_result.avg_gen_throughput = np.mean(generation_throughputs) - avg_gen_throughput = ( - args.concurrency * args.max_tokens / np.mean(generation_latencies) - ) - profile_result.avg_gen_throughput = avg_gen_throughput * 1000 # msec to sec - - profile_result.p50_gen_throughput = np.percentile( - generation_throughputs, 50, method="lower" - ) - profile_result.p90_gen_throughput = np.percentile( - generation_throughputs, 90, method="lower" - ) - profile_result.p95_gen_throughput = np.percentile( - generation_throughputs, 95, method="lower" - ) - profile_result.p99_gen_throughput = np.percentile( - generation_throughputs, 99, method="lower" - ) - - -def collect_offline_metrics(export_data, sequence_len): - latencies = [] - throughputs = [] - requests = export_data["experiments"][0]["requests"] - - for request in requests: - total_time = request["response_timestamps"][-1] - request["timestamp"] - time_s = total_time / 1_000_000_000 # sec - time_ms = total_time / 1_000_000 # msec - latencies.append(time_ms) - throughputs.append(sequence_len / time_s) - return throughputs, latencies - - -def calculate_offline_metrics(args, profile_result, export_data): - """Calculate offline metrics that show end-to-end performance.""" - throughputs, latencies = collect_offline_metrics( - export_data, sequence_len=profile_result.prompt_size + args.max_tokens - ) - - # profile_result.max_e2e_latency = max(latencies) - # profile_result.min_e2e_latency = min(latencies) - profile_result.avg_e2e_latency = np.mean(latencies) - # profile_result.p50_e2e_latency = np.percentile(latencies, 50, method="lower") - # profile_result.p90_e2e_latency = np.percentile(latencies, 90, method="lower") - # profile_result.p95_e2e_latency = np.percentile(latencies, 95, method="lower") - # profile_result.p99_e2e_latency = np.percentile(latencies, 99, method="lower") - - # profile_result.max_e2e_throughput = max(throughputs) - # profile_result.min_e2e_throughput = min(throughputs) - profile_result.avg_e2e_throughput = np.mean(throughputs) - # profile_result.p50_e2e_throughput = np.percentile(throughputs, 50, method="lower") - # profile_result.p90_e2e_throughput = np.percentile(throughputs, 90, method="lower") - # profile_result.p95_e2e_throughput = np.percentile(throughputs, 95, method="lower") - # profile_result.p99_e2e_throughput = np.percentile(throughputs, 99, method="lower") - - -def calculate_metrics(args, profile_result, export_data): - # Sanity check the number of responses received from backend - if args.ignore_eos: - requests = export_data["experiments"][0]["requests"] - for request in requests: - # Expect number of responses to match tokens only in online mode - # Offline mode will just receive 1-2 responses (full response + empty final). - if not args.offline: - if len(request["response_timestamps"]) == args.max_tokens: - # Assume FINAL flag is returned with final token response - pass - elif len(request["response_timestamps"]) == args.max_tokens + 1: - # Assume FINAL flag was returned with an empty response after - # the final token - pass - else: - raise ValueError( - f"Expecting {args.max_tokens} tokens but received " - f"{len(request['response_timestamps'])} tokens. " - f"This could be due to an unsupported sequence length. " - f"Please double check the input and output length." - ) - - calculate_offline_metrics(args, profile_result, export_data) - if not args.offline: - calculate_online_metrics(args, profile_result, export_data) - - if args.periodic_concurrency_range: - calculate_avg_periodic_latencies(args, profile_result, export_data) - plot_results( - latencies=profile_result.avg_periodic_t2t_latencies, - filename=get_plot_filename(args, profile_result.prompt_size), - ) - - -def summarize_profile_results(args, prompts): - results = [] - for prompt in prompts: - prompt_size = len(prompt.split()) - export_file = get_export_filename(args, prompt_size) - export_data = load_json_data(export_file) - - profile_result = ProfileResults(prompt_size=prompt_size) - calculate_metrics(args, profile_result, export_data) - results.append(profile_result) - - print_benchmark_summary(results) - if args.periodic_concurrency_range: - print( - "Saved in-flight batching benchmark plots " - "@ 'inflight_batching_benchmark-*.png'." - ) - - -def profile(args, export_file): - command = ( - f"perf_analyzer -m {args.model} -i grpc --async --streaming " - f"-u {args.url} " - f"--input-data={INPUT_FILENAME} " - f"--profile-export-file={export_file} " - ) - if args.backend == "tensorrtllm": - command += ( - "--shape=text_input:1 " - "--shape=max_tokens:1 " - "--shape=bad_words:1 " - "--shape=stop_words:1 " - ) - if args.periodic_concurrency_range: - start, end, step = args.periodic_concurrency_range - command += ( - f"--periodic-concurrency-range={start}:{end}:{step} " - f"--request-period={args.request_period}" - ) - else: - command += ( - "--measurement-mode=count_windows " - "--measurement-request-count=10 " - "--stability-percentage=999 " - f"--concurrency-range={args.concurrency}" - ) - - if args.verbose: - logger.info(f"Running the following command: {command}") - proc = subprocess.run(args=[command], shell=True, capture_output=True) - - if args.verbose: - logger.info(f"Perf Analyzer output:\n{proc.stdout.decode('utf-8')}") - if proc.returncode: - raise RuntimeError( - "Encountered the following error while running Perf Analyzer:\n" - f"{proc.stderr.decode('utf-8').rstrip()}" - ) - - -def prepare_export_file(args, prompt): - prompt_size = len(prompt.split()) - filename = get_export_filename(args, prompt_size) - - # If exists, clean up - export_file = Path(filename) - export_file.unlink(missing_ok=True) - return export_file - - -def prepare_input_data(input_data, prompt): - """Insert the prompt to send into input JSON data.""" - input_data["data"][0]["text_input"] = [prompt] - save_json_data(input_data, INPUT_FILENAME) - - -def generate_prompts(args, input_data): - """Generate dummy prompts if not specified by input JSON file.""" - prompt = input_data["data"][0]["text_input"][0] - - if not prompt: # Generate dummy prompt - assert args.prompt_size_range, "Must specify --prompt-size-range." - start, end, step = args.prompt_size_range - return [" ".join(["hi"] * size) for size in range(start, end + 1, step)] - return [prompt] - - -def construct_vllm_input_data(args): - """Construct input data that contains input tensors and parameters for vLLM. - - Parse the input JSON file (if exists) to construct the input data. - When user sets parameters through command line, overwrite the - parameters set by input JSON file. - """ - # Default sampling parameters - sampling_params = { - "max_tokens": 256, - "ignore_eos": False, - } - - if args.input_data: - input_data = load_json_data(filename=args.input_data) - if "sampling_parameters" in input_data["data"][0]: - loaded_params = input_data["data"][0]["sampling_parameters"][0] - loaded_params = json.loads(loaded_params or "null") - sampling_params = loaded_params if loaded_params else sampling_params - else: - # Default input JSON - input_data = { - "data": [ - { - "text_input": [""], - "stream": [True], - "sampling_parameters": [""], - } - ] - } - - # If command line option is specified, overwrite - if args.offline: - input_data["data"][0]["stream"] = [False] - elif not input_data["data"][0]["stream"]: - args.offline = True - - if args.max_tokens: - sampling_params["max_tokens"] = args.max_tokens - elif "max_tokens" in sampling_params: - args.max_tokens = sampling_params["max_tokens"] - else: - args.max_tokens = 256 # default - sampling_params["max_tokens"] = args.max_tokens - - if args.ignore_eos: - sampling_params["ignore_eos"] = args.ignore_eos - elif "ignore_eos" in sampling_params: - args.ignore_eos = sampling_params["ignore_eos"] - else: - args.ignore_eos = False # default - sampling_params["ignore_eos"] = args.ignore_eos - - input_data["data"][0]["sampling_parameters"] = [json.dumps(sampling_params)] - return input_data - - -def construct_trtllm_input_data(args): - """Construct input data that contains input tensors and parameters for TRT-LLM. - - Parse the input JSON file (if exists) to construct the input data. - When user sets parameters through command line, overwrite the - parameters set by input JSON file. - """ - if args.input_data: - input_data = load_json_data(filename=args.input_data) - else: - # Default input JSON - input_data = { - "data": [ - { - "text_input": [""], - "max_tokens": [128], - "stream": [True], - "bad_words": [""], - "stop_words": [""], - } - ] - } - - # If command line option is specified, overwrite - if args.offline: - input_data["data"][0]["stream"] = [False] - elif not input_data["data"][0]["stream"]: - args.offline = True - - if args.max_tokens: - input_data["data"][0]["max_tokens"] = [args.max_tokens] - else: - args.max_tokens = input_data["data"][0]["max_tokens"][0] - - return input_data - - -def main(args, should_summarize=True): - if args.backend == "tensorrtllm": - input_data = construct_trtllm_input_data(args) - elif args.backend == "vllm": - input_data = construct_vllm_input_data(args) - else: - raise ValueError( - f"Unknown backend specified: '{args.backend}'. Supported backend types are: 'tensorrtllm' " - "and 'vllm'." - ) - - prompts = generate_prompts(args, input_data) - - for prompt in prompts: - prepare_input_data(input_data, prompt) - export_file = prepare_export_file(args, prompt) - - if args.verbose: - logger.info(f"Input Data:\n{input_data}") - - # Run Perf Analyzer - profile(args, export_file) - - if should_summarize: - summarize_profile_results(args, prompts) - - -class Args: - backend = "vllm" - model = "" - periodic_concurrency_range = [] - request_period = None - max_tokens = 128 - prompt_size_range = [2048, 2048, 1] - input_data = "" - ignore_eos = True - offline = False - url = "localhost:8001" - concurrency = 1 - verbose = False - - -class Profiler: - @staticmethod - def profile( - model, - backend, - batch_size, - url, - input_length=128, - output_length=128, - offline=False, - verbose=False, - ): - args = Args() - args.model = model - args.backend = backend - args.concurrency = batch_size # inflight batch size - args.url = url - args.prompt_size_range = [input_length, input_length, 1] - args.max_tokens = output_length - args.offline = offline - args.verbose = verbose - - start, end, step = args.prompt_size_range - assert start == end and step == 1 # no sweeping for now - - with Progress(transient=True) as progress: - _ = progress.add_task("[green]Warming up...", total=None) - main(args, should_summarize=False) # warm-up - - mode = "offline" if offline else "online" - logger.info( - "Warmed up, profiling with the following config:\n" - "[ PROFILE CONFIGURATIONS ]\n" - f" * Model: {args.model}\n" - f" * Backend: {args.backend}\n" - f" * Profiling Mode: {mode}\n" - f" * Batch size: {args.concurrency}\n" - f" * Input tokens: {args.prompt_size_range[0]}\n" - f" * Output tokens: {args.max_tokens}\n" - "" - ) - - with Progress(transient=True) as progress: - _ = progress.add_task("[green]Profiling...", total=None) - main(args) diff --git a/tests/test_cli.py b/tests/test_cli.py index 785dfc8..812593c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,7 +27,7 @@ import os import pytest from triton_cli.main import run -from triton_cli.parser import KNOWN_MODEL_SOURCES +from triton_cli.parser import KNOWN_MODEL_SOURCES, parse_args KNOWN_MODELS = KNOWN_MODEL_SOURCES.keys() KNOWN_SOURCES = KNOWN_MODEL_SOURCES.values() @@ -129,3 +129,13 @@ def test_remove_nonexistent(self): @pytest.mark.parametrize("repo", TEST_REPOS) def test_list(self, repo): self._list(repo) + + # This test uses mock system args and a mock subprocess call + # to ensure that the correct subprocess call is made for profile. + def test_triton_profile(self, mocker, monkeypatch): + test_args = ["triton", "profile", "-m", "add_sub"] + mock_run = mocker.patch("subprocess.run") + monkeypatch.setattr("sys.argv", test_args) + args = parse_args() + args.func(args) + mock_run.assert_called_once_with(["genai-perf", "-m", "add_sub"], check=True) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index e00f5aa..d321aeb 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -59,12 +59,8 @@ def _infer(self, model, prompt=None, protocol=None): args += ["-i", protocol] run(args) - def _profile(self, model, protocol=None, backend=None): - args = ["profile", "-m", model] - if protocol: - args += ["-i", protocol] - if backend: - args += ["--backend", backend] + def _profile(self, model, backend): + args = ["profile", "-m", model, "--backend", backend] run(args) class KillServerByPid: @@ -115,7 +111,7 @@ def test_tensorrtllm_e2e(self, protocol, setup_and_teardown): utils.wait_for_server_ready() self._infer(model, prompt=PROMPT, protocol=protocol) - self._profile(model, backend="tensorrtllm", protocol=protocol) + self._profile(model, backend="tensorrtllm") @pytest.mark.skipif( os.environ.get("IMAGE_KIND") != "VLLM", reason="Only run for VLLM image" @@ -146,7 +142,7 @@ def test_vllm_e2e(self, protocol, setup_and_teardown): utils.wait_for_server_ready(timeout=300) self._infer(model, prompt=PROMPT, protocol=protocol) - self._profile(model, protocol=protocol) + self._profile(model, backend="vllm") @pytest.mark.parametrize("protocol", ["grpc", "http"]) def test_non_llm(self, protocol, setup_and_teardown): @@ -159,11 +155,6 @@ def test_non_llm(self, protocol, setup_and_teardown): model = "add_sub" # infer should work without a prompt for non-LLM models self._infer(model, protocol=protocol) - # profile should fail for non-LLM models - with pytest.raises(Exception): - if protocol == "http": - pytest.xfail("Profiler does not support http protocol at this time") - self._profile(model, protocol=protocol) @pytest.mark.parametrize("protocol", ["grpc", "http"]) def test_mock_llm(self, protocol, setup_and_teardown): @@ -178,4 +169,6 @@ def test_mock_llm(self, protocol, setup_and_teardown): self._infer(model, prompt=PROMPT, protocol=protocol) # infer should fail without a prompt for LLM models with pytest.raises(Exception): - self._profile(model, protocol=protocol) + self._infer(model, protocol=protocol) + # profile should work without a prompt for LLM models + self._profile(model, backend="tensorrtllm") diff --git a/tests/test_models/mock_llm/1/model.py b/tests/test_models/mock_llm/1/model.py index fb55e6e..a382bf8 100644 --- a/tests/test_models/mock_llm/1/model.py +++ b/tests/test_models/mock_llm/1/model.py @@ -38,7 +38,7 @@ def initialize(self, args): def execute(self, requests): responses = [] for request in requests: - in_0 = pb_utils.get_input_tensor_by_name(request, "TEXT_INPUT") - out_tensor_0 = pb_utils.Tensor("TEXT_OUTPUT", in_0.as_numpy()) + in_0 = pb_utils.get_input_tensor_by_name(request, "text_input") + out_tensor_0 = pb_utils.Tensor("text_output", in_0.as_numpy()) responses.append(pb_utils.InferenceResponse([out_tensor_0])) return responses diff --git a/tests/test_models/mock_llm/config.pbtxt b/tests/test_models/mock_llm/config.pbtxt index f78b815..91b4bac 100644 --- a/tests/test_models/mock_llm/config.pbtxt +++ b/tests/test_models/mock_llm/config.pbtxt @@ -29,14 +29,19 @@ backend: "python" input [ { - name: "TEXT_INPUT" + name: "text_input" data_type: TYPE_STRING dims: [ 1 ] + }, + { + name: "max_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] } ] output [ { - name: "TEXT_OUTPUT" + name: "text_output" data_type: TYPE_STRING dims: [ 1 ] }