Skip to content

Commit

Permalink
Cast tokenizer where needed
Browse files Browse the repository at this point in the history
Remove unnecessary import
  • Loading branch information
dyastremsky committed Oct 25, 2024
1 parent eef6292 commit 9e25f4e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import random
from typing import Any, Dict
from typing import Any, Dict, cast

from genai_perf.exceptions import GenAIPerfException
from genai_perf.inputs.converters.base_converter import BaseConverter
Expand All @@ -36,6 +36,7 @@
)
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.retrievers.generic_dataset import GenericDataset
from genai_perf.tokenizer import Tokenizer


class TensorRTLLMEngineConverter(BaseConverter):
Expand All @@ -50,9 +51,10 @@ def convert(
) -> Dict[Any, Any]:
request_body: Dict[str, Any] = {"data": []}

tokenizer = cast(Tokenizer, config.tokenizer)
for file_data in generic_dataset.files_data.values():
for row in file_data.rows:
token_ids = config.tokenizer.encode(row.texts[0])
token_ids = tokenizer.encode(row.texts[0])
payload = {
"input_ids": {
"content": token_ids,
Expand Down
4 changes: 2 additions & 2 deletions genai-perf/genai_perf/inputs/inputs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
PromptSource,
)
from genai_perf.inputs.retrievers.synthetic_image_generator import ImageFormat
from genai_perf.tokenizer import DEFAULT_TOKENIZER, Tokenizer, get_tokenizer
from genai_perf.tokenizer import Tokenizer


@dataclass
Expand Down Expand Up @@ -141,4 +141,4 @@ class InputsConfig:
random_seed: int = DEFAULT_RANDOM_SEED

# The tokenizer to use when generating synthetic prompts
tokenizer: Tokenizer = get_tokenizer(DEFAULT_TOKENIZER)
tokenizer: Optional[Tokenizer] = None
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


from typing import List
from typing import List, cast

from genai_perf.inputs.input_constants import DEFAULT_SYNTHETIC_FILENAME
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.retrievers.base_input_retriever import BaseInputRetriever
from genai_perf.inputs.retrievers.generic_dataset import (
DataRow,
Expand All @@ -41,6 +40,7 @@
from genai_perf.inputs.retrievers.synthetic_prompt_generator import (
SyntheticPromptGenerator,
)
from genai_perf.tokenizer import Tokenizer


class SyntheticDataRetriever(BaseInputRetriever):
Expand All @@ -52,13 +52,14 @@ def retrieve_data(self) -> GenericDataset:
files = self.config.synthetic_input_filenames or [DEFAULT_SYNTHETIC_FILENAME]
synthetic_dataset = GenericDataset(files_data={})

tokenizer = cast(Tokenizer, self.config.tokenizer)
for file in files:
data_rows: List[DataRow] = []

for _ in range(self.config.num_prompts):
row = DataRow(texts=[], images=[])
prompt = SyntheticPromptGenerator.create_synthetic_prompt(
self.config.tokenizer,
tokenizer,
self.config.prompt_tokens_mean,
self.config.prompt_tokens_stddev,
)
Expand Down
25 changes: 12 additions & 13 deletions genai-perf/genai_perf/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@

import contextlib
import io
from typing import List
from typing import TYPE_CHECKING, List

from genai_perf.exceptions import GenAIPerfException

# Silence tokenizer warning on import
with contextlib.redirect_stdout(io.StringIO()) as stdout, contextlib.redirect_stderr(
io.StringIO()
) as stderr:
from transformers import AutoTokenizer, BatchEncoding
from transformers import logging as token_logger
# Use TYPE_CHECKING to import BatchEncoding only during static type checks
if TYPE_CHECKING:
from transformers import BatchEncoding

token_logger.set_verbosity_error()
from genai_perf.exceptions import GenAIPerfException

DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer"
DEFAULT_TOKENIZER_REVISION = "main"
Expand All @@ -41,10 +36,14 @@ def __init__(self, name: str, trust_remote_code: bool, revision: str) -> None:
Initialize by downloading the tokenizer from Huggingface.co
"""
try:
# Silence tokenizer warning on first use
# Silence tokenizer warning on import and first use
with contextlib.redirect_stdout(
io.StringIO()
) as stdout, contextlib.redirect_stderr(io.StringIO()) as stderr:
) as stdout, contextlib.redirect_stderr(io.StringIO()):
from transformers import AutoTokenizer
from transformers import logging as token_logger

token_logger.set_verbosity_error()
tokenizer = AutoTokenizer.from_pretrained(
name, trust_remote_code=trust_remote_code, revision=revision
)
Expand All @@ -58,7 +57,7 @@ def __init__(self, name: str, trust_remote_code: bool, revision: str) -> None:
self._encode_args = {"add_special_tokens": False}
self._decode_args = {"skip_special_tokens": True}

def __call__(self, text, **kwargs) -> BatchEncoding:
def __call__(self, text, **kwargs) -> "BatchEncoding":
self._call_args.update(kwargs)
return self._tokenizer(text, **self._call_args)

Expand Down

0 comments on commit 9e25f4e

Please sign in to comment.