Skip to content

Commit

Permalink
Add initial tests for repo subcommand (#21)
Browse files Browse the repository at this point in the history
Co-authored-by: Francesco Petrini <[email protected]>
  • Loading branch information
rmccorm4 and fpetrini15 authored Jan 10, 2024
1 parent 45d2f36 commit cfb4bbb
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 14 deletions.
15 changes: 12 additions & 3 deletions src/triton_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,28 @@

from triton_cli import parser

import sys
import logging

logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("triton")


# Separate function that can raise exceptions used for testing
# to assert correct errors and messages.
# Optional argv used for testing - will default to sys.argv if None.
def run(argv=None):
args = parser.parse_args(argv)
args.func(args)


def main():
args = parser.parse_args()
# Interactive use will catch exceptions and log formatted errors rather than tracebacks.
try:
args.func(args)
run()
except Exception as e:
logger.error(f"{e}")


if __name__ == "__main__":
main()
sys.exit(main())
5 changes: 3 additions & 2 deletions src/triton_cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def parse_args_bench(subcommands):
return bench_run


def parse_args():
# Optional argv used for testing - will default to sys.argv if None.
def parse_args(argv=None):
parser = argparse.ArgumentParser(
prog="triton", description="CLI to interact with Triton Inference Server"
)
Expand All @@ -536,5 +537,5 @@ def parse_args():
_ = parse_args_repo(subcommands)
_ = parse_args_server(subcommands)
_ = parse_args_bench(subcommands)
args = parser.parse_args()
args = parser.parse_args(argv)
return args
10 changes: 9 additions & 1 deletion src/triton_cli/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import json
import subprocess
from dataclasses import dataclass
from itertools import pairwise
from itertools import tee
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -75,6 +75,14 @@
}


# Built-in to itertools in Python 3.10+
def pairwise(iterable):
# n=2 for pairs
a, b = tee(iterable, 2)
next(b, None)
return zip(a, b)


@dataclass
class ProfileResults:
prompt_size: int
Expand Down
8 changes: 4 additions & 4 deletions src/triton_cli/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def add(
verbose=True,
):
if not source:
raise Exception("Non-empty model source must be provided")
raise ValueError("Non-empty model source must be provided")

if backend:
raise NotImplementedError(
Expand Down Expand Up @@ -216,9 +216,9 @@ def __add_huggingface_model(
self, model_dir: Path, version_dir: Path, huggingface_id: str
):
if not model_dir or not model_dir.exists():
raise Exception("Model directory must be provided and exist")
raise ValueError("Model directory must be provided and exist")
if not huggingface_id:
raise Exception("HuggingFace ID must be non-empty")
raise ValueError("HuggingFace ID must be non-empty")

# TODO: Add generic support for HuggingFace models with HF API.
# For now, use vLLM as a means of deploying HuggingFace Transformers
Expand Down Expand Up @@ -262,7 +262,7 @@ def __create_model_repository(
if backend == "tensorrtllm":
# Don't allow existing files for TRT-LLM for now in case we delete large engine files
if model_dir.exists():
raise Exception(
raise ValueError(
f"Found existing model at {version_dir}, skipping repo add."
)

Expand Down
3 changes: 2 additions & 1 deletion src/triton_cli/server/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import subprocess
import logging
import json
from typing import Union
from pathlib import Path
import tritonclient.grpc.model_config_pb2 as mc
from google.protobuf import json_format, text_format
Expand All @@ -42,7 +43,7 @@ def get_launch_command(
server_config: TritonServerConfig,
cmd_as_list: bool,
env_cmds=[],
) -> str | list:
) -> Union[str, list]:
"""
Parameters
----------
Expand Down
126 changes: 126 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import pytest
from triton_cli.main import run
from triton_cli.parser import KNOWN_MODEL_SOURCES

KNOWN_MODELS = KNOWN_MODEL_SOURCES.keys()
KNOWN_SOURCES = KNOWN_MODEL_SOURCES.values()
TEST_REPOS = [None, os.path.join("tmp", "models")]

CUSTOM_VLLM_MODEL_SOURCES = [("vllm-model", "hf:gpt2")]

CUSTOM_TRTLLM_MODEL_SOURCES = [("trtllm-model", "hf:gpt2")]

# TODO: Add public NGC model for testing
CUSTOM_NGC_MODEL_SOURCES = [("my-llm", "ngc:does-not-exist")]


class TestRepo:
def repo_list(self, repo=None):
args = ["repo", "list"]
if repo:
args += ["--repo", repo]
run(args)

def repo_clear(self, repo=None):
args = ["repo", "clear"]
if repo:
args += ["--repo", repo]
run(args)

def repo_add(self, model, source=None, repo=None):
args = ["repo", "add", "-m", model]
if source:
args += ["--source", source]
if repo:
args += ["--repo", repo]
run(args)

def repo_remove(self, model, repo=None):
args = ["repo", "remove", "-m", model]
if repo:
args += ["--repo", repo]
run(args)

@pytest.mark.parametrize("repo", TEST_REPOS)
def test_repo_clear(self, repo):
self.repo_clear(repo)

# TODO: Add pre/post repo clear to a fixture for setup/teardown
@pytest.mark.parametrize("model", KNOWN_MODELS)
@pytest.mark.parametrize("repo", TEST_REPOS)
def test_repo_add_known_model(self, model, repo):
self.repo_clear(repo)
self.repo_add(model, repo=repo)
self.repo_clear(repo)

@pytest.mark.parametrize("source", KNOWN_SOURCES)
@pytest.mark.parametrize("repo", TEST_REPOS)
def test_repo_add_known_source(self, source, repo):
self.repo_clear(repo)
self.repo_add("known_source", source=source, repo=repo)
self.repo_clear(repo)

@pytest.mark.parametrize("model,source", CUSTOM_VLLM_MODEL_SOURCES)
def test_repo_add_vllm(self, model, source):
self.repo_clear()
self.repo_add(model, source=source)
# TODO: Parse repo to find model, with vllm backend in config
self.repo_clear()

@pytest.mark.skip(reason="TRT-LLM engine build support not ready to test")
def test_repo_add_trtllm_build(self, model, source):
# TODO: Parse repo to find TRT-LLM models and backend in config
pass

@pytest.mark.skip(reason="Pre-built TRT-LLM engines not available")
def test_repo_add_trtllm_prebuilt(self, model, source):
# TODO: Parse repo to find TRT-LLM models and backend in config
pass

def test_repo_add_no_source(self):
# TODO: Investigate idiomatic way to assert failures for CLIs
with pytest.raises(
Exception, match="Please use a known model, or provide a --source"
):
self.repo_add("no_source", source=None)

def test_repo_remove(self):
self.repo_add("gpt2", source="hf:gpt2")
self.repo_remove("gpt2")

# TODO: Find a way to raise well-typed errors for testing purposes, without
# always dumping traceback to user-facing output.
def test_repo_remove_nonexistent(self):
with pytest.raises(FileNotFoundError, match="No model folder exists"):
self.repo_remove("does-not-exist")

@pytest.mark.parametrize("repo", TEST_REPOS)
def test_repo_list(self, repo):
self.repo_list(repo)
7 changes: 4 additions & 3 deletions tests/test_example.py → tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Placeholder to add real tests in the future
def test_import():
import triton_cli
import triton_cli


# Placeholder to add real tests in the future
def test_version():
print(triton_cli.__version__)

0 comments on commit cfb4bbb

Please sign in to comment.