Skip to content

Commit

Permalink
Hugging Face Transformer Deployment Tutorial (#49)
Browse files Browse the repository at this point in the history
* Initial Commit

* Mount model repo so changes reflect, parameter tweaking, README file

* Image name error

* Incorporating review comments. Separate docker and model repo builds, add README, restructure repo

* Tutorial restructuring. Using static model configurations

* Bump triton container and update README

* Remove client script

* Incorporating review comments

* Modify WIP line in vLLM tutorial

* Remove trust_remote_code parameter from falcon model

* Removing Mistral

* Incorporating Feedback

* Change input/output names

* Pre-commit format

* Different perf_analyzer example, config file format fixes

* Deep dive changes to Triton tools section

* Remove unused variable
  • Loading branch information
fpetrini15 authored Oct 24, 2023
1 parent af67595 commit de7da4a
Show file tree
Hide file tree
Showing 7 changed files with 667 additions and 2 deletions.
27 changes: 27 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
FROM nvcr.io/nvidia/tritonserver:23.09-py3
RUN pip install transformers==4.34.0 protobuf==3.20.3 sentencepiece==0.1.99 accelerate==0.23.0 einops==0.6.1
355 changes: 355 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/README.md

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/falcon7b/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os

os.environ[
"TRANSFORMERS_CACHE"
] = "/opt/tritonserver/model_repository/falcon7b/hf_cache"
import json

import numpy as np
import torch
import transformers
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
self.model_params = self.model_config.get("parameters", {})
default_hf_model = "tiiuae/falcon-7b"
default_max_gen_length = "15"
# Check for user-specified model name in model config parameters
hf_model = self.model_params.get("huggingface_model", {}).get(
"string_value", default_hf_model
)
# Check for user-specified max length in model config parameters
self.max_output_length = int(
self.model_params.get("max_output_length", {}).get(
"string_value", default_max_gen_length
)
)

self.logger.log_info(f"Max sequence length: {self.max_output_length}")
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...")
# Assume tokenizer available for same model
self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model)
self.pipeline = transformers.pipeline(
"text-generation",
model=hf_model,
torch_dtype=torch.float16,
tokenizer=self.tokenizer,
device_map="auto",
)
self.pipeline.tokenizer.pad_token_id = self.tokenizer.eos_token_id

def execute(self, requests):
prompts = []
for request in requests:
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
multi_dim = input_tensor.as_numpy().ndim > 1
if not multi_dim:
prompt = input_tensor.as_numpy()[0].decode("utf-8")
self.logger.log_info(f"Generating sequences for text_input: {prompt}")
prompts.append(prompt)
else:
# Implementation to accept dynamically batched inputs
num_prompts = input_tensor.as_numpy().shape[0]
for prompt_index in range(0, num_prompts):
prompt = input_tensor.as_numpy()[prompt_index][0].decode("utf-8")
prompts.append(prompt)

batch_size = len(prompts)
return self.generate(prompts, batch_size)

def generate(self, prompts, batch_size):
sequences = self.pipeline(
prompts,
max_length=self.max_output_length,
pad_token_id=self.tokenizer.eos_token_id,
batch_size=batch_size,
)
responses = []
texts = []
for i, seq in enumerate(sequences):
output_tensors = []
text = seq[0]["generated_text"]
texts.append(text)
tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_))
output_tensors.append(tensor)
responses.append(pb_utils.InferenceResponse(output_tensors=output_tensors))

return responses

def finalize(self):
print("Cleaning up...")
36 changes: 36 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/falcon7b/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Triton backend to use
backend: "python"

# Hugging face model path. Parameters must follow this
# key/value structure
parameters: {
key: "huggingface_model",
value: {string_value: "tiiuae/falcon-7b"}
}
# The maximum number of tokens to generate in response
# to our input
parameters: {
key: "max_output_length",
value: {string_value: "15"}
}
# Triton should expect as input a single string of set
# length named 'text_input'
input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
# Triton should expect to respond with a single string
# output of variable length named 'text_output'
output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
]
103 changes: 103 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/persimmon8b/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os

os.environ[
"TRANSFORMERS_CACHE"
] = "/opt/tritonserver/model_repository/persimmon8b/hf_cache"

import json

import numpy as np
import torch
import transformers
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
self.model_params = self.model_config.get("parameters", {})
default_hf_model = "adept/persimmon-8b-base"
default_max_gen_length = "15"
# Check for user-specified model name in model config parameters
hf_model = self.model_params.get("huggingface_model", {}).get(
"string_value", default_hf_model
)
# Check for user-specified max length in model config parameters
self.max_output_length = int(
self.model_params.get("max_output_length", {}).get(
"string_value", default_max_gen_length
)
)

self.logger.log_info(f"Max output length: {self.max_output_length}")
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...")
# Assume tokenizer available for same model
self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model)
self.pipeline = transformers.pipeline(
"text-generation",
model=hf_model,
torch_dtype=torch.float16,
tokenizer=self.tokenizer,
device_map="auto",
)

def execute(self, requests):
responses = []
for request in requests:
# Assume input named "prompt", specified in autocomplete above
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
prompt = input_tensor.as_numpy()[0].decode("utf-8")

self.logger.log_info(f"Generating sequences for text_input: {prompt}")
response = self.generate(prompt)
responses.append(response)

return responses

def generate(self, prompt):
sequences = self.pipeline(
prompt,
max_length=self.max_output_length,
pad_token_id=self.tokenizer.eos_token_id,
)

output_tensors = []
texts = []
for i, seq in enumerate(sequences):
text = seq["generated_text"]
self.logger.log_info(f"Sequence {i+1}: {text}")
texts.append(text)

tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_))
output_tensors.append(tensor)
response = pb_utils.InferenceResponse(output_tensors=output_tensors)
return response

def finalize(self):
print("Cleaning up...")
36 changes: 36 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/persimmon8b/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Triton backend to use
backend: "python"

# Hugging face model path. Parameters must follow this
# key/value structure
parameters: {
key: "huggingface_model",
value: {string_value: "adept/persimmon-8b-base"}
}
# The maximum number of tokens to generate in response
# to our input
parameters: {
key: "max_output_length",
value: {string_value: "15"}
}
# Triton should expect as input a single string of set
# length named 'text_input'
input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
# Triton should expect to respond with a single string
# output of variable length named 'text_output'
output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
]
3 changes: 1 addition & 2 deletions Quick_Deploy/vLLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ The following tutorial demonstrates how to deploy a simple
Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend) and the
[vLLM](https://github.com/vllm-project/vllm) library.

*NOTE*: The tutorial is intended to be a reference example only. It is a work in progress with
[known limitations](#limitations).
*NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations).


## Step 1: Build a Triton Container Image with vLLM
Expand Down

0 comments on commit de7da4a

Please sign in to comment.