Skip to content

Commit

Permalink
Move prompt templating to tokenizer (#1347)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Aug 16, 2024
1 parent 67f6a06 commit 367e9ab
Show file tree
Hide file tree
Showing 40 changed files with 1,037 additions and 261 deletions.
2 changes: 2 additions & 0 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
limit=10 \
dtype=fp32 \
device=cpu \
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_eval_recipe_errors_without_lm_eval(self, caplog, monkeypatch, tmpdir):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
limit=10 \
dtype=fp32 \
device=cpu \
Expand Down
1 change: 1 addition & 0 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_loss(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
metric_logger.filename={log_file} \
""".split()
if fsdp_sharding_strategy:
Expand Down
4 changes: 4 additions & 0 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
metric_logger.filename={log_file} \
compile={compile} \
""".split()
Expand Down Expand Up @@ -134,6 +135,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2"]
Expand All @@ -155,6 +157,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
""".split()
Expand All @@ -180,6 +183,7 @@ def _get_test_config_overrides(self):
"dtype=fp32",
"enable_activation_checkpointing=False",
"tokenizer.path=/tmp/test-artifacts/tokenizer.model",
"tokenizer.prompt_template=null",
"dataset=tests.recipes.utils.DummyDataset",
"dataset.train_on_input=False",
"seed=9",
Expand Down
4 changes: 4 additions & 0 deletions tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_loss(self, fsdp_sharding_strategy, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()
if fsdp_sharding_strategy:
cmd.append(f"fsdp_sharding_strategy={fsdp_sharding_strategy}")
Expand Down Expand Up @@ -137,6 +138,7 @@ def test_training_state_on_resume(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand All @@ -158,6 +160,7 @@ def test_training_state_on_resume(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
""".split()
Expand Down Expand Up @@ -201,6 +204,7 @@ def test_save_and_load_merged_weights(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down
4 changes: 4 additions & 0 deletions tests/recipes/test_lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_loss(self, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down Expand Up @@ -135,6 +136,7 @@ def test_training_state_on_resume(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand All @@ -156,6 +158,7 @@ def test_training_state_on_resume(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
""".split()
Expand Down Expand Up @@ -202,6 +205,7 @@ def test_save_and_load_merged_weights(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down
5 changes: 5 additions & 0 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
metric_logger.filename={log_file} \
compile={compile} \
""".split()
Expand Down Expand Up @@ -133,6 +134,7 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
compile={compile} \
""".split()

Expand Down Expand Up @@ -180,6 +182,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand All @@ -204,6 +207,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
Expand Down Expand Up @@ -234,6 +238,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down
1 change: 1 addition & 0 deletions tests/recipes/test_ppo_full_tunetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _get_test_config_overrides(self):
"enable_activation_checkpointing=False",
"tokenizer.path=/tmp/test-artifacts/tokenizer.model",
"tokenizer._component_=torchtune.models.llama2.llama2_tokenizer",
"tokenizer.prompt_template=null",
"seed=9",
"optimizer=torch.optim.AdamW",
"optimizer.lr=2e-5",
Expand Down
1 change: 1 addition & 0 deletions tests/recipes/test_qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch):
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
metric_logger.filename={log_file} \
""".split()
model_config = MODEL_TEST_CONFIGS[model_type]
Expand Down
18 changes: 18 additions & 0 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from omegaconf import OmegaConf
from torchtune.config._utils import (
_get_component_from_path,
_get_prompt_template,
_merge_yaml_and_cli_args,
_remove_key_by_dotpath,
InstantiationError,
log_config,
)
from torchtune.data._prompt_templates import PromptTemplate
from torchtune.models.llama2 import Llama2ChatTemplate
from torchtune.utils.argparse import TuneRecipeArgumentParser

_CONFIG = {
Expand Down Expand Up @@ -181,3 +184,18 @@ def test_remove_key_by_dotpath(self):
cfg = copy.deepcopy(_CONFIG)
with pytest.raises(KeyError, match="'g'"):
_remove_key_by_dotpath(cfg, "g")

def test_get_prompt_template(self):
template = _get_prompt_template("torchtune.models.llama2.Llama2ChatTemplate")
assert isinstance(template, Llama2ChatTemplate)

template = _get_prompt_template({"user": ("1", "2"), "assistant": ("3", "4")})
assert isinstance(template, PromptTemplate)
assert template.template["user"] == ("1", "2")
assert template.template["assistant"] == ("3", "4")

with pytest.raises(
ValueError,
match="Prompt template must be a dotpath string or dictionary with custom template",
):
_ = _get_prompt_template(["user", "assistant"])
12 changes: 11 additions & 1 deletion tests/torchtune/data/test_prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,24 @@ class TestChatMLTemplate:
role="assistant",
content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.<|im_end|>",
"showcasing their driving skills despite the challenging conditions.<|im_end|>\n",
),
]

def test_format(self):
actual = ChatMLTemplate()(MESSAGE_SAMPLE)
assert_dialogue_equal(actual, self.expected_dialogue)

def test_format_generation(self):
messages_generation = MESSAGE_SAMPLE[:2] + [
Message(role="assistant", content="")
]
expected = self.expected_dialogue[:2] + [
Message(role="assistant", content="<|im_start|>assistant\n")
]
actual = ChatMLTemplate()(messages_generation)
assert_dialogue_equal(actual, expected)


class TestGrammarErrorCorrectionTemplate:
samples = [
Expand Down
62 changes: 3 additions & 59 deletions tests/torchtune/datasets/test_grammar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,35 +39,7 @@ def test_label_no_masking(self, load_dataset, tokenizer):
grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True)
input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"]

assert input == [
0,
7,
4,
2,
8,
8,
7,
2,
3,
6,
4,
8,
5,
8,
5,
3,
10,
7,
4,
3,
6,
4,
8,
9,
2,
9,
-1,
]
assert input == [0, 7, 2, 3, 6, 4, 8, 5, 8, 5, 7, 4, 3, 6, 4, 8, 9, 2, 9, -1]
assert labels == input

@patch("torchtune.datasets._sft.load_dataset")
Expand All @@ -91,34 +63,6 @@ def test_label_masking(self, load_dataset, tokenizer):
# Generate the input and labels
input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"]

assert input == [
0,
7,
4,
2,
8,
8,
7,
2,
3,
6,
4,
8,
5,
8,
5,
3,
10,
7,
4,
3,
6,
4,
8,
9,
2,
9,
-1,
]
assert input == [0, 7, 2, 3, 6, 4, 8, 5, 8, 5, 7, 4, 3, 6, 4, 8, 9, 2, 9, -1]
# Check that the input is masked
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 17
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 10
6 changes: 1 addition & 5 deletions tests/torchtune/datasets/test_preference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from unittest import mock

import pytest
from tests.test_utils import DummyPromptTemplate, DummyTokenizer
from tests.test_utils import DummyTokenizer
from torchtune.data import Message
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets._preference import PreferenceDataset
Expand Down Expand Up @@ -64,21 +64,18 @@ def expected(self):
return {
"prompt": [
0,
5,
4,
2,
4,
],
"chosen": [
10,
3,
6,
2,
2,
-1,
],
"rejected": [
10,
3,
6,
2,
Expand All @@ -103,7 +100,6 @@ def test_get_item(self, mock_load_dataset, dialogue, expected):
source="iam/agoofy/goober",
message_transform=ToDummyPreferenceMessages(),
tokenizer=DummyTokenizer(),
prompt_template=DummyPromptTemplate(),
)
assert len(ds) == 1
mock_load_dataset.assert_called_once()
Expand Down
12 changes: 1 addition & 11 deletions tests/torchtune/datasets/test_samsum_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def test_label_no_masking(self, load_dataset, tokenizer):

assert input == [
0,
9,
4,
9,
7,
1,
5,
Expand All @@ -61,8 +58,6 @@ def test_label_no_masking(self, load_dataset, tokenizer):
3,
8,
3,
3,
8,
6,
5,
7,
Expand Down Expand Up @@ -100,9 +95,6 @@ def test_label_masking(self, load_dataset, tokenizer):

assert input == [
0,
9,
4,
9,
7,
1,
5,
Expand All @@ -119,8 +111,6 @@ def test_label_masking(self, load_dataset, tokenizer):
3,
8,
3,
3,
8,
6,
5,
7,
Expand All @@ -132,4 +122,4 @@ def test_label_masking(self, load_dataset, tokenizer):
9,
-1,
]
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 22
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 17
Loading

0 comments on commit 367e9ab

Please sign in to comment.