Skip to content

Commit

Permalink
precompute collapsed linear transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
thwu1 committed Jul 14, 2024
1 parent 3a85b9e commit cc67b87
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
35 changes: 28 additions & 7 deletions routellm/routers/matrix_factorization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ def __init__(
text_dim,
num_classes,
use_proj,
collapse_linear=False,
):
super().__init__()
self._name = "TextMF"
self.use_proj = use_proj
self.collapse_linear = collapse_linear # collapse the linear transformations into a single linear layer
self.P = torch.nn.Embedding(num_models, dim)

self.embedding_model = "text-embedding-3-small"
Expand All @@ -104,19 +105,20 @@ def get_device(self):
return self.P.weight.device

def forward(self, model_id, prompt):
model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device())

model_embed = self.P(model_id)
model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1)

prompt_embed = (
OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model)
.data[0]
.embedding
)
prompt_embed = torch.tensor(prompt_embed, device=self.get_device())
prompt_embed = self.text_proj(prompt_embed)
model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device())

if self.collapse_linear:
upscaled_model_embed = self.precompute_upscaled_embedding(model_id)
return upscaled_model_embed @ prompt_embed.squeeze(-1)

model_embed = self.P(model_id)
prompt_embed = self.text_proj(prompt_embed)
return self.classifier(model_embed * prompt_embed).squeeze()

@torch.no_grad()
Expand All @@ -127,3 +129,22 @@ def pred_win_rate(self, model_a, model_b, prompt):

def load(self, path):
self.load_state_dict(torch.load(path))

def post_process_weight(self):
# since the current model consist of only linear transformations
# we can collapse the linear transformations into a single linear layer
# https://github.com/lm-sys/RouteLLM/issues/9
num_models = self.P.weight.shape[0]
text_dim = self.text_proj[0].weight.shape[1]

self.P.weight.data = torch.nn.functional.normalize(
self.P.weight.data, p=2, dim=1
)

if self.collapse_linear:
self.precompute_upscaled_embedding = torch.nn.Embedding(
num_models, text_dim
)
self.precompute_upscaled_embedding.weight.data = (
self.P.weight * self.classifier[0].weight.data
) @ self.text_proj[0].weight.data
1 change: 1 addition & 0 deletions routellm/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(
num_classes=num_classes,
use_proj=use_proj,
)
self.model.post_process_weight()
self.model = self.model.eval().to(device)
self.strong_model_id = MODEL_IDS[strong_model]
self.weak_model_id = MODEL_IDS[weak_model]
Expand Down

0 comments on commit cc67b87

Please sign in to comment.