From f97e5c2e5a4dbc51421bd145972f57f4f46a9c50 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Tue, 9 Jan 2024 19:08:32 -0800 Subject: [PATCH] Fix negative tests to expect and catch failures --- src/triton_cli/main.py | 6 +++++- src/triton_cli/repository.py | 8 ++++---- tests/test_cli.py | 20 +++++++++++++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/triton_cli/main.py b/src/triton_cli/main.py index e217647..c4e9dd7 100755 --- a/src/triton_cli/main.py +++ b/src/triton_cli/main.py @@ -28,6 +28,7 @@ from triton_cli import parser +import sys import logging logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s") @@ -41,7 +42,10 @@ def main(argv=None): args.func(args) except Exception as e: logger.error(f"{e}") + # TODO: Find a way to raise well-typed errors for testing purposes, + # without always dumping traceback to user-facing output. + raise e if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index 419e208..41b5357 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -144,7 +144,7 @@ def add( verbose=True, ): if not source: - raise Exception("Non-empty model source must be provided") + raise ValueError("Non-empty model source must be provided") if backend: raise NotImplementedError( @@ -216,9 +216,9 @@ def __add_huggingface_model( self, model_dir: Path, version_dir: Path, huggingface_id: str ): if not model_dir or not model_dir.exists(): - raise Exception("Model directory must be provided and exist") + raise ValueError("Model directory must be provided and exist") if not huggingface_id: - raise Exception("HuggingFace ID must be non-empty") + raise ValueError("HuggingFace ID must be non-empty") # TODO: Add generic support for HuggingFace models with HF API. # For now, use vLLM as a means of deploying HuggingFace Transformers @@ -262,7 +262,7 @@ def __create_model_repository( if backend == "tensorrtllm": # Don't allow existing files for TRT-LLM for now in case we delete large engine files if model_dir.exists(): - raise Exception( + raise ValueError( f"Found existing model at {version_dir}, skipping repo add." ) diff --git a/tests/test_cli.py b/tests/test_cli.py index d734513..c834d66 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -78,7 +78,7 @@ def test_repo_clear(self, repo): @pytest.mark.parametrize("repo", TEST_REPOS) def test_repo_add_known_model(self, model, repo): self.repo_clear(repo) - self.repo_add(model, repo) + self.repo_add(model, repo=repo) self.repo_clear(repo) @pytest.mark.parametrize("source", KNOWN_SOURCES) @@ -87,13 +87,13 @@ def test_repo_add_known_source(self, source, repo): self.repo_clear(repo) # Random model name, since we don't care about it here model = str(uuid.uuid4()) - self.repo_add(model, source, repo) + self.repo_add(model, source=source, repo=repo) self.repo_clear(repo) @pytest.mark.parametrize("model,source", CUSTOM_VLLM_MODEL_SOURCES) def test_repo_add_vllm(self, model, source): self.repo_clear() - self.repo_add(model, source) + self.repo_add(model, source=source) # TODO: Parse repo to find model, with vllm backend in config self.repo_clear() @@ -107,12 +107,22 @@ def test_repo_add_trtllm_prebuilt(self, model, source): # TODO: Parse repo to find TRT-LLM models and backend in config pass + def test_repo_add_no_source(self): + # TODO: Investigate idiomatic way to assert failures for CLIs + with pytest.raises( + Exception, match="Please use a known model, or provide a --source" + ): + self.repo_add("no_source", source=None) + def test_repo_remove(self): - self.repo_add("gpt2") + self.repo_add("gpt2", source="hf:gpt2") self.repo_remove("gpt2") + # TODO: Find a way to raise well-typed errors for testing purposes, without + # always dumping traceback to user-facing output. def test_repo_remove_nonexistent(self): - self.repo_remove("does-not-exist") + with pytest.raises(FileNotFoundError, match="No model folder exists"): + self.repo_remove("does-not-exist") @pytest.mark.parametrize("repo", TEST_REPOS) def test_repo_list(self, repo):