Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Background Server #15

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/triton_cli/client/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import queue
import logging
import time
import numpy as np
from functools import partial

Expand Down Expand Up @@ -53,6 +54,9 @@ def is_server_ready(self):
def is_server_live(self):
return self.client.is_server_live()

def ready_for_inference(self):
return self.client.is_server_live() and self.client.is_server_ready()

def get_server_health(self):
live = self.is_server_live()
if not live:
Expand Down Expand Up @@ -115,7 +119,7 @@ def __create_triton_input(self, name, shape, dtype, data):
return _input

def generate_data(self, config: dict, data_mode: str):
logger.info("Generating input data...")
logger.debug("Generating input data...")
inputs = [i for i in config["input"]]

infer_inputs = []
Expand Down Expand Up @@ -270,3 +274,12 @@ def __process_infer_result(self, result):
# Used for decoupled purposes to determine when requests are finished
# Only applicable to GRPC streaming at this time
return is_final_response

# Junk function to show a server can be launched and queried in a single terminal.
# TODO: Remove.
def benchmark_model(self, model: str):
start = time.time_ns()
self.infer(model, "random", "This CLI is ")
end = time.time_ns()
total_time_ms = (end - start) / (10**6)
logger.info(f"Inference time: {total_time_ms} ms")
194 changes: 162 additions & 32 deletions src/triton_cli/parser.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,84 @@
import time
import logging
import argparse
from pathlib import Path

from rich import print as rich_print
from rich.progress import Progress

from triton_cli.constants import DEFAULT_MODEL_REPO
from triton_cli.client.client import TritonClient
from triton_cli.constants import DEFAULT_MODEL_REPO, LOGGER_NAME
from triton_cli.client.client import InferenceServerException, TritonClient
from triton_cli.metrics import MetricsClient
from triton_cli.repository import ModelRepository
from triton_cli.server.server_factory import TritonServerFactory
from triton_cli.profiler import Profiler

logger = logging.getLogger("triton")
logger = logging.getLogger(LOGGER_NAME)

# TODO: Move to config file approach
# TODO: Per-GPU mappings for TRT LLM models
KNOWN_MODEL_SOURCES = {
"llama-2-7b": "ngc:whw3rcpsilnj/playground/llama2_7b_trt_a100:0.1",
"llama-2-13b": "ngc:whw3rcpsilnj/playground/llama2_13b_trt_a100:0.1",
"gpt2": "hf:gpt2",
"opt125m": "hf:facebook/opt125-m",
"mistral-7b": "hf:mistralai/Mistral-7B-v0.1",
}


# TODO: Move out of parser
# TODO: rich progress bar
def wait_for_ready(timeout, server, client):
with Progress(transient=True) as progress:
_ = progress.add_task("[green]Waiting for server startup...", total=None)
for _ in range(timeout):
# Client health will allow early exit of wait if healthy,
# errors may occur while server starting up, so ignore them.
try:
if client.is_server_ready():
return
except InferenceServerException:
pass

# Server health will throw exception if error occurs on server side
server.health()
time.sleep(1)
raise Exception("Timed out waiting for server to startup.")


def add_server_start_args(subcommands):
for subcommand in subcommands:
subcommand.add_argument(
"--mode",
choices=["local", "docker"],
type=str,
default="local",
required=False,
help="Mode to start Triton with. (Default: 'local')",
)
subcommand.add_argument(
"--image",
type=str,
required=False,
default="nvcr.io/nvidia/tritonserver:23.11-vllm-python-py3",
help="Image to use when starting Triton with 'docker' mode",
)
# TODO: Delete once world-size can be parsed from a known
# config file location.
subcommand.add_argument(
"--world-size",
type=int,
required=False,
default=-1,
help="Number of devices to deploy a tensorrtllm model.",
)
subcommand.add_argument(
"--server-timeout",
type=int,
required=False,
default=100,
help="Maximum number of seconds to wait for server startup. (Default: 100)",
)


def add_client_args(subcommands):
Expand Down Expand Up @@ -113,14 +180,8 @@ def handle_model(args: argparse.Namespace):

