Skip to content

Commit

Permalink
Deep Fusion Modules (#1338)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Aug 16, 2024
1 parent b74f4b4 commit 67f6a06
Show file tree
Hide file tree
Showing 9 changed files with 839 additions and 0 deletions.
15 changes: 15 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ PEFT Components
peft.validate_state_dict_for_lora
peft.disable_adapter


Fusion Components
-----------------
Components for building models that are a fusion of two+ pre-trained models.

.. autosummary::
:toctree: generated/
:nosignatures:

model_fusion.DeepFusionModel
model_fusion.FusionLayer
model_fusion.FusionEmbedding
model_fusion.register_fusion_module


Module Utilities
------------------
These are utilities that are common to and can be used by all modules.
Expand Down
5 changes: 5 additions & 0 deletions tests/torchtune/modules/model_fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
90 changes: 90 additions & 0 deletions tests/torchtune/modules/model_fusion/test_fusion_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected, fixed_init_model
from torchtune.modules.model_fusion import FusionEmbedding
from torchtune.utils.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(1)


class TestFusionEmbedding:
"""
Class for testing our FusionEmbedding.
"""

@pytest.fixture
def dim(self) -> int:
return 2

@pytest.fixture
def vocab_size(self) -> int:
return 10

@pytest.fixture
def fusion_vocab_size(self) -> int:
return 5

@pytest.fixture
def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding:
embeds = FusionEmbedding(
vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim
)
fixed_init_model(embeds.embedding, min_val=0, max_val=0.5)
fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1)
return embeds

@torch.no_grad()
def test_forward(self, embed, vocab_size, fusion_vocab_size, dim):
"""
Test that the forward pass of the FusionEmbedding works as expected.
"""
tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10))
out = embed(tokens)

assert out.shape == (2, 10, dim)
assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3)

# Only new tokens, embeddings should be > 0.5
tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10))
out = embed(tokens)

assert out.min() > 0.5

# Only old tokens, embeddings should be < 0.5
tokens = torch.randint(0, vocab_size, (2, 10))
out = embed(tokens)

assert out.max() < 0.5

def test_fusion_params(self, embed):
"""
Test that the currect fusion params are returned.
"""
fusion_params = set(embed.fusion_params())

assert fusion_params == {"fusion_embedding.weight"}

def test_get_and_load_state_dict(self, embed):
"""
Test that the state dict hooks work in removing the "layer" variable
"""
state_dict = embed.state_dict()
state_keys = set(state_dict.keys())

assert state_keys == {
"weight",
"fusion_embedding.weight",
}

# Check that the state_dict can be loaded back into the model
embed.load_state_dict(state_dict)
121 changes: 121 additions & 0 deletions tests/torchtune/modules/model_fusion/test_fusion_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected, fixed_init_model
from torch import nn
from torchtune.modules.model_fusion import FusionLayer
from torchtune.utils.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(1)


