Skip to content

Commit

Permalink
Model transform docs (#1665)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Sep 25, 2024
1 parent 18efc81 commit c3ff864
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 1 deletion.
170 changes: 170 additions & 0 deletions docs/source/basics/model_transforms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
.. _model_transform_usage_label:

=====================
Multimodal Transforms
=====================

Multimodal model transforms apply model-specific data transforms to each modality and prepares :class:`~torchtune.data.Message`
objects to be input into the model. torchtune currently supports text + image model transforms.
These are intended to be drop-in replacements for tokenizers in multimodal datasets and support the standard
``encode``, ``decode``, and ``tokenize_messages``.

.. code-block:: python
# torchtune.models.flamingo.FlamingoTransform
class FlamingoTransform(ModelTokenizer, Transform):
def __init__(...):
# Text transform - standard tokenization
self.tokenizer = llama3_tokenizer(...)
# Image transforms
self.transform_image = CLIPImageTransform(...)
self.xattn_mask = VisionCrossAttentionMask(...)
.. code-block:: python
from torchtune.models.flamingo import FlamingoTransform
from torchtune.data import Message
from PIL import Image
sample = {
"messages": [
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Message(
role="assistant",
content="A robot is in both images.",
),
],
}
transform = FlamingoTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
tile_size=224,
patch_size=14,
)
tokenized_dict = transform(sample)
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
Using model transforms
----------------------
You can pass them into any multimodal dataset builder just as you would a model tokenizer.

.. code-block:: python
from torchtune.datasets.multimodal import the_cauldron_dataset
from torchtune.models.flamingo import FlamingoTransform
transform = FlamingoTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
tile_size=224,
patch_size=14,
)
ds = the_cauldron_dataset(
model_transform=transform,
subset="ai2d",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# <|begin_of_text|><|start_header_id|>user<|end_header_id|>
#
# <|image|>Question: What do respiration and combustion give out
# Choices:
# A. Oxygen
# B. Carbon dioxide
# C. Nitrogen
# D. Heat
# Answer with the letter.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
#
# Answer: B<|eot_id|>
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
Creating model transforms
-------------------------
Model transforms are expected to process both text and images in the sample dictionary.
Both should be contained in the ``"messages"`` field of the sample.

The following methods are required on the model transform:

- ``tokenize_messages``
- ``__call__``

.. code-block:: python
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
class MyMultimodalTransform(ModelTokenizer, Transform):
def __init__(...):
self.tokenizer = my_tokenizer_builder(...)
self.transform_image = MyImageTransform(...)
def tokenize_messages(
self,
messages: List[Message],
add_eos: bool = True,
) -> Tuple[List[int], List[bool]]:
# Any other custom logic here
...
return self.tokenizer.tokenize_messages(
messages=messages,
add_eos=add_eos,
)
def __call__(
self, sample: Mapping[str, Any], inference: bool = False
) -> Mapping[str, Any]:
# Expected input parameters for vision encoder
encoder_input = {"images": [], "aspect_ratio": []}
messages = sample["messages"]
# Transform all images in sample
for message in messages:
for image in message.get_media():
out = self.transform_image({"image": image}, inference=inference)
encoder_input["images"].append(out["image"])
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
sample["encoder_input"] = encoder_input
# Transform all text - returns same dictionary with additional keys "tokens" and "mask"
sample = self.tokenizer(sample, inference=inference)
return sample
transform = MyMultimodalTransform(...)
sample = {
"messages": [
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Message(
role="assistant",
content="A robot is in both images.",
),
],
}
tokenized_dict = transform(sample)
print(tokenized_dict)
# {'encoder_input': {'images': ..., 'aspect_ratio': ...}, 'tokens': ..., 'mask': ...}
Example model transforms
------------------------
- Flamingo
- :class:`~torchtune.models.flamingo.FlamingoTransform`
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ torchtune tutorials.
basics/tokenizers
basics/prompt_templates
basics/preference_datasets
basics/model_transforms

.. toctree::
:glob:
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/flamingo/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FlamingoTransform(ModelTokenizer, Transform):
Args:
path (str): Path to pretrained tiktoken tokenizer file.
tile_size (int): Size of the tiles to divide the image into. Default 224.
tile_size (int): Size of the tiles to divide the image into.
patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is
used to calculate the number of image embeddings per image.
max_num_tiles (int): Only used if possible_resolutions is NOT given.
Expand Down

0 comments on commit c3ff864

Please sign in to comment.