def handle_server(args: argparse.Namespace):
if args.subcommand == "start":
# TODO: No support for specifying GPUs for now, default to all available.
gpus = []
server = TritonServerFactory.get_server_handle(args, gpus)
logger.debug(server)
server = TritonServerFactory.get_server_handle(args)
try:
logger.info(
f"Starting server with model repository: [{args.model_repository}]..."
)
server.start()
logger.info("Reading server output...")
server.logs()
Expand Down Expand Up @@ -241,28 +302,7 @@ def parse_args_server(subcommands):
server.set_defaults(func=handle_server)
server_commands = server.add_subparsers(required=True, dest="subcommand")
server_start = server_commands.add_parser("start", help="Start a Triton server")
server_start.add_argument(
"--mode",
choices=["local", "docker"],
type=str,
default="docker",
required=False,
help="Mode to start Triton with. (Default: 'docker')",
)
server_start.add_argument(
"--image",
type=str,
required=False,
default="nvcr.io/nvidia/tritonserver:23.11-vllm-python-py3",
help="Image to use when starting Triton with 'docker' mode",
)
server_start.add_argument(
"--world-size",
type=int,
required=False,
default=-1,
help="Number of devices to deploy a tensorrtllm model.",
)
add_server_start_args([server_start])
add_repo_args([server_start])

server_metrics = server_commands.add_parser(
Expand All @@ -278,6 +318,95 @@ def parse_args_server(subcommands):
return server


def handle_bench(args: argparse.Namespace):
if args.verbose:
logger.setLevel(logging.DEBUG)

if args.subcommand == "run":
### Add model to repo
repo = ModelRepository(args.model_repository)
# Handle common models for convenience
if not args.source:
if args.model in KNOWN_MODEL_SOURCES:
args.source = KNOWN_MODEL_SOURCES[args.model]
logger.info(
f"Known model source found for '{args.model}': '{args.source}'"
)
else:
logger.error(
f"No known source for model: '{args.model}'. Known sources: {list(KNOWN_MODEL_SOURCES.keys())}"
)
raise Exception("Please use a known model, or provide a --source.")

repo.add(
args.model,
version=1,
source=args.source,
backend=None,
verbose=args.verbose,
)

### Start server
server = TritonServerFactory.get_server_handle(args)
try:
server.start()
client = TritonClient(url=args.url, port=args.port, protocol=args.protocol)
wait_for_ready(args.server_timeout, server, client)
### Profile model
logger.info("Server is ready for inference. Starting benchmark...")
client.benchmark_model(model=args.model)
except KeyboardInterrupt:
print()
except Exception as ex:
# Catch timeout exception
logger.error(ex)

logger.info("Stopping server...")
server.stop()
else:
raise NotImplementedError(f"bench subcommand {args.subcommand} not supported")


def parse_args_bench(subcommands):
# Model Repository Management
bench = subcommands.add_parser(
"bench", help="Run benchmarks on a model loaded into the Triton server."
)
bench.set_defaults(func=handle_bench)
bench_commands = bench.add_subparsers(required=True, dest="subcommand")
bench_run = bench_commands.add_parser(
"run", help="Start a Triton benchmarking session."
)
bench_run.add_argument(
"-m",
"--model",
type=str,
required=True,
help="The name of the model to benchmark",
)
bench_run.add_argument(
"-s",
"--source",
type=str,
required=False,
help="Local model path or model identifier. Use prefix 'hf:' to specify a HuggingFace model ID, or 'ngc:' for NGC model ID. "
"NOTE: HuggingFace models are currently limited to vLLM, and NGC models are currently limited to TRT-LLM",
)
bench_run.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Enable verbose logging",
)

add_server_start_args([bench_run])
add_repo_args([bench_run])
add_client_args([bench_run])

return bench


