diff --git a/.gitignore b/.gitignore index f1b69cb25e..bce94b6830 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,13 @@ __pycache__ tmp *.log +*.xml test_results.txt +artifacts +cprofile +*.prof + +# Test exclusions +qa/L0_openai/openai +tensorrtllm_models +custom_tokenizer diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 663a36d631..b11dc007bd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -73,12 +73,13 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace -- repo: local - hooks: - - id: add-license - name: Add License - entry: python tools/add_copyright.py - language: python - stages: [pre-commit] - verbose: true - require_serial: true +# FIXME: Only run on changed files when triggered by GitHub Actions +#- repo: local +# hooks: +# - id: add-license +# name: Add License +# entry: python tools/add_copyright.py +# language: python +# stages: [pre-commit] +# verbose: true +# require_serial: true diff --git a/python/openai/README.md b/python/openai/README.md new file mode 100644 index 0000000000..f6647da12a --- /dev/null +++ b/python/openai/README.md @@ -0,0 +1,206 @@ + +# OpenAI-Compatible Frontend for Triton Inference Server + +## Pre-requisites + +1. Docker + NVIDIA Container Runtime +2. A correctly configured `HF_TOKEN` for access to HuggingFace models. + - The current examples and testing primarily use the + [`meta-llama/Meta-Llama-3.1-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) + model, but you can manually bring your own models and adjust accordingly. + +## VLLM + +1. Launch the container and install dependencies: + - Mounts the `~/.huggingface/cache` for re-use of downloaded models across runs, containers, etc. + - Sets the [`HF_TOKEN`](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hftoken) environment variable to + access gated models, make sure this is set in your local environment if needed. + +```bash +docker run -it --net=host --gpus all --rm \ + -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ + -e HF_TOKEN \ + nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3 +``` + +2. Install dependencies inside the container: +```bash +# Install python bindings for tritonserver and tritonfrontend +pip install /opt/tritonserver/python/triton*.whl + +# Install application requirements +git clone https://github.com/triton-inference-server/server.git +cd server/python/openai/ +pip install -r requirements.txt +``` + +3. Launch the OpenAI-compatible Triton Inference Server: +```bash +# NOTE: Adjust the --tokenizer based on the model being used +python3 openai_frontend/main.py --model-repository tests/vllm_models --tokenizer meta-llama/Meta-Llama-3.1-8B-Instruct +``` + +4. Send a `/v1/chat/completions` request: + - Note the use of `jq` is optional, but provides a nicely formatted output for JSON responses. +```bash +MODEL="llama-3.1-8b-instruct" +curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ + "model": "'${MODEL}'", + "messages": [{"role": "user", "content": "Say this is a test!"}] +}' | jq +``` + +5. Send a `/v1/completions` request: + - Note the use of `jq` is optional, but provides a nicely formatted output for JSON responses. +```bash +MODEL="llama-3.1-8b-instruct" +curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json' -d '{ + "model": "'${MODEL}'", + "prompt": "Machine learning is" +}' | jq +``` + +6. Benchmark with `genai-perf`: +```bash +MODEL="llama-3.1-8b-instruct" +TOKENIZER="meta-llama/Meta-Llama-3.1-8B-Instruct" +genai-perf \ + --model ${MODEL} \ + --tokenizer ${TOKENIZER} \ + --service-kind openai \ + --endpoint-type chat \ + --synthetic-input-tokens-mean 256 \ + --synthetic-input-tokens-stddev 0 \ + --output-tokens-mean 256 \ + --output-tokens-stddev 0 \ + --streaming +``` + +7. Use the OpenAI python client directly: +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:9000/v1", + api_key="EMPTY", +) + +model = "llama-3.1-8b-instruct" +completion = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + {"role": "user", "content": "What are LLMs?"}, + ], + max_tokens=256, +) + +print(completion.choices[0].message.content) +``` + +8. Run tests (NOTE: The server should not be running, the tests will handle starting/stopping the server as necessary): +```bash +cd server/python/openai/ +pip install -r requirements-test.txt + +pytest -v tests/ +``` + +## TensorRT-LLM + +0. Prepare your model repository for serving a TensorRT-LLM model: + https://github.com/triton-inference-server/tensorrtllm_backend?tab=readme-ov-file#quick-start + +1. Launch the container: + - Mounts the `~/.huggingface/cache` for re-use of downloaded models across runs, containers, etc. + - Sets the [`HF_TOKEN`](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hftoken) environment variable to + access gated models, make sure this is set in your local environment if needed. + +```bash +docker run -it --net=host --gpus all --rm \ + -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ + -e HF_TOKEN \ + nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3 +``` + +2. Install dependencies inside the container: +```bash +# Install python bindings for tritonserver and tritonfrontend +pip install /opt/tritonserver/python/triton*.whl + +# Install application requirements +git clone https://github.com/triton-inference-server/server.git +cd server/python/openai/ +pip install -r requirements.txt +``` + +2. Launch the OpenAI server: +```bash +# NOTE: Adjust the --tokenizer based on the model being used +python3 openai_frontend/main.py --model-repository tests/tensorrtllm_models --tokenizer meta-llama/Meta-Llama-3.1-8B-Instruct +``` + +3. Send a `/v1/chat/completions` request: + - Note the use of `jq` is optional, but provides a nicely formatted output for JSON responses. +```bash +MODEL="tensorrt_llm_bls" +curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ + "model": "'${MODEL}'", + "messages": [{"role": "user", "content": "Say this is a test!"}] +}' | jq +``` + +The other examples should be the same as vLLM, except that you should set `MODEL="tensorrt_llm_bls"`, +everywhere applicable as seen in the example request above. + +## KServe Frontends + +To support serving requests through both the OpenAI-Compatible and +KServe Predict v2 frontends to the same running Triton Inference Server, +the `tritonfrontend` python bindings are included for optional use in this +application as well. + +You can opt-in to including these additional frontends, assuming `tritonfrontend` +is installed, with `--enable-kserve-frontends` like below: + +``` +python3 openai_frontend/main.py \ + --model-repository tests/vllm_models \ + --tokenizer meta-llama/Meta-Llama-3.1-8B-Instruct \ + --enable-kserve-frontends +``` + +See `python3 openai_frontend/main.py --help` for more information on the +available arguments and default values. + +For more information on the `tritonfrontend` python bindings, see the docs +[here](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/tritonfrontend.md). diff --git a/python/openai/openai_frontend/__init__.py b/python/openai/openai_frontend/__init__.py new file mode 100644 index 0000000000..dc1c939c66 --- /dev/null +++ b/python/openai/openai_frontend/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/openai_frontend/engine/__init__.py b/python/openai/openai_frontend/engine/__init__.py new file mode 100644 index 0000000000..77855ae979 --- /dev/null +++ b/python/openai/openai_frontend/engine/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/openai_frontend/engine/engine.py b/python/openai/openai_frontend/engine/engine.py new file mode 100644 index 0000000000..9c90dec25e --- /dev/null +++ b/python/openai/openai_frontend/engine/engine.py @@ -0,0 +1,94 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from __future__ import annotations + +from typing import Iterator, List, Protocol + +from schemas.openai import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + CreateCompletionRequest, + CreateCompletionResponse, + Model, +) + + +class LLMEngine(Protocol): + """ + Interface for an OpenAI-aware inference engine to be attached to an + OpenAI-compatible frontend. + + NOTE: This interface is subject to change, and may land on something more + generic rather than the current 1:1 with OpenAI endpoints over time. + """ + + def ready(self) -> bool: + """ + Returns True if the engine is ready to accept inference requests, or False otherwise. + """ + pass + + def metrics(self) -> str: + """ + Returns the engine's metrics in a Prometheus-compatible string format. + """ + pass + + def models(self) -> List[Model]: + """ + Returns a List of OpenAI Model objects. + """ + pass + + def chat( + self, request: CreateChatCompletionRequest + ) -> CreateChatCompletionResponse | Iterator[str]: + """ + If request.stream is True, this returns an Iterator (or Generator) that + produces server-sent-event (SSE) strings in the following form: + 'data: {CreateChatCompletionStreamResponse}\n\n' + ... + 'data: [DONE]\n\n' + + If request.stream is False, this returns a CreateChatCompletionResponse. + """ + pass + + def completion( + self, request: CreateCompletionRequest + ) -> CreateCompletionResponse | Iterator[str]: + """ + If request.stream is True, this returns an Iterator (or Generator) that + produces server-sent-event (SSE) strings in the following form: + 'data: {CreateCompletionResponse}\n\n' + ... + 'data: [DONE]\n\n' + + If request.stream is False, this returns a CreateCompletionResponse. + """ + pass diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py new file mode 100644 index 0000000000..e6f4f0bbe3 --- /dev/null +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -0,0 +1,434 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from __future__ import annotations + +import time +import uuid +from dataclasses import dataclass +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, List, Optional + +import tritonserver +from engine.engine import LLMEngine +from engine.utils.tokenizer import get_tokenizer +from engine.utils.triton import ( + _create_trtllm_inference_request, + _create_vllm_inference_request, + _get_output, + _validate_triton_responses_non_streaming, +) +from schemas.openai import ( + ChatCompletionChoice, + ChatCompletionFinishReason, + ChatCompletionResponseMessage, + ChatCompletionStreamingResponseChoice, + ChatCompletionStreamResponseDelta, + Choice, + CreateChatCompletionRequest, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, + CreateCompletionRequest, + CreateCompletionResponse, + FinishReason, + Model, + ObjectType, +) + + +# TODO: Improve type hints +@dataclass +class TritonModelMetadata: + # Name used in Triton model repository + name: str + # Name of backend used by Triton + backend: str + # Triton model object handle + model: tritonserver.Model + # Tokenizers used for chat templates + tokenizer: Optional[Any] + # Time that model was loaded by Triton + create_time: int + # Conversion format between OpenAI and Triton requests + request_converter: Callable + + +class TritonLLMEngine(LLMEngine): + def __init__( + self, server: tritonserver.Server, tokenizer: str, backend: Optional[str] = None + ): + # Assume an already configured and started server + self.server = server + self.tokenizer = self._get_tokenizer(tokenizer) + # TODO: Reconsider name of "backend" vs. something like "request_format" + self.backend = backend + + # NOTE: Creation time and model metadata will be static at startup for + # now, and won't account for dynamically loading/unloading models. + self.create_time = int(time.time()) + self.model_metadata = self._get_model_metadata() + + def ready(self) -> bool: + return self.server.ready() + + def metrics(self) -> str: + return self.server.metrics() + + def models(self) -> List[Model]: + models = [] + for metadata in self.model_metadata.values(): + models.append( + Model( + id=metadata.name, + created=metadata.create_time, + object=ObjectType.model, + owned_by="Triton Inference Server", + ), + ) + + return models + + async def chat( + self, request: CreateChatCompletionRequest + ) -> CreateChatCompletionResponse | AsyncIterator[str]: + metadata = self.model_metadata.get(request.model) + self._validate_chat_request(request, metadata) + + conversation = [ + {"role": str(message.role), "content": str(message.content)} + for message in request.messages + ] + add_generation_prompt = True + + prompt = metadata.tokenizer.apply_chat_template( + conversation=conversation, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + + # Convert to Triton request format and perform inference + responses = metadata.model.async_infer( + metadata.request_converter(metadata.model, prompt, request) + ) + + # Prepare and send responses back to client in OpenAI format + request_id = f"cmpl-{uuid.uuid1()}" + created = int(time.time()) + default_role = "assistant" + role = self._get_first_response_role( + conversation, add_generation_prompt, default_role + ) + + if request.stream: + return self._streaming_chat_iterator( + request_id, created, request.model, role, responses + ) + + # Response validation with decoupled models in mind + responses = [response async for response in responses] + _validate_triton_responses_non_streaming(responses) + response = responses[0] + text = _get_output(response) + + return CreateChatCompletionResponse( + id=request_id, + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionResponseMessage( + content=text, role=role, function_call=None + ), + logprobs=None, + finish_reason=ChatCompletionFinishReason.stop, + ) + ], + created=created, + model=request.model, + system_fingerprint=None, + object=ObjectType.chat_completion, + ) + + async def completion( + self, request: CreateCompletionRequest + ) -> CreateCompletionResponse | AsyncIterator[str]: + # Validate request and convert to Triton format + metadata = self.model_metadata.get(request.model) + self._validate_completion_request(request, metadata) + + # Convert to Triton request format and perform inference + responses = metadata.model.async_infer( + metadata.request_converter(metadata.model, request.prompt, request) + ) + + # Prepare and send responses back to client in OpenAI format + request_id = f"cmpl-{uuid.uuid1()}" + created = int(time.time()) + if request.stream: + return self._streaming_completion_iterator( + request_id, created, metadata.name, responses + ) + + # Response validation with decoupled models in mind + responses = [response async for response in responses] + _validate_triton_responses_non_streaming(responses) + response = responses[0] + text = _get_output(response) + + choice = Choice( + finish_reason=FinishReason.stop, + index=0, + logprobs=None, + text=text, + ) + return CreateCompletionResponse( + id=request_id, + choices=[choice], + system_fingerprint=None, + object=ObjectType.text_completion, + created=created, + model=metadata.name, + ) + + # TODO: This behavior should be tested further + def _get_first_response_role( + self, conversation: List[Dict], add_generation_prompt: bool, default_role: str + ) -> str: + if add_generation_prompt: + return default_role + + return conversation[-1]["role"] + + # TODO: Expose explicit flag to catch edge cases + def _determine_request_converter(self, backend: str): + # Allow manual override of backend request format if provided by user + if self.backend: + backend = self.backend + + # Request conversion from OpenAI format to backend-specific format + if backend == "vllm": + return _create_vllm_inference_request + + # Use TRT-LLM format as default for everything else. This could be + # an ensemble, a python or BLS model, a TRT-LLM backend model, etc. + return _create_trtllm_inference_request + + def _get_tokenizer(self, tokenizer_name: str): + tokenizer = None + if tokenizer_name: + tokenizer = get_tokenizer(tokenizer_name) + + return tokenizer + + def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]: + # One tokenizer and creation time shared for all loaded models for now. + model_metadata = {} + + # Read all triton models and store the necessary metadata for each + for name, _ in self.server.models().keys(): + model = self.server.model(name) + backend = model.config()["backend"] + print(f"Found model: {name=}, {backend=}") + + metadata = TritonModelMetadata( + name=name, + backend=backend, + model=model, + tokenizer=self.tokenizer, + create_time=self.create_time, + request_converter=self._determine_request_converter(backend), + ) + model_metadata[name] = metadata + + return model_metadata + + def _get_streaming_chat_response_chunk( + self, + choice: ChatCompletionStreamingResponseChoice, + request_id: str, + created: int, + model: str, + ) -> CreateChatCompletionStreamResponse: + return CreateChatCompletionStreamResponse( + id=request_id, + choices=[choice], + created=created, + model=model, + system_fingerprint=None, + object=ObjectType.chat_completion_chunk, + ) + + def _get_first_streaming_chat_response( + self, request_id: str, created: int, model: str, role: str + ) -> CreateChatCompletionStreamResponse: + # First chunk has no content and sets the role + choice = ChatCompletionStreamingResponseChoice( + index=0, + delta=ChatCompletionStreamResponseDelta( + role=role, content="", function_call=None + ), + logprobs=None, + finish_reason=None, + ) + chunk = self._get_streaming_chat_response_chunk( + choice, request_id, created, model + ) + return chunk + + def _get_nth_streaming_chat_response( + self, + request_id: str, + created: int, + model: str, + response: tritonserver.InferenceResponse, + ) -> CreateChatCompletionStreamResponse: + text = _get_output(response) + choice = ChatCompletionStreamingResponseChoice( + index=0, + delta=ChatCompletionStreamResponseDelta( + role=None, content=text, function_call=None + ), + logprobs=None, + finish_reason=ChatCompletionFinishReason.stop if response.final else None, + ) + + chunk = self._get_streaming_chat_response_chunk( + choice, request_id, created, model + ) + return chunk + + async def _streaming_chat_iterator( + self, + request_id: str, + created: int, + model: str, + role: str, + responses: AsyncIterable, + ) -> AsyncIterator[str]: + chunk = self._get_first_streaming_chat_response( + request_id, created, model, role + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + + async for response in responses: + chunk = self._get_nth_streaming_chat_response( + request_id, created, model, response + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + + yield "data: [DONE]\n\n" + + def _validate_chat_request( + self, request: CreateChatCompletionRequest, metadata: TritonModelMetadata + ): + """ + Validates a chat request to align with currently supported features. + """ + + # Reject missing internal information needed to do inference + if not metadata: + raise Exception(f"Unknown model: {request.model}") + + if not metadata.tokenizer: + raise Exception("Unknown tokenizer") + + if not metadata.backend: + raise Exception("Unknown backend") + + if not metadata.request_converter: + raise Exception(f"Unknown request format for model: {request.model}") + + # Reject unsupported features if requested + if request.n and request.n > 1: + raise Exception( + f"Received n={request.n}, but only single choice (n=1) is currently supported" + ) + + if request.logit_bias is not None or request.logprobs: + raise Exception("logit bias and log probs not currently supported") + + async def _streaming_completion_iterator( + self, request_id: str, created: int, model: str, responses: AsyncIterable + ) -> AsyncIterator[str]: + async for response in responses: + text = _get_output(response) + choice = Choice( + finish_reason=FinishReason.stop if response.final else None, + index=0, + logprobs=None, + text=text, + ) + chunk = CreateCompletionResponse( + id=request_id, + choices=[choice], + system_fingerprint=None, + object=ObjectType.text_completion, + created=created, + model=model, + ) + + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + + yield "data: [DONE]\n\n" + + def _validate_completion_request( + self, request: CreateCompletionRequest, metadata: TritonModelMetadata + ): + """ + Validates a completions request to align with currently supported features. + """ + # Reject missing internal information needed to do inference + if not metadata: + raise Exception(f"Unknown model: {request.model}") + + if not metadata.backend: + raise Exception("Unknown backend") + + if not metadata.request_converter: + raise Exception(f"Unknown request format for model: {request.model}") + + # Reject unsupported features if requested + if request.suffix is not None: + raise Exception("suffix is not currently supported") + + if not request.prompt: + raise Exception("prompt must be non-empty") + + # Currently only support single string as input + if not isinstance(request.prompt, str): + raise Exception("only single string input is supported") + + if request.n and request.n > 1: + raise Exception( + f"Received n={request.n}, but only single choice (n=1) is currently supported" + ) + + if request.best_of and request.best_of > 1: + raise Exception( + f"Received best_of={request.best_of}, but only single choice (best_of=1) is currently supported" + ) + + if request.logit_bias is not None or request.logprobs is not None: + raise Exception("logit bias and log probs not supported") diff --git a/python/openai/openai_frontend/engine/utils/__init__.py b/python/openai/openai_frontend/engine/utils/__init__.py new file mode 100644 index 0000000000..dc1c939c66 --- /dev/null +++ b/python/openai/openai_frontend/engine/utils/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/openai_frontend/engine/utils/tokenizer.py b/python/openai/openai_frontend/engine/utils/tokenizer.py new file mode 100644 index 0000000000..982e553cea --- /dev/null +++ b/python/openai/openai_frontend/engine/utils/tokenizer.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer.py +# Copyright 2024 The vLLM team. + +from typing import Optional, Union + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + + +def get_cached_tokenizer( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_len = len(tokenizer) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + +def get_tokenizer( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tokenizer_revision: Optional[str] = None, + download_dir: Optional[str] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface/modelscope.""" + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + **kwargs, + ) + except ValueError as e: + raise e + except AttributeError as e: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + print( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead." + ) + return get_cached_tokenizer(tokenizer) diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py new file mode 100644 index 0000000000..2ec8cce7d5 --- /dev/null +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -0,0 +1,175 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import ctypes +from typing import Iterable, List + +import numpy as np +import tritonserver +from schemas.openai import CreateChatCompletionRequest, CreateCompletionRequest + + +def _create_vllm_inference_request( + model, prompt, request: CreateChatCompletionRequest | CreateCompletionRequest +): + inputs = {} + # Exclude non-sampling parameters so they aren't passed to vLLM + excludes = { + "model", + "stream", + "messages", + "prompt", + "echo", + "store", + "metadata", + "response_format", + "service_tier", + "stream_options", + "tools", + "tool_choice", + "parallel_tool_calls", + "user", + "function_call", + "functions", + "suffix", + } + + # NOTE: The exclude_none is important, as internals may not support + # values of NoneType at this time. + sampling_parameters = request.model_dump_json( + exclude=excludes, + exclude_none=True, + ) + + exclude_input_in_output = True + echo = getattr(request, "echo", None) + if echo is not None: + exclude_input_in_output = not echo + + inputs["text_input"] = [prompt] + inputs["stream"] = np.bool_([request.stream]) + inputs["exclude_input_in_output"] = np.bool_([exclude_input_in_output]) + # Pass sampling_parameters as serialized JSON string input to support List + # fields like 'stop' that aren't supported by TRITONSERVER_Parameters yet. + inputs["sampling_parameters"] = [sampling_parameters] + return model.create_request(inputs=inputs) + + +def _create_trtllm_inference_request( + model, prompt, request: CreateChatCompletionRequest | CreateCompletionRequest +): + inputs = {} + inputs["text_input"] = [[prompt]] + inputs["stream"] = np.bool_([[request.stream]]) + if request.max_tokens: + inputs["max_tokens"] = np.int32([[request.max_tokens]]) + if request.stop: + if isinstance(request.stop, str): + request.stop = [request.stop] + inputs["stop_words"] = [request.stop] + # Check "is not None" specifically, because values of zero are valid. + if request.top_p is not None: + inputs["top_p"] = np.float32([[request.top_p]]) + if request.frequency_penalty is not None: + inputs["frequency_penalty"] = np.float32([[request.frequency_penalty]]) + if request.presence_penalty is not None: + inputs["presence_penalty"] = np.float32([[request.presence_penalty]]) + if request.seed is not None: + inputs["random_seed"] = np.uint64([[request.seed]]) + if request.temperature is not None: + inputs["temperature"] = np.float32([[request.temperature]]) + # FIXME: TRT-LLM doesn't currently support runtime changes of 'echo' and it + # is configured at model load time, so we don't handle it here for now. + return model.create_request(inputs=inputs) + + +def _construct_string_from_pointer(pointer: int, size: int) -> str: + """Constructs a Python string from a C pointer and size.""" + + # Create a ctypes string buffer + string_buffer = ctypes.create_string_buffer(size + 1) # +1 for null terminator + + # Copy the data from the pointer to the buffer + ctypes.memmove(string_buffer, pointer, size) + + # Convert the buffer to a Python string + return string_buffer.value.decode("utf-8") # Adjust encoding if needed + + +def _get_volume(shape: Iterable[int]) -> int: + volume = 1 + for dim in shape: + volume *= dim + + return volume + + +def _to_string(tensor: tritonserver.Tensor) -> str: + # FIXME: This could be a bit more robust by reading byte size from first + # 4 bytes and then just reading the first string, rather than assuming + # single string, assuming it's of similar performance to do so. + + # The following optimization to read string directly from buffer assumes + # there is only a single string, so enforce it to avoid obscure errors. + volume = _get_volume(tensor.shape) + if volume != 1: + raise Exception( + f"Expected to find 1 string in the output, found {volume} instead." + ) + if tensor.size < 4: + raise Exception( + f"Expected string buffer to contain its serialized byte size, but found size of {tensor.size}." + ) + + # NOTE: +/- 4 accounts for serialized byte string length in first 4 bytes of buffer + return _construct_string_from_pointer(tensor.data_ptr + 4, tensor.size - 4) + + +# TODO: Use tritonserver.InferenceResponse when support is published +def _get_output(response: tritonserver._api._response.InferenceResponse) -> str: + if "text_output" in response.outputs: + tensor = response.outputs["text_output"] + + # Alternative method, creates the same string, but goes through + # deserialization, numpy, and dlpack overhead: + # return tensor.to_bytes_array()[0].decode("utf-8") + + # Optimized method + return _to_string(tensor) + + return "" + + +def _validate_triton_responses_non_streaming( + responses: List[tritonserver._api._response.InferenceResponse], +): + num_responses = len(responses) + if num_responses == 1 and responses[0].final != True: + raise Exception("Unexpected internal error with incorrect response flags") + if num_responses == 2 and responses[-1].final != True: + raise Exception("Unexpected internal error with incorrect response flags") + if num_responses > 2: + raise Exception(f"Unexpected number of responses: {num_responses}, expected 1.") diff --git a/python/openai/openai_frontend/frontend/__init__.py b/python/openai/openai_frontend/frontend/__init__.py new file mode 100644 index 0000000000..77855ae979 --- /dev/null +++ b/python/openai/openai_frontend/frontend/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/openai_frontend/frontend/fastapi/__init__py b/python/openai/openai_frontend/frontend/fastapi/__init__py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/__init__.py b/python/openai/openai_frontend/frontend/fastapi/routers/__init__.py new file mode 100644 index 0000000000..dc1c939c66 --- /dev/null +++ b/python/openai/openai_frontend/frontend/fastapi/routers/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/chat.py b/python/openai/openai_frontend/frontend/fastapi/routers/chat.py new file mode 100644 index 0000000000..0f72047a5e --- /dev/null +++ b/python/openai/openai_frontend/frontend/fastapi/routers/chat.py @@ -0,0 +1,53 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from schemas.openai import CreateChatCompletionRequest, CreateChatCompletionResponse + +router = APIRouter() + + +@router.post( + "/v1/chat/completions", response_model=CreateChatCompletionResponse, tags=["Chat"] +) +async def create_chat_completion( + request: CreateChatCompletionRequest, + raw_request: Request, +) -> CreateChatCompletionResponse | StreamingResponse: + """ + Creates a chat completion for the provided messages and parameters. + """ + if not raw_request.app.engine: + raise HTTPException(status_code=500, detail="No attached inference engine") + + try: + response = await raw_request.app.engine.chat(request) + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + return response + except Exception as e: + raise HTTPException(status_code=400, detail=f"{e}") diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/completions.py b/python/openai/openai_frontend/frontend/fastapi/routers/completions.py new file mode 100644 index 0000000000..ade89a47cc --- /dev/null +++ b/python/openai/openai_frontend/frontend/fastapi/routers/completions.py @@ -0,0 +1,52 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from schemas.openai import CreateCompletionRequest, CreateCompletionResponse + +router = APIRouter() + + +@router.post( + "/v1/completions", response_model=CreateCompletionResponse, tags=["Completions"] +) +async def create_completion( + request: CreateCompletionRequest, raw_request: Request +) -> CreateCompletionResponse | StreamingResponse: + """ + Creates a completion for the provided prompt and parameters. + """ + if not raw_request.app.engine: + raise HTTPException(status_code=500, detail="No attached inference engine") + + try: + response = await raw_request.app.engine.completion(request) + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + return response + except Exception as e: + raise HTTPException(status_code=400, detail=f"{e}") diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/models.py b/python/openai/openai_frontend/frontend/fastapi/routers/models.py new file mode 100644 index 0000000000..ac2fa7fdc0 --- /dev/null +++ b/python/openai/openai_frontend/frontend/fastapi/routers/models.py @@ -0,0 +1,63 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +from fastapi import APIRouter, HTTPException, Request +from schemas.openai import ListModelsResponse, Model, ObjectType + +router = APIRouter() + +OWNED_BY = "Triton Inference Server" + + +@router.get("/v1/models", response_model=ListModelsResponse, tags=["Models"]) +def list_models(request: Request) -> ListModelsResponse: + """ + Lists the currently available models, and provides basic information about each one such as the owner and availability. + """ + if not request.app.engine: + raise HTTPException(status_code=500, detail="No attached inference engine") + + models: List[Model] = request.app.engine.models() + return ListModelsResponse(object=ObjectType.list, data=models) + + +@router.get("/v1/models/{model_name}", response_model=Model, tags=["Models"]) +def retrieve_model(request: Request, model_name: str) -> Model: + """ + Retrieves a model instance, providing basic information about the model such as the owner and permissioning. + """ + if not request.app.engine: + raise HTTPException(status_code=500, detail="No attached inference engine") + + # TODO: Return model directly from engine instead of searching models + models: List[Model] = request.app.engine.models() + for model in models: + if model.id == model_name: + return model + + raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/observability.py b/python/openai/openai_frontend/frontend/fastapi/routers/observability.py new file mode 100644 index 0000000000..b8040f56b7 --- /dev/null +++ b/python/openai/openai_frontend/frontend/fastapi/routers/observability.py @@ -0,0 +1,49 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import PlainTextResponse, Response + +router = APIRouter() + + +@router.get("/metrics", response_class=PlainTextResponse, tags=["Utilities"]) +def metrics(request: Request) -> PlainTextResponse: + return request.app.engine.metrics() + + +@router.get("/health/ready", tags=["Utilities"]) +def ready(request: Request) -> Response: + if not request.app.engine: + raise HTTPException(status_code=500, detail="No attached inference engine") + + if not request.app.engine.ready(): + raise HTTPException( + status_code=400, + detail="Attached inference engine is not ready for inference requests.", + ) + + return Response(status_code=200) diff --git a/python/openai/openai_frontend/frontend/fastapi_frontend.py b/python/openai/openai_frontend/frontend/fastapi_frontend.py new file mode 100644 index 0000000000..adee4cbab3 --- /dev/null +++ b/python/openai/openai_frontend/frontend/fastapi_frontend.py @@ -0,0 +1,109 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import uvicorn +from engine.triton_engine import TritonLLMEngine +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from frontend.fastapi.routers import chat, completions, models, observability +from frontend.frontend import OpenAIFrontend + + +class FastApiFrontend(OpenAIFrontend): + def __init__( + self, + engine: TritonLLMEngine, + host: str = "localhost", + port: int = 8000, + log_level: str = "info", + ): + self.host: str = host + self.port: int = port + self.log_level: str = log_level + self.stopped: bool = False + + self.app = self._create_app() + # Attach the inference engine to the FastAPI app + self.app.engine = engine + + def __del__(self): + self.stop() + + def start(self): + config = uvicorn.Config( + app=self.app, + host=self.host, + port=self.port, + log_level=self.log_level, + timeout_keep_alive=5, + ) + server = uvicorn.Server(config) + server.run() + + def stop(self): + # NOTE: If the frontend owned the engine, it could do cleanup here. + pass + + def _create_app(self): + app = FastAPI( + title="OpenAI API", + description="The OpenAI REST API. Please see https://platform.openai.com/docs/api-reference for more details.", + version="2.0.0", + termsOfService="https://openai.com/policies/terms-of-use", + contact={"name": "OpenAI Support", "url": "https://help.openai.com/"}, + license={ + "name": "MIT", + "url": "https://github.com/openai/openai-openapi/blob/master/LICENSE", + }, + ) + + app.include_router(observability.router) + app.include_router(models.router) + app.include_router(completions.router) + app.include_router(chat.router) + + # NOTE: For debugging purposes, should generally be restricted or removed + self._add_cors_middleware(app) + + return app + + def _add_cors_middleware(self, app: FastAPI): + # Allow API calls through browser /docs route for debug purposes + origins = [ + "http://localhost", + ] + + # TODO: Move towards logger instead of printing + print(f"[WARNING] Adding CORS for the following origins: {origins}") + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) diff --git a/python/openai/openai_frontend/frontend/frontend.py b/python/openai/openai_frontend/frontend/frontend.py new file mode 100644 index 0000000000..2e311a9aac --- /dev/null +++ b/python/openai/openai_frontend/frontend/frontend.py @@ -0,0 +1,43 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Protocol + + +class OpenAIFrontend(Protocol): + def start(self) -> None: + """ + Starts the OpenAI-compatible service. + """ + pass + + def stop(self) -> None: + """ + Stops the OpenAI-compatible service. + """ + pass diff --git a/python/openai/openai_frontend/main.py b/python/openai/openai_frontend/main.py new file mode 100755 index 0000000000..29fc684354 --- /dev/null +++ b/python/openai/openai_frontend/main.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import signal +from functools import partial + +import tritonserver +from engine.triton_engine import TritonLLMEngine +from frontend.fastapi_frontend import FastApiFrontend + + +def signal_handler( + server, openai_frontend, kserve_http_frontend, kserve_grpc_frontend, signal, frame +): + print(f"Received {signal=}, {frame=}") + # Graceful Shutdown + shutdown(server, openai_frontend, kserve_http_frontend, kserve_grpc_frontend) + + +def shutdown(server, openai_frontend, kserve_http, kserve_grpc): + print("Shutting down Triton OpenAI-Compatible Frontend...") + openai_frontend.stop() + + if kserve_http: + print("Shutting down Triton KServe HTTP Frontend...") + kserve_http.stop() + + if kserve_grpc: + print("Shutting down Triton KServe GRPC Frontend...") + kserve_grpc.stop() + + print("Shutting down Triton Inference Server...") + server.stop() + + +def start_kserve_frontends(server, args): + http_service, grpc_service = None, None + try: + from tritonfrontend import KServeGrpc, KServeHttp + + http_options = KServeHttp.Options(address=args.host, port=args.kserve_http_port) + http_service = KServeHttp.Server(server, http_options) + http_service.start() + + grpc_options = KServeGrpc.Options(address=args.host, port=args.kserve_grpc_port) + grpc_service = KServeGrpc.Server(server, grpc_options) + grpc_service.start() + + except ModuleNotFoundError: + # FIXME: Raise error instead of warning if kserve frontends are opt-in + print( + "[WARNING] The 'tritonfrontend' package was not found. " + "KServe frontends won't be available through this application without it. " + "Check /opt/tritonserver/python for tritonfrontend*.whl and pip install it if present." + ) + return http_service, grpc_service + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Triton Inference Server with OpenAI-Compatible RESTful API server." + ) + + # Triton Inference Server + triton_group = parser.add_argument_group("Triton Inference Server") + triton_group.add_argument( + "--model-repository", + type=str, + required=True, + help="Path to the Triton model repository holding the models to be served", + ) + triton_group.add_argument( + "--tokenizer", + type=str, + default=None, + help="HuggingFace ID or local folder path of the Tokenizer to use for chat templates", + ) + triton_group.add_argument( + "--backend", + type=str, + default=None, + choices=["vllm", "tensorrtllm"], + help="Manual override of Triton backend request format (inputs/output names) to use for inference", + ) + triton_group.add_argument( + "--tritonserver-log-verbose-level", + type=int, + default=0, + help="The tritonserver log verbosity level", + ) + triton_group.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Address/host of frontends (default: '0.0.0.0')", + ) + + # OpenAI-Compatible Frontend (FastAPI) + openai_group = parser.add_argument_group("Triton OpenAI-Compatible Frontend") + openai_group.add_argument( + "--openai-port", type=int, default=9000, help="OpenAI HTTP port (default: 9000)" + ) + openai_group.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=["debug", "info", "warning", "error", "critical", "trace"], + help="log level for uvicorn", + ) + + # KServe Predict v2 Frontend + kserve_group = parser.add_argument_group("Triton KServe Frontend") + kserve_group.add_argument( + "--enable-kserve-frontends", + action="store_true", + help="Enable KServe Predict v2 HTTP/GRPC frontends (disabled by default)", + ) + kserve_group.add_argument( + "--kserve-http-port", + type=int, + default=8000, + help="KServe Predict v2 HTTP port (default: 8000)", + ) + kserve_group.add_argument( + "--kserve-grpc-port", + type=int, + default=8001, + help="KServe Predict v2 GRPC port (default: 8001)", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Initialize a Triton Inference Server pointing at LLM models + server: tritonserver.Server = tritonserver.Server( + model_repository=args.model_repository, + log_verbose=args.tritonserver_log_verbose_level, + log_info=True, + log_warn=True, + log_error=True, + ).start(wait_until_ready=True) + + # Wrap Triton Inference Server in an interface-conforming "LLMEngine" + engine: TritonLLMEngine = TritonLLMEngine( + server=server, tokenizer=args.tokenizer, backend=args.backend + ) + + # Attach TritonLLMEngine as the backbone for inference and model management + openai_frontend: FastApiFrontend = FastApiFrontend( + engine=engine, + host=args.host, + port=args.openai_port, + log_level=args.uvicorn_log_level, + ) + + # Optionally expose Triton KServe HTTP/GRPC Frontends + kserve_http, kserve_grpc = None, None + if args.enable_kserve_frontends: + kserve_http, kserve_grpc = start_kserve_frontends(server, args) + + # Gracefully shutdown when receiving signals for testing and interactive use + signal.signal( + signal.SIGINT, + partial(signal_handler, server, openai_frontend, kserve_http, kserve_grpc), + ) + signal.signal( + signal.SIGTERM, + partial(signal_handler, server, openai_frontend, kserve_http, kserve_grpc), + ) + + # Blocking call until killed or interrupted with SIGINT + openai_frontend.start() + + +if __name__ == "__main__": + main() diff --git a/python/openai/openai_frontend/schemas/__init__.py b/python/openai/openai_frontend/schemas/__init__.py new file mode 100644 index 0000000000..dc1c939c66 --- /dev/null +++ b/python/openai/openai_frontend/schemas/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/openai_frontend/schemas/openai.py b/python/openai/openai_frontend/schemas/openai.py new file mode 100644 index 0000000000..9a562729a7 --- /dev/null +++ b/python/openai/openai_frontend/schemas/openai.py @@ -0,0 +1,918 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# generated by fastapi-codegen: +# filename: api-spec/openai_trimmed.yml +# timestamp: 2024-05-05T21:52:36+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel, confloat, conint + + +class Error(BaseModel): + code: str + message: str + param: str + type: str + + +class ErrorResponse(BaseModel): + error: Error + + +class Object(Enum): + list = "list" + + +class DeleteModelResponse(BaseModel): + id: str + deleted: bool + object: str + + +class Model1(Enum): + gpt_3_5_turbo_instruct = "gpt-3.5-turbo-instruct" + davinci_002 = "davinci-002" + babbage_002 = "babbage-002" + + +class PromptItem(RootModel): + root: List[Any] + + +class CreateCompletionRequest(BaseModel): + # Explicitly return errors for unknown fields. + model_config: ConfigDict = ConfigDict(extra="forbid") + + model: Union[str, Model1] = Field( + ..., + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ) + prompt: Union[str, List[str], List[int], List[PromptItem]] = Field( + ..., + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + ) + best_of: Optional[conint(ge=0, le=20)] = Field( + 1, + description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to the completion\n" + ) + frequency_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ) + # TODO: Extension, flesh out description and defaults + ignore_eos: Optional[bool] = Field( + False, description="Ignore end-of-sequence tokens during generation\n" + ) + logit_bias: Optional[Dict[str, int]] = Field( + None, + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n', + ) + logprobs: Optional[conint(ge=0, le=5)] = Field( + None, + description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", + ) + max_tokens: Optional[conint(ge=0)] = Field( + 16, + description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=[16], + ) + # TODO: Extension, flesh out description and defaults + min_tokens: Optional[conint(ge=0)] = Field( + None, + description="The minimum number of [tokens](/tokenizer) that should be generated in the completion.\n", + ) + n: Optional[conint(ge=1, le=128)] = Field( + 1, + description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", + examples=[1], + ) + presence_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ) + seed: Optional[conint(ge=-9223372036854775808, le=9223372036854775807)] = Field( + None, + description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ) + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + ) + stream: Optional[bool] = Field( + False, + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + suffix: Optional[str] = Field( + None, + description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", + examples=["test."], + ) + temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ) + top_p: Optional[confloat(ge=0.0, le=1.0)] = Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ) + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ) + + +class FinishReason(Enum): + stop = "stop" + length = "length" + content_filter = "content_filter" + + +class Logprobs(BaseModel): + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +class Choice(BaseModel): + finish_reason: FinishReason | None = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n", + ) + index: int + logprobs: Logprobs | None + text: str + + +class Object1(Enum): + text_completion = "text_completion" + + +class Type(Enum): + image_url = "image_url" + + +class Detail(Enum): + auto = "auto" + low = "low" + high = "high" + + +class ImageUrl(BaseModel): + url: AnyUrl = Field( + ..., description="Either a URL of the image or the base64 encoded image data." + ) + detail: Optional[Detail] = Field( + "auto", + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", + ) + + +class ChatCompletionRequestMessageContentPartImage(BaseModel): + type: Type = Field(..., description="The type of the content part.") + image_url: ImageUrl + + +class Type1(Enum): + text = "text" + + +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Type1 = Field(..., description="The type of the content part.") + text: str = Field(..., description="The text content.") + + +class Role(Enum): + system = "system" + + def __str__(self): + return self.name + + +class ChatCompletionRequestSystemMessage(BaseModel): + content: str = Field(..., description="The contents of the system message.") + role: Role = Field( + ..., description="The role of the messages author, in this case `system`." + ) + name: Optional[str] = Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ) + + +class Role1(Enum): + user = "user" + + def __str__(self): + return self.name + + +class Role2(Enum): + assistant = "assistant" + + def __str__(self): + return self.name + + +class FunctionCall(BaseModel): + arguments: str = Field( + ..., + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + name: str = Field(..., description="The name of the function to call.") + + +class Role3(Enum): + tool = "tool" + + def __str__(self): + return self.name + + +class ChatCompletionRequestToolMessage(BaseModel): + role: Role3 = Field( + ..., description="The role of the messages author, in this case `tool`." + ) + content: str = Field(..., description="The contents of the tool message.") + tool_call_id: str = Field( + ..., description="Tool call that this message is responding to." + ) + + +class Role4(Enum): + function = "function" + + def __str__(self): + return self.name + + +class ChatCompletionRequestFunctionMessage(BaseModel): + role: Role4 = Field( + ..., description="The role of the messages author, in this case `function`." + ) + content: str = Field(..., description="The contents of the function message.") + name: str = Field(..., description="The name of the function to call.") + + +class FunctionParameters(BaseModel): + model_config = ConfigDict(extra="allow") + + +class ChatCompletionFunctions(BaseModel): + description: Optional[str] = Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ) + name: str = Field( + ..., + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.", + ) + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionFunctionCallOption(BaseModel): + name: str = Field(..., description="The name of the function to call.") + + +class Type2(Enum): + function = "function" + + +class FunctionObject(BaseModel): + description: Optional[str] = Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ) + name: str = Field( + ..., + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.", + ) + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionToolChoiceOption1(Enum): + none = "none" + auto = "auto" + required = "required" + + +class Function(BaseModel): + name: str = Field(..., description="The name of the function to call.") + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Type2 = Field( + ..., + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Function + + +class Function1(BaseModel): + name: str = Field(..., description="The name of the function to call.") + arguments: str = Field( + ..., + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + + +class ChatCompletionMessageToolCall(BaseModel): + id: str = Field(..., description="The ID of the tool call.") + type: Type2 = Field( + ..., + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Function1 = Field(..., description="The function that the model called.") + + +class Function2(BaseModel): + name: Optional[str] = Field(None, description="The name of the function to call.") + arguments: Optional[str] = Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + + +class ChatCompletionMessageToolCallChunk(BaseModel): + index: int + id: Optional[str] = Field(None, description="The ID of the tool call.") + type: Optional[Type2] = Field( + None, + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Optional[Function2] = None + + +class ChatCompletionRole(Enum): + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + function = "function" + + +class Role5(Enum): + assistant = "assistant" + + def __str__(self): + return self.name + + +class FunctionCall2(BaseModel): + arguments: Optional[str] = Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + name: Optional[str] = Field(None, description="The name of the function to call.") + + +class Role6(Enum): + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + + def __str__(self): + return self.name + + +class ChatCompletionStreamResponseDelta(BaseModel): + content: Optional[str] = Field( + None, description="The contents of the chunk message." + ) + function_call: Optional[FunctionCall2] = Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ) + tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None + role: Optional[str] = Field( + None, description="The role of the author of this message." + ) + + +class Model2(Enum): + gpt_4_turbo = "gpt-4-turbo" + gpt_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" + gpt_4_0125_preview = "gpt-4-0125-preview" + gpt_4_turbo_preview = "gpt-4-turbo-preview" + gpt_4_1106_preview = "gpt-4-1106-preview" + gpt_4_vision_preview = "gpt-4-vision-preview" + gpt_4 = "gpt-4" + gpt_4_0314 = "gpt-4-0314" + gpt_4_0613 = "gpt-4-0613" + gpt_4_32k = "gpt-4-32k" + gpt_4_32k_0314 = "gpt-4-32k-0314" + gpt_4_32k_0613 = "gpt-4-32k-0613" + gpt_3_5_turbo = "gpt-3.5-turbo" + gpt_3_5_turbo_16k = "gpt-3.5-turbo-16k" + gpt_3_5_turbo_0301 = "gpt-3.5-turbo-0301" + gpt_3_5_turbo_0613 = "gpt-3.5-turbo-0613" + gpt_3_5_turbo_1106 = "gpt-3.5-turbo-1106" + gpt_3_5_turbo_0125 = "gpt-3.5-turbo-0125" + gpt_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" + + +class Type6(Enum): + text = "text" + json_object = "json_object" + + +class ResponseFormat(BaseModel): + type: Optional[Type6] = Field( + "text", + description="Must be one of `text` or `json_object`.", + examples=["json_object"], + ) + + +class FunctionCall3(Enum): + none = "none" + auto = "auto" + + +class ChatCompletionFinishReason(Enum): + stop = "stop" + length = "length" + tool_calls = "tool_calls" + content_filter = "content_filter" + function_call = "function_call" + + +class Object2(Enum): + chat_completion = "chat.completion" + + +class FinishReason2(Enum): + stop = "stop" + length = "length" + function_call = "function_call" + content_filter = "content_filter" + + +class TopLogprob(BaseModel): + token: str = Field(..., description="The token.") + logprob: float = Field( + ..., + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely.", + ) + bytes: List[int] = Field( + ..., + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token.", + ) + + +class ChatCompletionTokenLogprob(BaseModel): + token: str = Field(..., description="The token.") + logprob: float = Field( + ..., + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely.", + ) + bytes: List[int] = Field( + ..., + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token.", + ) + top_logprobs: List[TopLogprob] = Field( + ..., + description="List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned.", + ) + + +class Logprobs2(BaseModel): + content: List[ChatCompletionTokenLogprob] = Field( + ..., + description="A list of message content tokens with log probability information.", + ) + + +class ChatCompletionFinishReason(Enum): + stop = "stop" + length = "length" + tool_calls = "tool_calls" + content_filter = "content_filter" + function_call = "function_call" + + +class ChatCompletionStreamingResponseChoice(BaseModel): + delta: ChatCompletionStreamResponseDelta + logprobs: Optional[Logprobs2] = Field( + None, description="Log probability information for the choice." + ) + finish_reason: ChatCompletionFinishReason | None = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + + +class Object4(Enum): + chat_completion_chunk = "chat.completion.chunk" + + +class CreateChatCompletionStreamResponse(BaseModel): + id: str = Field( + ..., + description="A unique identifier for the chat completion. Each chunk has the same ID.", + ) + choices: List[ChatCompletionStreamingResponseChoice] = Field( + ..., + description="A list of chat completion choices. Can be more than one if `n` is greater than 1.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.", + ) + model: str = Field(..., description="The model to generate the completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object4 = Field( + ..., description="The object type, which is always `chat.completion.chunk`." + ) + + +class CreateChatCompletionImageResponse(BaseModel): + pass + + +class Object5(Enum): + model = "model" + + +class Model(BaseModel): + id: str = Field( + ..., + description="The model identifier, which can be referenced in the API endpoints.", + ) + created: int = Field( + ..., description="The Unix timestamp (in seconds) when the model was created." + ) + object: Object5 = Field( + ..., description='The object type, which is always "model".' + ) + owned_by: str = Field(..., description="The organization that owns the model.") + + +class CompletionUsage(BaseModel): + completion_tokens: int = Field( + ..., description="Number of tokens in the generated completion." + ) + prompt_tokens: int = Field(..., description="Number of tokens in the prompt.") + total_tokens: int = Field( + ..., + description="Total number of tokens used in the request (prompt + completion).", + ) + + +class Event(Enum): + error = "error" + + +class ErrorEvent(BaseModel): + event: Event + data: Error + + +class Event1(Enum): + done = "done" + + +class Data(Enum): + field_DONE_ = "[DONE]" + + +class DoneEvent(BaseModel): + event: Event1 + data: Data + + +class ListModelsResponse(BaseModel): + object: Object + data: List[Model] + + +class CreateCompletionResponse(BaseModel): + id: str = Field(..., description="A unique identifier for the completion.") + choices: List[Choice] = Field( + ..., + description="The list of completion choices the model generated for the input prompt.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the completion was created.", + ) + model: str = Field(..., description="The model used for completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object1 = Field( + ..., description='The object type, which is always "text_completion"' + ) + usage: Optional[CompletionUsage] = None + + +class ChatCompletionRequestMessageContentPart(RootModel): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Union[str, List[ChatCompletionRequestMessageContentPart]] = Field( + ..., description="The contents of the user message.\n" + ) + role: Role1 = Field( + ..., description="The role of the messages author, in this case `user`." + ) + name: Optional[str] = Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ) + + +class ChatCompletionTool(BaseModel): + type: Type2 = Field( + ..., + description="The type of the tool. Currently, only `function` is supported.", + ) + function: FunctionObject + + +class ChatCompletionToolChoiceOption(RootModel): + root: Union[ChatCompletionToolChoiceOption1, ChatCompletionNamedToolChoice] = Field( + ..., + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n', + ) + + +class ChatCompletionMessageToolCalls(RootModel): + root: List[ChatCompletionMessageToolCall] = Field( + ..., + description="The tool calls generated by the model, such as function calls.", + ) + + +class ChatCompletionResponseMessage(BaseModel): + content: str = Field(..., description="The contents of the message.") + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + role: str = Field(..., description="The role of the author of this message.") + function_call: Optional[FunctionCall] = Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ) + + +class ChatCompletionChoice(BaseModel): + finish_reason: ChatCompletionFinishReason = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + message: ChatCompletionResponseMessage + logprobs: Logprobs2 | None = Field( + ..., description="Log probability information for the choice." + ) + + +class CreateChatCompletionResponse(BaseModel): + id: str = Field(..., description="A unique identifier for the chat completion.") + choices: List[ChatCompletionChoice] = Field( + ..., + description="A list of chat completion choices. Can be more than one if `n` is greater than 1.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: str = Field(..., description="The model used for the chat completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object2 = Field( + ..., description="The object type, which is always `chat.completion`." + ) + usage: Optional[CompletionUsage] = None + + +class Choice2(BaseModel): + finish_reason: FinishReason2 = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + message: ChatCompletionResponseMessage + + +class CreateChatCompletionFunctionResponse(BaseModel): + id: str = Field(..., description="A unique identifier for the chat completion.") + choices: List[Choice2] = Field( + ..., + description="A list of chat completion choices. Can be more than one if `n` is greater than 1.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: str = Field(..., description="The model used for the chat completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object2 = Field( + ..., description="The object type, which is always `chat.completion`." + ) + usage: Optional[CompletionUsage] = None + + +class ChatCompletionRequestAssistantMessage(BaseModel): + content: Optional[str] = Field( + None, + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", + ) + role: Role2 = Field( + ..., description="The role of the messages author, in this case `assistant`." + ) + name: Optional[str] = Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ) + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + function_call: Optional[FunctionCall] = Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ) + + +class ChatCompletionRequestMessage(RootModel): + root: Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + + @property + def role(self): + return self.root.role + + @property + def content(self): + return self.root.content + + +class CreateChatCompletionRequest(BaseModel): + # Explicitly return errors for unknown fields. + model_config: ConfigDict = ConfigDict(extra="forbid") + + messages: List[ChatCompletionRequestMessage] = Field( + ..., + description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", + min_length=1, + ) + model: Union[str, Model2] = Field( + ..., + description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", + examples=["gpt-4-turbo"], + ) + frequency_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ) + # TODO: Extension, flesh out description and defaults + ignore_eos: Optional[bool] = Field( + False, description="Ignore end-of-sequence tokens during generation\n" + ) + logit_bias: Optional[Dict[str, int]] = Field( + None, + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n", + ) + logprobs: Optional[bool] = Field( + False, + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.", + ) + top_logprobs: Optional[conint(ge=0, le=20)] = Field( + None, + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", + ) + # TODO: Consider new max_completion_tokens field in the future: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_completion_tokens + max_tokens: Optional[conint(ge=0)] = Field( + 16, + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + ) + # TODO: Extension, flesh out description and defaults + min_tokens: Optional[conint(ge=0)] = Field( + None, + description="The minimum number of [tokens](/tokenizer) that should be generated in the chat completion.\n", + ) + n: Optional[conint(ge=1, le=128)] = Field( + 1, + description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", + examples=[1], + ) + presence_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ) + response_format: Optional[ResponseFormat] = Field( + None, + description='An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + ) + seed: Optional[conint(ge=-9223372036854775808, le=9223372036854775807)] = Field( + None, + description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ) + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.\n", + ) + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( + 0.7, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ) + top_p: Optional[confloat(ge=0.0, le=1.0)] = Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ) + tools: Optional[List[ChatCompletionTool]] = Field( + None, + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n", + ) + tool_choice: Optional[ChatCompletionToolChoiceOption] = None + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ) + function_call: Optional[ + Union[FunctionCall3, ChatCompletionFunctionCallOption] + ] = Field( + None, + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', + ) + functions: Optional[List[ChatCompletionFunctions]] = Field( + None, + description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", + max_length=128, + min_length=1, + ) + + +# Additional Aliases for Convenience + + +class ObjectType: + model = Object5.model + list = Object.list + text_completion = Object1.text_completion + chat_completion_chunk = Object4.chat_completion_chunk + chat_completion = Object2.chat_completion diff --git a/python/openai/requirements-test.txt b/python/openai/requirements-test.txt new file mode 100644 index 0000000000..08c098811b --- /dev/null +++ b/python/openai/requirements-test.txt @@ -0,0 +1,3 @@ +# Testing +pytest==8.1.1 +pytest-asyncio==0.23.8 diff --git a/python/openai/requirements.txt b/python/openai/requirements.txt new file mode 100644 index 0000000000..d87feaa6f2 --- /dev/null +++ b/python/openai/requirements.txt @@ -0,0 +1,3 @@ +# FastAPI Application +fastapi==0.111.1 +openai==1.40.6 diff --git a/python/openai/tests/__init__.py b/python/openai/tests/__init__.py new file mode 100644 index 0000000000..dc1c939c66 --- /dev/null +++ b/python/openai/tests/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/openai/tests/conftest.py b/python/openai/tests/conftest.py new file mode 100644 index 0000000000..9ea9a5634e --- /dev/null +++ b/python/openai/tests/conftest.py @@ -0,0 +1,147 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient +from tests.utils import OpenAIServer, setup_fastapi_app, setup_server + + +### TEST ENVIRONMENT SETUP ### +def infer_test_environment(): + # Infer the test environment for simplicity in local dev/testing. + try: + import vllm as _ + + backend = "vllm" + model = "llama-3.1-8b-instruct" + return backend, model + except ImportError: + print("No vllm installation found.") + + try: + import tensorrt_llm as _ + + backend = "tensorrtllm" + model = "tensorrt_llm_bls" + return backend, model + except ImportError: + print("No tensorrt_llm installation found.") + + raise Exception("Unknown test environment") + + +def infer_test_model_repository(backend): + model_repository = str(Path(__file__).parent / f"{backend}_models") + return model_repository + + +# TODO: Refactor away from global variables +TEST_MODEL = os.environ.get("TEST_MODEL") +TEST_BACKEND = os.environ.get("TEST_BACKEND") +TEST_MODEL_REPOSITORY = os.environ.get("TEST_MODEL_REPOSITORY") + +TEST_TOKENIZER = os.environ.get( + "TEST_TOKENIZER", "meta-llama/Meta-Llama-3.1-8B-Instruct" +) +TEST_PROMPT = "What is machine learning?" +TEST_MESSAGES = [{"role": "user", "content": TEST_PROMPT}] + +if not TEST_BACKEND or not TEST_MODEL: + TEST_BACKEND, TEST_MODEL = infer_test_environment() + +if not TEST_MODEL_REPOSITORY: + TEST_MODEL_REPOSITORY = infer_test_model_repository(TEST_BACKEND) + + +# NOTE: OpenAI client requires actual server running, and won't work +# with the FastAPI TestClient. Run the server at module scope to run +# only once for all the tests below. +@pytest.fixture(scope="module") +def server(): + args = [ + "--model-repository", + TEST_MODEL_REPOSITORY, + "--tokenizer", + TEST_TOKENIZER, + "--backend", + TEST_BACKEND, + ] + # TODO: Incorporate kserve frontend binding smoke tests to catch any + # breakage with default values or slight cli arg variations + extra_args = ["--enable-kserve-frontends"] + args += extra_args + + with OpenAIServer(args) as openai_server: + yield openai_server + + +# NOTE: The FastAPI TestClient acts like a server and triggers the FastAPI app +# lifespan startup/shutdown, but does not actually expose the network port to interact +# with arbitrary clients - you must use the TestClient returned to interact with +# the "server" when "starting the server" via TestClient. +@pytest.fixture(scope="class") +def fastapi_client_class_scope(): + server = setup_server(model_repository=TEST_MODEL_REPOSITORY) + app = setup_fastapi_app( + tokenizer=TEST_TOKENIZER, server=server, backend=TEST_BACKEND + ) + with TestClient(app) as test_client: + yield test_client + + server.stop() + + +@pytest.fixture(scope="module") +def model_repository(): + return TEST_MODEL_REPOSITORY + + +@pytest.fixture(scope="module") +def model(): + return TEST_MODEL + + +@pytest.fixture(scope="module") +def backend(): + return TEST_BACKEND + + +@pytest.fixture(scope="module") +def tokenizer_model(): + return TEST_TOKENIZER + + +@pytest.fixture(scope="module") +def prompt(): + return TEST_PROMPT + + +@pytest.fixture(scope="module") +def messages(): + return TEST_MESSAGES diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py new file mode 100644 index 0000000000..347f5939f1 --- /dev/null +++ b/python/openai/tests/test_chat_completions.py @@ -0,0 +1,584 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import copy +import subprocess +from pathlib import Path +from typing import List + +import pytest +import tritonserver +from fastapi.testclient import TestClient +from tests.utils import setup_fastapi_app, setup_server + + +class TestChatCompletions: + @pytest.fixture(scope="class") + def client(self, fastapi_client_class_scope): + yield fastapi_client_class_scope + + def test_chat_completions_defaults(self, client, model: str, messages: List[dict]): + response = client.post( + "/v1/chat/completions", + json={"model": model, "messages": messages}, + ) + + assert response.status_code == 200 + message = response.json()["choices"][0]["message"] + assert message["content"].strip() + assert message["role"] == "assistant" + # "usage" currently not supported + assert not response.json()["usage"] + + def test_chat_completions_system_prompt(self, client, model: str): + # NOTE: Currently just sanity check that there are no issues when a + # system role is provided. There is no test logic to measure the quality + # of the response yet. + messages = [ + {"role": "system", "content": "You are a Triton Inference Server expert."}, + {"role": "user", "content": "What is machine learning?"}, + ] + + response = client.post( + "/v1/chat/completions", json={"model": model, "messages": messages} + ) + + assert response.status_code == 200 + message = response.json()["choices"][0]["message"] + assert message["content"].strip() + assert message["role"] == "assistant" + + def test_chat_completions_system_prompt_only(self, client, model: str): + # No user prompt provided + messages = [ + {"role": "system", "content": "You are a Triton Inference Server expert."} + ] + + response = client.post( + "/v1/chat/completions", json={"model": model, "messages": messages} + ) + + assert response.status_code == 200 + message = response.json()["choices"][0]["message"] + assert message["content"].strip() + assert message["role"] == "assistant" + + @pytest.mark.parametrize( + "param_key, param_value", + [ + ("temperature", 0.7), + ("max_tokens", 10), + ("top_p", 0.9), + ("frequency_penalty", 0.5), + ("presence_penalty", 0.2), + ("n", 1), + # Single stop word as a string + ("stop", "."), + # List of stop words + ("stop", []), + ("stop", [".", ","]), + # logprobs is a boolean for chat completions + ("logprobs", True), + ("logit_bias", {"0": 0}), + # NOTE: Extensions to the spec + ("min_tokens", 16), + ("ignore_eos", True), + ], + ) + def test_chat_completions_sampling_parameters( + self, client, param_key, param_value, model: str, messages: List[dict] + ): + response = client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": messages, + param_key: param_value, + }, + ) + + # FIXME: Add support and remove this check + unsupported_parameters = ["logprobs", "logit_bias"] + if param_key in unsupported_parameters: + assert response.status_code == 400 + assert ( + response.json()["detail"] + == "logit bias and log probs not currently supported" + ) + return + + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] + assert response.json()["choices"][0]["message"]["role"] == "assistant" + + @pytest.mark.parametrize( + "param_key, param_value", + [ + ("temperature", 2.1), + ("temperature", -0.1), + ("max_tokens", -1), + ("top_p", 1.1), + ("frequency_penalty", 3), + ("frequency_penalty", -3), + ("presence_penalty", 2.1), + ("presence_penalty", -2.1), + # NOTE: Extensions to the spec + ("min_tokens", -1), + ("ignore_eos", 123), + ], + ) + def test_chat_completions_invalid_sampling_parameters( + self, client, param_key, param_value, model: str, messages: List[dict] + ): + response = client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": messages, + param_key: param_value, + }, + ) + print("Response:", response.json()) + + # Assert schema validation error + assert response.status_code == 422 + + # Simple tests to verify max_tokens roughly behaves as expected + def test_chat_completions_max_tokens( + self, client, model: str, messages: List[dict] + ): + responses = [] + payload = {"model": model, "messages": messages, "max_tokens": 1} + + # Send two requests with max_tokens = 1 to check their similarity + payload["max_tokens"] = 1 + responses.append( + client.post( + "/v1/chat/completions", + json=payload, + ) + ) + responses.append( + client.post( + "/v1/chat/completions", + json=payload, + ) + ) + # Send one requests with larger max_tokens to check its dis-similarity + payload["max_tokens"] = 100 + responses.append( + client.post( + "/v1/chat/completions", + json=payload, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = ( + responses[0].json()["choices"][0]["message"]["content"].strip().split() + ) + response2_text = ( + responses[1].json()["choices"][0]["message"]["content"].strip().split() + ) + response3_text = ( + responses[2].json()["choices"][0]["message"]["content"].strip().split() + ) + # Simplification: One token shouldn't be more than one space-delimited word + assert len(response1_text) == len(response2_text) == 1 + assert len(response3_text) > len(response1_text) + + @pytest.mark.parametrize( + "temperature", + [0.0, 1.0], + ) + # Simple tests to verify temperature roughly behaves as expected + def test_chat_completions_temperature_vllm( + self, client, temperature, backend: str, model: str, messages: List[dict] + ): + if backend != "vllm": + pytest.skip(reason="Only used to test vLLM-specific temperature behavior") + + responses = [] + payload = { + "model": model, + "messages": messages, + "max_tokens": 256, + "temperature": temperature, + } + + responses.append( + client.post( + "/v1/chat/completions", + json=payload, + ) + ) + responses.append( + client.post( + "/v1/chat/completions", + json=payload, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = ( + responses[0].json()["choices"][0]["message"]["content"].strip().split() + ) + response2_text = ( + responses[1].json()["choices"][0]["message"]["content"].strip().split() + ) + + # Temperature of 0.0 indicates greedy sampling, so check + # that two equivalent requests produce the same response. + if temperature == 0.0: + # NOTE: This check may be ambitious to get an exact match in all + # cases depending on how other parameter defaults are set, so + # it can probably be removed if it introduces flakiness. + assert response1_text == response2_text + # Temperature of 1.0 indicates maximum randomness, so check + # that two equivalent requests produce different responses. + elif temperature == 1.0: + assert response1_text != response2_text + # Don't bother checking values other than the extremes + else: + raise ValueError(f"Unexpected {temperature=} for this test.") + + # Remove xfail when fix is released and this test returns xpass status + @pytest.mark.xfail( + reason="TRT-LLM BLS model will ignore temperature until a later release" + ) + # Simple tests to verify temperature roughly behaves as expected + def test_chat_completions_temperature_tensorrtllm( + self, client, backend: str, model: str, messages: List[dict] + ): + if backend != "tensorrtllm": + pytest.skip( + reason="Only used to test TRT-LLM-specific temperature behavior" + ) + + responses = [] + payload1 = { + "model": model, + "messages": messages, + # Increase token length to allow more room for variability + "max_tokens": 200, + "temperature": 0.0, + # TRT-LLM requires certain settings of `top_k` / `top_p` to + # respect changes in `temperature` + "top_p": 0.5, + } + + payload2 = copy.deepcopy(payload1) + payload2["temperature"] = 1.0 + + # First 2 responses should be the same in TRT-LLM with identical payload + responses.append( + client.post( + "/v1/chat/completions", + json=payload1, + ) + ) + responses.append( + client.post( + "/v1/chat/completions", + json=payload1, + ) + ) + # Third response should differ with different temperature in payload + responses.append( + client.post( + "/v1/chat/completions", + json=payload2, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = ( + responses[0].json()["choices"][0]["message"]["content"].strip().split() + ) + response2_text = ( + responses[1].json()["choices"][0]["message"]["content"].strip().split() + ) + response3_text = ( + responses[2].json()["choices"][0]["message"]["content"].strip().split() + ) + + assert response1_text == response2_text + assert response1_text != response3_text + + # Simple tests to verify random seed roughly behaves as expected + def test_chat_completions_seed(self, client, model: str, messages: List[dict]): + responses = [] + payload1 = { + "model": model, + "messages": messages, + # Increase token length to allow more room for variability + "max_tokens": 200, + "seed": 1, + } + payload2 = copy.deepcopy(payload1) + payload2["seed"] = 2 + + # First 2 responses should be the same in both vLLM and TRT-LLM with identical seed + responses.append( + client.post( + "/v1/chat/completions", + json=payload1, + ) + ) + responses.append( + client.post( + "/v1/chat/completions", + json=payload1, + ) + ) + # Third response should differ with different seed in payload + responses.append( + client.post( + "/v1/chat/completions", + json=payload2, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = ( + responses[0].json()["choices"][0]["message"]["content"].strip().split() + ) + response2_text = ( + responses[1].json()["choices"][0]["message"]["content"].strip().split() + ) + response3_text = ( + responses[2].json()["choices"][0]["message"]["content"].strip().split() + ) + + assert response1_text == response2_text + assert response1_text != response3_text + + def test_chat_completions_no_message( + self, client, model: str, messages: List[dict] + ): + # Message validation requires min_length of 1 + messages = [] + response = client.post( + "/v1/chat/completions", json={"model": model, "messages": messages} + ) + assert response.status_code == 422 + assert ( + response.json()["detail"][0]["msg"] + == "List should have at least 1 item after validation, not 0" + ) + + def test_chat_completions_empty_message( + self, client, model: str, messages: List[dict] + ): + # Message validation requires min_length of 1 + messages = [{}] + response = client.post( + "/v1/chat/completions", json={"model": model, "messages": messages} + ) + assert response.status_code == 422 + assert response.json()["detail"][0]["msg"] == "Field required" + + def test_chat_completions_multiple_choices( + self, client, model: str, messages: List[dict] + ): + response = client.post( + "/v1/chat/completions", + json={"model": model, "messages": messages, "n": 2}, + ) + + assert response.status_code == 400 + assert "only single choice" in response.json()["detail"] + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_chat_completions_streaming(self, client): + pass + + def test_chat_completions_no_streaming( + self, client, model: str, messages: List[dict] + ): + response = client.post( + "/v1/chat/completions", + json={"model": model, "messages": messages, "stream": False}, + ) + + assert response.status_code == 200 + message = response.json()["choices"][0]["message"] + assert message["content"].strip() + assert message["role"] == "assistant" + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_function_calling(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_lora(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_multi_lora(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_request_n_choices(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_request_logprobs(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_request_logit_bias(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_usage_response(self): + pass + + +# For tests that won't use the same pytest fixture for server startup across +# the whole class test suite. +class TestChatCompletionsTokenizers: + # Re-use a single Triton server for different frontend configurations + @pytest.fixture(scope="class") + def server(self, model_repository: str): + server = setup_server(model_repository) + yield server + server.stop() + + # A tokenizer must be known for /chat/completions endpoint in order to + # apply chat templates, and for simplicity in determination, users should + # define the tokenizer. So, explicitly raise an error if none is provided. + def test_chat_completions_no_tokenizer( + self, + server: tritonserver.Server, + backend: str, + model: str, + messages: List[dict], + ): + app = setup_fastapi_app(tokenizer="", server=server, backend=backend) + with TestClient(app) as client: + response = client.post( + "/v1/chat/completions", + json={"model": model, "messages": messages}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Unknown tokenizer" + + def test_chat_completions_custom_tokenizer( + self, + server: tritonserver.Server, + backend: str, + tokenizer_model: str, + model: str, + messages: List[dict], + ): + # Tokenizers can be provided by a local file path to a directory containing + # the relevant files such as tokenizer.json and tokenizer_config.json. + custom_tokenizer_path = str(Path(__file__).parent / "custom_tokenizer") + download_cmd = f"huggingface-cli download --local-dir {custom_tokenizer_path} {tokenizer_model} --include *.json" + print(f"Running download command: {download_cmd}") + subprocess.run(download_cmd.split(), check=True) + + # Compare the downloaded tokenizer response against remote HF equivalent + # to assert equivalent functionality in responses and chat template. + app_local = setup_fastapi_app( + tokenizer=custom_tokenizer_path, server=server, backend=backend + ) + app_hf = setup_fastapi_app( + tokenizer=tokenizer_model, server=server, backend=backend + ) + + responses = [] + with TestClient(app_local) as client_local, TestClient(app_hf) as client_hf: + payload = {"model": model, "messages": messages, "temperature": 0} + responses.append(client_local.post("/v1/chat/completions", json=payload)) + responses.append(client_hf.post("/v1/chat/completions", json=payload)) + + for response in responses: + assert response.status_code == 200 + message = response.json()["choices"][0]["message"] + assert message["content"].strip() + assert message["role"] == "assistant" + + def equal_dicts(d1, d2, ignore_keys): + d1_filtered = {k: v for k, v in d1.items() if k not in ignore_keys} + d2_filtered = {k: v for k, v in d2.items() if k not in ignore_keys} + return d1_filtered == d2_filtered + + ignore_keys = ["id", "created"] + assert equal_dicts( + responses[0].json(), responses[1].json(), ignore_keys=ignore_keys + ) + + def test_chat_completions_invalid_chat_tokenizer( + self, + server: tritonserver.Server, + backend: str, + model: str, + messages: List[dict], + ): + # NOTE: Use of apply_chat_template on a tokenizer that doesn't support it + # is a warning prior to transformers 4.44, and an error afterwards. + # NOTE: Can remove after both TRT-LLM and VLLM containers have this version. + import transformers + + print(f"{transformers.__version__=}") + if transformers.__version__ < "4.44.0": + pytest.xfail() + + # Pick a tokenizer with no chat template defined + invalid_chat_tokenizer = "gpt2" + app = setup_fastapi_app( + tokenizer=invalid_chat_tokenizer, server=server, backend=backend + ) + with TestClient(app) as client: + response = client.post( + "/v1/chat/completions", + json={"model": model, "messages": messages}, + ) + + assert response.status_code == 400 + # Error may vary based on transformers version + expected_errors = [ + "cannot use apply_chat_template()", + "cannot use chat template", + ] + assert any( + error in response.json()["detail"].lower() for error in expected_errors + ) diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py new file mode 100644 index 0000000000..d89ff4701e --- /dev/null +++ b/python/openai/tests/test_completions.py @@ -0,0 +1,371 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import copy + +import pytest + + +class TestCompletions: + @pytest.fixture(scope="class") + def client(self, fastapi_client_class_scope): + yield fastapi_client_class_scope + + def test_completions_defaults(self, client, model: str, prompt: str): + response = client.post( + "/v1/completions", + json={"model": model, "prompt": prompt}, + ) + + print("Response:", response.json()) + assert response.status_code == 200 + # NOTE: Could be improved to look for certain quality of response, + # or tested with dummy identity model. + assert response.json()["choices"][0]["text"].strip() + # "usage" currently not supported + assert not response.json()["usage"] + + @pytest.mark.parametrize( + "sampling_parameter, value", + [ + ("temperature", 0.7), + ("max_tokens", 10), + ("top_p", 0.9), + ("frequency_penalty", 0.5), + ("presence_penalty", 0.2), + ("best_of", 1), + ("n", 1), + # logprobs is an integer for completions + ("logprobs", 5), + ("logit_bias", {"0": 0}), + # NOTE: Extensions to the spec + ("min_tokens", 16), + ("ignore_eos", True), + ], + ) + def test_completions_sampling_parameters( + self, client, sampling_parameter, value, model: str, prompt: str + ): + response = client.post( + "/v1/completions", + json={ + "model": model, + "prompt": prompt, + sampling_parameter: value, + }, + ) + print("Response:", response.json()) + + # FIXME: Add support and remove this check + unsupported_parameters = ["logprobs", "logit_bias"] + if sampling_parameter in unsupported_parameters: + assert response.status_code == 400 + assert response.json()["detail"] == "logit bias and log probs not supported" + return + + assert response.status_code == 200 + assert response.json()["choices"][0]["text"].strip() + + # Simple tests to verify max_tokens roughly behaves as expected + def test_completions_max_tokens(self, client, model: str, prompt: str): + responses = [] + payload = {"model": model, "prompt": prompt, "max_tokens": 1} + + # Send two requests with max_tokens = 1 to check their similarity + payload["max_tokens"] = 1 + responses.append( + client.post( + "/v1/completions", + json=payload, + ) + ) + responses.append( + client.post( + "/v1/completions", + json=payload, + ) + ) + # Send one requests with larger max_tokens to check its dis-similarity + payload["max_tokens"] = 100 + responses.append( + client.post( + "/v1/completions", + json=payload, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = responses[0].json()["choices"][0]["text"].strip().split() + response2_text = responses[1].json()["choices"][0]["text"].strip().split() + response3_text = responses[2].json()["choices"][0]["text"].strip().split() + # Simplification: One token shouldn't be more than one space-delimited word + assert len(response1_text) == len(response2_text) == 1 + assert len(response3_text) > len(response1_text) + + @pytest.mark.parametrize( + "temperature", + [0.0, 1.0], + ) + # Simple tests to verify temperature roughly behaves as expected + def test_completions_temperature_vllm( + self, client, temperature, backend: str, model: str, prompt: str + ): + if backend != "vllm": + pytest.skip(reason="Only used to test vLLM-specific temperature behavior") + + responses = [] + payload = { + "model": model, + "prompt": prompt, + "temperature": temperature, + } + + responses.append( + client.post( + "/v1/completions", + json=payload, + ) + ) + responses.append( + client.post( + "/v1/completions", + json=payload, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = responses[0].json()["choices"][0]["text"].strip().split() + response2_text = responses[1].json()["choices"][0]["text"].strip().split() + + # Temperature of 0.0 indicates greedy sampling, so check + # that two equivalent requests produce the same response. + if temperature == 0.0: + # NOTE: This check may be ambitious to get an exact match in all + # frameworks depending on how other parameter defaults are set, so + # it can probably be removed if it introduces flakiness. + print(f"Comparing '{response1_text}' == '{response2_text}'") + assert response1_text == response2_text + # Temperature of 1.0 indicates maximum randomness, so check + # that two equivalent requests produce different responses. + elif temperature == 1.0: + print(f"Comparing '{response1_text}' != '{response2_text}'") + assert response1_text != response2_text + # Don't bother checking values other than the extremes + else: + raise ValueError(f"Unexpected {temperature=} for this test.") + + # Remove xfail when fix is released and this test returns xpass status + @pytest.mark.xfail( + reason="TRT-LLM BLS model will ignore temperature until a later release" + ) + # Simple tests to verify temperature roughly behaves as expected + def test_completions_temperature_tensorrtllm( + self, client, backend: str, model: str, prompt: str + ): + if backend != "tensorrtllm": + pytest.skip(reason="Only used to test vLLM-specific temperature behavior") + + responses = [] + payload1 = { + "model": model, + "prompt": prompt, + "temperature": 0.0, + # TRT-LLM requires certain settings of `top_k` / `top_p` to + # respect changes in `temperature` + "top_p": 0.5, + } + payload2 = copy.deepcopy(payload1) + payload2["temperature"] = 1.0 + + # First 2 responses should be the same in TRT-LLM with identical payload + responses.append( + client.post( + "/v1/completions", + json=payload1, + ) + ) + responses.append( + client.post( + "/v1/completions", + json=payload1, + ) + ) + # Third response should differ with different temperature in payload + responses.append( + client.post( + "/v1/completions", + json=payload2, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = responses[0].json()["choices"][0]["text"].strip().split() + response2_text = responses[1].json()["choices"][0]["text"].strip().split() + response3_text = responses[2].json()["choices"][0]["text"].strip().split() + + assert response1_text == response2_text + assert response1_text != response3_text + + # Simple tests to verify seed roughly behaves as expected + def test_completions_seed(self, client, model: str, prompt: str): + responses = [] + payload1 = {"model": model, "prompt": prompt, "seed": 1} + payload2 = copy.deepcopy(payload1) + payload2["seed"] = 2 + + # First 2 responses should be the same in TRT-LLM with identical payload + responses.append( + client.post( + "/v1/completions", + json=payload1, + ) + ) + responses.append( + client.post( + "/v1/completions", + json=payload1, + ) + ) + # Third response should differ with different temperature in payload + responses.append( + client.post( + "/v1/completions", + json=payload2, + ) + ) + + for response in responses: + print("Response:", response.json()) + assert response.status_code == 200 + + response1_text = responses[0].json()["choices"][0]["text"].strip().split() + response2_text = responses[1].json()["choices"][0]["text"].strip().split() + response3_text = responses[2].json()["choices"][0]["text"].strip().split() + + assert response1_text == response2_text + assert response1_text != response3_text + + @pytest.mark.parametrize( + "sampling_parameter, value", + [ + ("temperature", 2.1), + ("temperature", -0.1), + ("max_tokens", -1), + ("top_p", 1.1), + ("frequency_penalty", 3), + ("frequency_penalty", -3), + ("presence_penalty", 2.1), + ("presence_penalty", -2.1), + # NOTE: Extensions to the spec + ("min_tokens", -1), + ("ignore_eos", 123), + ], + ) + def test_completions_invalid_sampling_parameters( + self, client, sampling_parameter, value, model: str, prompt: str + ): + response = client.post( + "/v1/completions", + json={ + "model": model, + "prompt": prompt, + sampling_parameter: value, + }, + ) + + print("Response:", response.json()) + assert response.status_code == 422 + + def test_completions_empty_request(self, client): + response = client.post("/v1/completions", json={}) + assert response.status_code == 422 + + def test_completions_no_model(self, client, prompt: str): + response = client.post("/v1/completions", json={"prompt": prompt}) + assert response.status_code == 422 + + def test_completions_no_prompt(self, client, model: str): + response = client.post("/v1/completions", json={"model": model}) + assert response.status_code == 422 + + def test_completions_empty_prompt(self, client, model: str): + response = client.post("/v1/completions", json={"model": model, "prompt": ""}) + + # NOTE: Should this be validated in schema instead? + # 400 Error returned in route handler + assert response.status_code == 400 + + def test_no_prompt(self, client, model: str): + response = client.post("/v1/completions", json={"model": model}) + + # 422 Error returned by schema validation + assert response.status_code == 422 + + @pytest.mark.parametrize( + "sampling_parameter_dict", + [ + # Each individual parameter should fail for > 1 for now + {"n": 2}, + {"best_of": 2}, + {"n": 2, "best_of": 2}, + # When individual params > 1 are supported, best_of < n should fail + {"n": 2, "best_of": 1}, + ], + ) + def test_completions_multiple_choices( + self, client, sampling_parameter_dict: dict, model: str, prompt: str + ): + response = client.post( + "/v1/completions", + json={"model": model, "prompt": prompt, **sampling_parameter_dict}, + ) + print("Response:", response.json()) + + # FIXME: Add support and test for success + # Expected to fail when n or best_of > 1, only single choice supported for now + assert response.status_code == 400 + assert "only single choice" in response.json()["detail"] + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_lora(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_multi_lora(self): + pass + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_usage_response(self): + pass diff --git a/python/openai/tests/test_models/identity_py/1/model.py b/python/openai/tests/test_models/identity_py/1/model.py new file mode 100644 index 0000000000..7bbe4bf991 --- /dev/null +++ b/python/openai/tests/test_models/identity_py/1/model.py @@ -0,0 +1,40 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def execute(self, requests): + """ + Identity model in Python backend. + """ + responses = [] + for request in requests: + input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0") + out_tensor = pb_utils.Tensor("OUTPUT0", input_tensor.as_numpy()) + responses.append(pb_utils.InferenceResponse([out_tensor])) + return responses diff --git a/python/openai/tests/test_models/identity_py/config.pbtxt b/python/openai/tests/test_models/identity_py/config.pbtxt new file mode 100644 index 0000000000..3926c830cb --- /dev/null +++ b/python/openai/tests/test_models/identity_py/config.pbtxt @@ -0,0 +1,51 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" +max_batch_size: 64 + +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/python/openai/tests/test_models/mock_llm/1/model.py b/python/openai/tests/test_models/mock_llm/1/model.py new file mode 100644 index 0000000000..0fc9053cd3 --- /dev/null +++ b/python/openai/tests/test_models/mock_llm/1/model.py @@ -0,0 +1,108 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import time + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.decoupled = self.model_config.get("model_transaction_policy", {}).get( + "decoupled" + ) + + def execute(self, requests): + if self.decoupled: + return self.exec_decoupled(requests) + else: + return self.exec(requests) + + def exec(self, requests): + responses = [] + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + + input_np = pb_utils.get_input_tensor_by_name( + request, "text_intpu" + ).as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "stream").as_numpy() + stream = stream_np.flatten()[0] + if stream: + responses.append( + pb_utils.InferenceResponse( + error=pb_utils.TritonError( + "STREAM only supported in decoupled mode" + ) + ) + ) + else: + out_tensor = pb_utils.Tensor( + "text_output", np.repeat(input_np, rep_count, axis=1) + ) + responses.append(pb_utils.InferenceResponse([out_tensor])) + return responses + + def exec_decoupled(self, requests): + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + fail_last = params["FAIL_LAST"] if "FAIL_LAST" in params else False + delay = params["DELAY"] if "DELAY" in params else None + + sender = request.get_response_sender() + input_np = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "stream").as_numpy() + out_tensor = pb_utils.Tensor("text_output", input_np) + response = pb_utils.InferenceResponse([out_tensor]) + # If stream enabled, just send multiple copies of response + # FIXME: Could split up response string into tokens, but this is simpler for now. + stream = stream_np.flatten()[0] + if stream: + for _ in range(rep_count): + if delay is not None: + time.sleep(delay) + sender.send(response) + sender.send( + None + if not fail_last + else pb_utils.InferenceResponse( + error=pb_utils.TritonError("An Error Occurred") + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + # If stream disabled, just send one response + else: + sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + return None diff --git a/python/openai/tests/test_models/mock_llm/config.pbtxt b/python/openai/tests/test_models/mock_llm/config.pbtxt new file mode 100644 index 0000000000..5f665ff543 --- /dev/null +++ b/python/openai/tests/test_models/mock_llm/config.pbtxt @@ -0,0 +1,60 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +backend: "python" + +max_batch_size: 0 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1, 1 ] + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1, 1 ] + } +] + +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ 1, -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] diff --git a/python/openai/tests/test_observability.py b/python/openai/tests/test_observability.py new file mode 100644 index 0000000000..b15b64e735 --- /dev/null +++ b/python/openai/tests/test_observability.py @@ -0,0 +1,90 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient +from tests.utils import setup_fastapi_app, setup_server + + +# Override conftest.py default model +@pytest.fixture +def model(): + return "mock_llm" + + +class TestObservability: + @pytest.fixture(scope="class") + def client(self): + # TODO: Cleanup, mock server/engine, etc. + model_repository = Path(__file__).parent / "test_models" + server = setup_server(str(model_repository)) + app = setup_fastapi_app(tokenizer="", server=server, backend=None) + with TestClient(app) as test_client: + yield test_client + + server.stop() + + ### General Error Handling ### + def test_not_found(self, client): + response = client.get("/does-not-exist") + assert response.status_code == 404 + + ### Startup / Health ### + def test_startup_success(self, client): + response = client.get("/health/ready") + assert response.status_code == 200 + + ### Metrics ### + def test_startup_metrics(self, client): + response = client.get("/metrics") + assert response.status_code == 200 + # TODO: Flesh out metrics tests further + assert "nv_cpu_utilization" in response.text + + ### Models ### + def test_models_list(self, client): + response = client.get("/v1/models") + assert response.status_code == 200 + models = response.json()["data"] + # Two models are in test_models specifically to verify that all models + # are listed by this endpoint. This can be removed if the behavior changes. + assert len(models) == 2 + for model in models: + assert model["id"] + assert model["object"] == "model" + assert model["created"] > 0 + assert model["owned_by"] == "Triton Inference Server" + + def test_models_get(self, client, model): + response = client.get(f"/v1/models/{model}") + assert response.status_code == 200 + model_resp = response.json() + assert model_resp["id"] == model + assert model_resp["object"] == "model" + assert model_resp["created"] > 0 + assert model_resp["owned_by"] == "Triton Inference Server" diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py new file mode 100644 index 0000000000..6f1b456ab4 --- /dev/null +++ b/python/openai/tests/test_openai_client.py @@ -0,0 +1,250 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import openai +import pytest + + +class TestOpenAIClient: + @pytest.fixture(scope="class") + def client(self, server): + return server.get_client() + + def test_openai_client_models(self, client: openai.OpenAI, backend: str): + models = list(client.models.list()) + print(f"Models: {models}") + if backend == "tensorrtllm": + # tensorrt_llm_bls + + # preprocess -> tensorrt_llm -> postprocess + assert len(models) == 4 + elif backend == "vllm": + assert len(models) == 1 + else: + raise Exception(f"Unexpected backend {backend=}") + + def test_openai_client_completion( + self, client: openai.OpenAI, model: str, prompt: str + ): + completion = client.completions.create( + prompt=prompt, + model=model, + ) + + print(f"Completion results: {completion}") + assert completion.choices[0].text + assert completion.choices[0].finish_reason == "stop" + + def test_openai_client_chat_completion( + self, client: openai.OpenAI, model: str, messages: List[dict] + ): + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + ) + + print(f"Chat completion results: {chat_completion}") + assert chat_completion.choices[0].message.content + assert chat_completion.choices[0].finish_reason == "stop" + + @pytest.mark.parametrize("echo", [False, True]) + def test_openai_client_completion_echo( + self, client: openai.OpenAI, echo: bool, backend: str, model: str, prompt: str + ): + if backend == "tensorrtllm": + pytest.skip( + reason="TRT-LLM backend currently only supports setting this parameter at model load time", + ) + + completion = client.completions.create(prompt=prompt, model=model, echo=echo) + + print(f"Completion results: {completion}") + response = completion.choices[0].text + if echo: + assert prompt in response + else: + assert prompt not in response + + @pytest.mark.skip(reason="Not Implemented Yet") + def test_openai_client_function_calling(self): + pass + + +class TestAsyncOpenAIClient: + @pytest.fixture(scope="class") + def client(self, server): + return server.get_async_client() + + @pytest.mark.asyncio + async def test_openai_client_models(self, client: openai.AsyncOpenAI, backend: str): + async_models = await client.models.list() + models = [model async for model in async_models] + print(f"Models: {models}") + if backend == "tensorrtllm": + # tensorrt_llm_bls + + # preprocess -> tensorrt_llm -> postprocess + assert len(models) == 4 + elif backend == "vllm": + assert len(models) == 1 + else: + raise Exception(f"Unexpected backend {backend=}") + + @pytest.mark.asyncio + async def test_openai_client_completion( + self, client: openai.AsyncOpenAI, model: str, prompt: str + ): + completion = await client.completions.create( + prompt=prompt, + model=model, + ) + + print(f"Completion results: {completion}") + assert completion.choices[0].text + assert completion.choices[0].finish_reason == "stop" + + @pytest.mark.asyncio + async def test_openai_client_chat_completion( + self, client: openai.AsyncOpenAI, model: str, messages: List[dict] + ): + chat_completion = await client.chat.completions.create( + messages=messages, + model=model, + ) + + assert chat_completion.choices[0].message.content + assert chat_completion.choices[0].finish_reason == "stop" + print(f"Chat completion results: {chat_completion}") + + @pytest.mark.asyncio + async def test_completion_streaming( + self, client: openai.AsyncOpenAI, model: str, prompt: str + ): + # Test single completion for comparison + chat_completion = await client.completions.create( + model=model, + prompt=prompt, + max_tokens=10, + temperature=0.0, + stream=False, + ) + output = chat_completion.choices[0].text + stop_reason = chat_completion.choices[0].finish_reason + + # Test streaming + stream = await client.completions.create( + model=model, + prompt=prompt, + max_tokens=10, + temperature=0.0, + stream=True, + ) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0] + if delta.text: + chunks.append(delta.text) + if delta.finish_reason is not None: + finish_reason_count += 1 + + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert "".join(chunks) == output + + @pytest.mark.parametrize( + "sampling_parameter_dict", + [ + {}, + # Verify that stop words work with streaming outputs + {"stop": "is"}, + {"stop": ["is"]}, + {"stop": ["is", ".", ","]}, + ], + ) + @pytest.mark.asyncio + async def test_chat_streaming( + self, + client: openai.AsyncOpenAI, + model: str, + messages: List[dict], + sampling_parameter_dict: dict, + ): + # Fixed seed and temperature for comparing reproducible responses + seed = 0 + temperature = 0.0 + # Generate enough tokens to easily identify stop words are working. + max_tokens = 64 + + # Test single chat completion for comparison + chat_completion = await client.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, + stream=False, + **sampling_parameter_dict, + ) + output = chat_completion.choices[0].message.content + stop_reason = chat_completion.choices[0].finish_reason + + # Test streaming + stream = await client.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, + stream=True, + **sampling_parameter_dict, + ) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + + # Assert that streaming actually returned multiple responses + # and that it is equivalent to the non-streamed output + assert len(chunks) > 1 + streamed_output = "".join(chunks) + assert streamed_output == output + + @pytest.mark.skip(reason="Not Implemented Yet") + @pytest.mark.asyncio + async def test_openai_client_function_calling(self): + pass diff --git a/python/openai/tests/utils.py b/python/openai/tests/utils.py new file mode 100644 index 0000000000..fdffcc5ea9 --- /dev/null +++ b/python/openai/tests/utils.py @@ -0,0 +1,139 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional + +import openai +import requests +import tritonserver + +sys.path.append(os.path.join(Path(__file__).resolve().parent, "..", "openai_frontend")) +from engine.triton_engine import TritonLLMEngine +from frontend.fastapi_frontend import FastApiFrontend + + +# TODO: Cleanup, refactor, mock, etc. +def setup_server(model_repository: str): + server: tritonserver.Server = tritonserver.Server( + model_repository=model_repository, + log_verbose=0, + log_info=True, + log_warn=True, + log_error=True, + ).start(wait_until_ready=True) + return server + + +def setup_fastapi_app(tokenizer: str, server: tritonserver.Server, backend: str): + engine: TritonLLMEngine = TritonLLMEngine( + server=server, tokenizer=tokenizer, backend=backend + ) + frontend: FastApiFrontend = FastApiFrontend(engine=engine) + return frontend.app + + +# Heavily inspired by vLLM's test infrastructure +class OpenAIServer: + API_KEY = "EMPTY" # Triton's OpenAI server does not need API key + START_TIMEOUT = 120 # wait for server to start for up to 120 seconds + + def __init__( + self, + cli_args: List[str], + *, + env_dict: Optional[Dict[str, str]] = None, + ) -> None: + # TODO: Incorporate caller's cli_args passed to this instance instead + self.host = "localhost" + self.port = 9000 + + env = os.environ.copy() + if env_dict is not None: + env.update(env_dict) + + this_dir = Path(__file__).resolve().parent + script_path = this_dir / ".." / "openai_frontend" / "main.py" + self.proc = subprocess.Popen( + ["python3", script_path] + cli_args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + # Wait until health endpoint is responsive + self._wait_for_server( + url=self.url_for("health", "ready"), timeout=self.START_TIMEOUT + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + wait_secs = 30 + self.proc.wait(wait_secs) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _wait_for_server(self, *, url: str, timeout: float): + start = time.time() + while True: + try: + if requests.get(url).status_code == 200: + break + except Exception as err: + result = self.proc.poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError("Server failed to start in time.") from err + + @property + def url_root(self) -> str: + return f"http://{self.host}:{self.port}" + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self): + return openai.OpenAI( + base_url=self.url_for("v1"), + api_key=self.API_KEY, + ) + + def get_async_client(self): + return openai.AsyncOpenAI( + base_url=self.url_for("v1"), + api_key=self.API_KEY, + ) diff --git a/python/openai/tests/vllm_models/llama-3.1-8b-instruct/1/model.json b/python/openai/tests/vllm_models/llama-3.1-8b-instruct/1/model.json new file mode 100644 index 0000000000..cb9b14c765 --- /dev/null +++ b/python/openai/tests/vllm_models/llama-3.1-8b-instruct/1/model.json @@ -0,0 +1 @@ +{"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "disable_log_requests": true, "gpu_memory_utilization": 0.9} diff --git a/python/openai/tests/vllm_models/llama-3.1-8b-instruct/config.pbtxt b/python/openai/tests/vllm_models/llama-3.1-8b-instruct/config.pbtxt new file mode 100644 index 0000000000..4ad6534943 --- /dev/null +++ b/python/openai/tests/vllm_models/llama-3.1-8b-instruct/config.pbtxt @@ -0,0 +1,2 @@ +backend: "vllm" +instance_group [{kind: KIND_MODEL}] \ No newline at end of file diff --git a/qa/L0_openai/test.sh b/qa/L0_openai/test.sh new file mode 100755 index 0000000000..c910c204ac --- /dev/null +++ b/qa/L0_openai/test.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +### Helpers ### + +function install_deps() { + # Install python bindings for tritonserver and tritonfrontend + pip install /opt/tritonserver/python/triton*.whl + + # Install application/testing requirements + pushd openai/ + pip install -r requirements.txt + pip install -r requirements-test.txt + + if [ "${IMAGE_KIND}" == "TRTLLM" ]; then + prepare_tensorrtllm + else + prepare_vllm + fi + popd +} + +function prepare_vllm() { + echo "No prep needed for vllm currently" +} + +function prepare_tensorrtllm() { + MODEL="llama-3-8b-instruct" + MODEL_REPO="tests/tensorrtllm_models" + rm -rf ${MODEL_REPO} + + # FIXME: This will require an upgrade each release to match the TRT-LLM version + # Use Triton CLI to prepare model repository for testing + pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.10 + # NOTE: Could use ENGINE_DEST_PATH set to NFS mount for pre-built engines in future + triton import \ + --model ${MODEL} \ + --backend tensorrtllm \ + --model-repository "${MODEL_REPO}" + + # WAR for tests expecting default name of "tensorrt_llm_bls" + mv "${MODEL_REPO}/${MODEL}" "${MODEL_REPO}/tensorrt_llm_bls" +} + +function pre_test() { + # Cleanup + rm -rf openai/ + rm -f *.xml *.log + + # Prep test environment + cp -r ../../python/openai . + install_deps +} + +function run_test() { + pushd openai/ + TEST_LOG="test_openai.log" + + # Capture error code without exiting to allow log collection + set +e + pytest -s -v --junitxml=test_openai.xml tests/ 2>&1 > ${TEST_LOG} + if [ $? -ne 0 ]; then + cat ${TEST_LOG} + echo -e "\n***\n*** Test Failed\n***" + RET=1 + fi + set -e + + # Collect logs for error analysis when needed + cp *.xml *.log ../../../ + popd +} + +### Test ### + +RET=0 + +pre_test +run_test + +exit ${RET} diff --git a/tools/add_copyright.py b/tools/add_copyright.py index 34432bb0c6..cf6b2a8686 100644 --- a/tools/add_copyright.py +++ b/tools/add_copyright.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: @@ -25,7 +26,6 @@ import argparse import os import re -import subprocess import sys from datetime import datetime from typing import Callable, Dict, Optional, Sequence @@ -257,7 +257,9 @@ def add_copyrights(paths): f"WARNING: No handler registered for file: {path}. Please add a new handler to {__file__}!" ) - subprocess.run(["git", "add"] + paths) + # Don't automatically 'git add' changes for now, make it more clear which + # files were changed and have ability to see 'git diff' on them. + # subprocess.run(["git", "add"] + paths) print(f"Processed copyright headers for {len(paths)} file(s).")