class DummyLayer(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.cache_enabled = False

def setup_cache(self, batch_size, dtype):
self.cache_enabled = True

def reset_cache(self):
self.cache_enabled = False

def forward(self, x):
return self.linear(x)


class TestFusionLayer:
"""
Class for testing our FusionLayer wrapper.
"""

@pytest.fixture
def dim(self) -> int:
return 2

@pytest.fixture
def layer(self, dim) -> nn.Module:
layer = DummyLayer(dim)
fixed_init_model(layer, min_val=-0.1, max_val=0.1)
return layer

@pytest.fixture
def fusion_layer(self, dim) -> nn.Module:
layer = DummyLayer(dim)
fixed_init_model(layer, min_val=-0.2, max_val=0.2)
return layer

@pytest.fixture
def fused_layer(self, layer, fusion_layer) -> FusionLayer:
return FusionLayer(layer, fusion_layer)

@torch.no_grad()
def test_forward(self, fused_layer, dim):
"""
Test that the forward pass of the FusionLayer works as expected.
"""
x = torch.rand((1, dim))
out = fused_layer(x)

assert out.shape == (1, dim)
assert_expected(out.mean(), torch.tensor(-0.0316), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_fusion_last_forward(self, layer, fusion_layer, dim) -> nn.Module:
"""
Test the forward method with fusion_first=False.
"""
fused_layer = FusionLayer(layer, fusion_layer, fusion_first=False)

x = torch.rand((1, dim))
out = fused_layer(x)

assert out.shape == (1, dim)
assert_expected(out.mean(), torch.tensor(-0.0816), atol=1e-3, rtol=1e-3)

def test_get_and_load_state_dict(self, fused_layer):
"""
Test that the state dict hooks work in removing the "layer" variable
"""
state_dict = fused_layer.state_dict()
state_keys = set(state_dict.keys())

assert state_keys == {
"linear.weight",
"linear.bias",
"fusion_layer.linear.weight",
"fusion_layer.linear.bias",
}

# Check that the state_dict can be loaded back into the model
fused_layer.load_state_dict(state_dict)

def test_fusion_params(self, fused_layer):
"""
Test that the currect fusion params are returned.
"""
fusion_params = set(fused_layer.fusion_params())

assert fusion_params == {
"fusion_layer.linear.weight",
"fusion_layer.linear.bias",
}

def test_setup_cache(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(2, torch.float32)
assert fused_layer.cache_enabled
fused_layer.reset_cache()
assert not fused_layer.cache_enabled
147 changes: 147 additions & 0 deletions tests/torchtune/modules/model_fusion/test_fusion_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected, fixed_init_model
from torch import nn
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.utils.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(1)


class DummyModel(nn.Module):
def __init__(self, dim, vocab_size):
super().__init__()
self.cache_enabled = False
self.embed = nn.Embedding(vocab_size, dim)
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.output = nn.Linear(dim, vocab_size)

def setup_caches(self, batch_size, dtype):
self.cache_enabled = True

def caches_are_enabled(self):
return self.cache_enabled

def reset_caches(self):
self.cache_enabled = False

def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos):
x = self.embed(tokens)
if encoder_input is not None:
q = self.q(x)
k = self.k(encoder_input)
v = self.v(encoder_input)
x += nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=encoder_mask
)
x = self.output(x)
return x


class TestDeepFusionModel:
"""
Class for testing our DeepFusionModel wrapper.
"""

@pytest.fixture
def vocab_size(self) -> int:
return 100

@pytest.fixture
def dim(self) -> int:
return 64

@pytest.fixture
def encoder(self, dim, vocab_size) -> nn.Module:
encoder = nn.Embedding(vocab_size, dim)
fixed_init_model(encoder)
return encoder

@pytest.fixture
def decoder(self, dim, vocab_size) -> nn.Module:
decoder = DummyModel(dim, vocab_size)
fixed_init_model(decoder, max_val=0.1)
return decoder

@pytest.fixture
def fused_model(self, encoder, decoder) -> DeepFusionModel:
model = DeepFusionModel(
encoder=encoder,
decoder=decoder,
)
return model

@pytest.fixture
def inputs(self, dim, vocab_size):
batch_size = 2
seq_len = 10
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
encoder_input = {"input": torch.randint(0, vocab_size, (batch_size, seq_len))}
encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool()
input_pos = torch.Tensor([1]).int()
return tokens, encoder_input, encoder_mask, input_pos

@torch.no_grad()
def test_forward(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel works as expected.
"""
tokens, encoder_input, encoder_mask, _ = inputs
batch_size, seq_len = tokens.shape
out = fused_model(
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_forward_no_encoding(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel with no encoder input.
"""
tokens, *_ = inputs
batch_size, seq_len = tokens.shape
out = fused_model(tokens)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_decoding_forward(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel works during decoding.
"""
tokens, encoder_input, encoder_mask, input_pos = inputs
tokens = tokens[:, input_pos]
batch_size, seq_len = tokens.shape
out = fused_model(
tokens,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3)

def test_setup_cache(self, fused_model):
"""
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_enabled()
fused_model.reset_caches()
assert not fused_model.caches_are_enabled()
Loading

0 comments on commit 67f6a06

Please sign in to comment.