def parse_args():
parser = argparse.ArgumentParser(
prog="triton", description="CLI to interact with Triton Inference Server"
Expand All @@ -286,5 +415,6 @@ def parse_args():
_ = parse_args_model(subcommands)
_ = parse_args_repo(subcommands)
_ = parse_args_server(subcommands)
_ = parse_args_bench(subcommands)
args = parser.parse_args()
return args
35 changes: 19 additions & 16 deletions src/triton_cli/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from directory_tree import display_tree

from triton_cli.constants import DEFAULT_MODEL_REPO
from triton_cli.constants import DEFAULT_MODEL_REPO, LOGGER_NAME
from triton_cli.json_parser import parse_and_substitute

logger = logging.getLogger("triton")
logger = logging.getLogger(LOGGER_NAME)

# For now, generated model configs will be limited to only backends
# that can be fully autocompleted for a simple deployment.
Expand Down Expand Up @@ -57,13 +57,13 @@ def __generate_config(self, org="", team="", api_key="", format_type="ascii"):
config_dir = Path.home() / ".ngc"
config_file = config_dir / "config"
if config_file.exists():
logger.info("Found existing NGC config, skipping config generation")
logger.debug("Found existing NGC config, skipping config generation")
return

logger.info("Generating NGC config")
if not config_dir.exists():
config_dir.mkdir(exist_ok=True)

logger.debug("Generating NGC config")
config = NGC_CONFIG_TEMPLATE.format(
api_key=api_key, format_type=format_type, org=org, team=team
)
Expand All @@ -83,7 +83,7 @@ def download_model(self, model, ngc_model_name, dest):
return

cmd = f"ngc registry model download-version {model} --dest {dest}"
logger.info(f"Running '{cmd}'")
logger.debug(f"Running '{cmd}'")
output = subprocess.run(cmd.split())
if output.returncode:
err = output.stderr.decode("utf-8")
Expand All @@ -101,9 +101,9 @@ def __init__(self, path: str = None):
# OK if model repo already exists, support adding multiple models
try:
self.repo.mkdir(parents=True, exist_ok=False)
logger.info(f"Created new model repository: {self.repo}")
logger.debug(f"Created new model repository: {self.repo}")
except FileExistsError:
logger.info(f"Using existing model repository: {self.repo}")
logger.debug(f"Using existing model repository: {self.repo}")

def list(self):
logger.info(f"Current repo at {self.repo}:")
Expand All @@ -115,6 +115,7 @@ def add(
version: int = 1,
source: str = None,
backend: str = None,
verbose=True,
):
if not source:
raise Exception("Non-empty model source must be provided")
Expand All @@ -126,16 +127,16 @@ def add(

# HuggingFace models
if source.startswith(SOURCE_PREFIX_HUGGINGFACE):
logger.info("HuggingFace prefix detected, parsing HuggingFace ID")
logger.debug("HuggingFace prefix detected, parsing HuggingFace ID")
source_type = "huggingface"
# NGC models
elif source.startswith(SOURCE_PREFIX_NGC):
logger.info("NGC prefix detected, parsing NGC ID")
logger.debug("NGC prefix detected, parsing NGC ID")
source_type = "ngc"
backend = "tensorrtllm"
# Local model path
else:
logger.info("No supported prefix detected, assuming local path")
logger.debug("No supported prefix detected, assuming local path")
source_type = "local"
model_path = Path(source)
if not model_path.exists():
Expand All @@ -162,10 +163,11 @@ def add(
# point to downloaded engines, etc.
self.__generate_trtllm_model(name, ngc_model_name)
else:
logger.info(f"Copying {model_path} to {version_dir}")
logger.debug(f"Copying {model_path} to {version_dir}")
shutil.copy(model_path, version_dir)

self.list()
if verbose:
self.list()

def clear(self):
logger.info(f"Clearing all contents from {self.repo}...")
Expand All @@ -174,13 +176,14 @@ def clear(self):
# No support for removing individual versions for now
# TODO: remove doesn't support removing groups of models like TRT LLM at this time
# Use "clear" instead to clean up the repo as a WAR.
def remove(self, name: str):
def remove(self, name: str, verbose=True):
model_dir = self.repo / name
if not model_dir.exists():
raise FileNotFoundError(f"No model folder exists at {model_dir}")
logger.info(f"Removing model {name} at {model_dir}...")
shutil.rmtree(model_dir)
self.list()
if verbose:
self.list()

def __add_huggingface_model(
self, model_dir: Path, version_dir: Path, huggingface_id: str
Expand Down Expand Up @@ -243,10 +246,10 @@ def __create_model_repository(
dirs_exist_ok=True,
ignore=shutil.ignore_patterns("__pycache__"),
)
logger.info(f"Adding TensorRT-LLM models at: {self.repo}")
logger.debug(f"Adding TensorRT-LLM models at: {self.repo}")
else:
version_dir.mkdir(parents=True, exist_ok=False)
logger.info(f"Adding new model to repo at: {version_dir}")
logger.debug(f"Adding new model to repo at: {version_dir}")
except FileExistsError:
logger.warning(f"Overwriting existing model in repo at: {version_dir}")

Expand Down
Loading