-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Early fusion multimodal models #1904
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1904
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New Failures, 2 Cancelled JobsAs of commit 024bfc7 with merge base d3039da (): NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this up Rafi! I left some comments on the implementation, but I'll leave the state dict discussion to others as we've already chatted on this.
def __init__( | ||
self, | ||
decoder: TransformerDecoder, | ||
encoders: nn.ModuleDict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think it would be nice if we allowed all of the encoder params to be a list/dict or single value input to make single encoder builders look much cleaner. Then we can package them as an iterable in the init.
encoders: nn.ModuleDict, | ||
encoder_tokens: Dict[str, int], | ||
decoder_trainable: bool, | ||
encoders_trainable: Dict[str, bool], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An error should be thrown if the different input dicts don't have the same keys
if decoder_trainable: | ||
trainable_params |= { | ||
f"decoder.{n}" for n, p in self.decoder.named_parameters() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is missing the logic and the parameter to make fusion modules trainable/untrainable
been expanded to the number of tokens encoded for the given media. For example, if an image is tiled/patched | ||
and tokenized to 100 tokens, we assume the text sequence already has 100 "image" tokens as placeholders. | ||
""" | ||
embeds = self.tok_embeddings(tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't do this because the encoder tokens won't be in the tok_embeddings table. You need to first filter those out as in here https://www.internalfb.com/intern/paste/P1666298928/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah let's talk about this offline, because from the reference code I was using the encoder tokens are part of the embedding table
>>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} | ||
>>> | ||
>>> # EarlyFusionModel combines the encoder and decoder | ||
>>> model = DeepFusionModel(decoder, encoders) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: EarlyFusionModel
Thanks for the RFC, you made it very clear what's difference between early fusion and late fusion! About the design choice, I personally prefer Option 2 for the same reason as you mentioned. I think it's fine to "polluting" the decoder model forward a bit with some optional arguments for each modality. We might need something like
|
# module into TransformerDecoder builder that does the | ||
# merging there | ||
self.tok_embeddings = decoder.tok_embeddings | ||
decoder.tok_embeddings = nn.Identity() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a user wants to use the decoder outside of the EarlyFusionModule in the same script, they will need to restore the original tok_embeddings module from nn.Identity
Why is this a concern? we only set decoder.tok_embeddings
to identity within EarlyFusionModule? If outside EarlyFusionModule, we just use the normal decoder right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, but if I passed in the decoder into EarlyFusionModule and modified a layer here, even if I use it separately the embedding layer will have been modified
if fusion_trainable: | ||
trainable_params |= set(get_fusion_params(self)) | ||
else: | ||
trainable_params -= set(get_fusion_params(self)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why would we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure yet if we need this... once we have a better idea of the full architecture let's chat with @pbontrager to see which components need to be fusion modules
11th hour comment on the open design question: in my mind there are nonzero UX costs to either approach. If we patch the decoder embeddings to Personally I really don't like state dict hooks for the reason I described above. As soon as something (inevitably) goes wrong, it will take a lot more debugging and head-banging-against-the-wall before the user realizes that things are being swapped out under the hood. So perhaps it's no surprise, but I vote for the simple and dumb thing: just add an extra parameter to TransformerDecoder forward. I know that may be controversial, but I like doing the obvious thing, and I like to think our users would appreciate that as well. |
TODO: fix tests
Context
This is a focused RFC based on @pbontrager 's excellent original RFC on multimodal fusion models #1283. Since the RFC, we have already landed Deep Fusion model components. This PR discusses and implements the EarlyFusionModel component, along with testing and some lint updates.
Early fusion is simply a decoder with 1 or more extra encoders that merges their outputs with the token embeddings tables. The challenge lies in how we merge the embeddings and pass it into the decoder.
Design
There is one design consideration I am seeking feedback on, and that is the EarlyFusionModel's usage of
self.decoder.tok_embeddings
. It accesses the decoder's token embedding table outside of the decoder forward because we need to merge the image encoder and any other modality encoder's output embeddings with the text embeddings (in this case just concatenate in sequence dimension):Now, instead of token ids, we are passing in the merged embeddings directly into the decoder. But since we already used the text-only tok_embeddings from the decoder, we need to skip it when passing in the merged embeddings for the final decoder output. There are two ways we can do this.
State dict surgery
In the current code changes and suggested by the original RFC, we can manually set
self.decoder.tok_embeddings = nn.Identity()
so that it becomes a no-op when you forward pass with merged embeddings.Additional input_embeds kwarg
We could add a new keyword argument in
TransformerDecoder
forward for input embeddings. If this is passed in, we automatically skip the token embeddings:This way we don't need any state dict hooks or decoder modifications. However, we are polluting the decoder model forward with more arguments.