-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] MOE design in Torchtune #1902
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1902
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ed424d8 with merge base dc0591c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
3c65a3c
to
ecd8e5f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't have chance to finish this yet, but I left some early comments/questions. Thanks for the detailed RFC!
RFC.md
Outdated
return top_scores, top_indices | ||
|
||
|
||
class ExpertChoiceMoeLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment on naming, we use layer at the decoder level. So I think this should just be something like MoEForward or MixtureOfExperts.
Do you think it would be possible to reuse this layer for both Token and Expert routing modules? The experts, shared_expert, and router are passed in inside init anyway. Or is there not enough in common between them?
If this module can't be generalized I wonder if it would make more sense to combine it with the router logic instead of having them be separate classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
I think MoeLayer / MoeBlock is a more common name, as I saw almost all MOE papers/OSS code call it this way. It also replaces the original FF layer / MLP layer.
-
at least in my mind it's hard to reuse the same MoeLayer for both token and expert routings, since the MoeLayer's forward function is where we route tokens into different experts based on either token choice or expert choice routing.
For token choice, we are looping over each experts, and calculate hidden states for each selected expert.
For expert choice, not all tokens will be chosen, we are gathering only the selected tokens and run experts forward on them.
Please let me know if you have a better way in mind to combine these two! -
We can combine router with MoeLayer(in fact it is combined during my first draft), but I think decoupling router, experts, MoeLayer would also be nice.
```python | ||
class Experts(nn.Module): | ||
def __init__(self, dim_in, dim_out, nonlinearity, num_experts=1, swiglu=True): | ||
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the implications of defining your experts all as a single parameter? Noob thoughts here; I think it makes sense that it would be faster for training since all experts are updated each step, but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(not an expert, but)I think the reason of combining all the experts is to use torch.bmm() and groupedGEMM instead of torch.matmul() to optimize performance.
Good question, for inference we are always doing TC even if training uses EC. And for TC forward, we are essentially looping over the experts one at a time.
|
||
**Here's the proposed Experts design in torchtune:** | ||
```python | ||
class Experts(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming, I think this should be something like MoELinear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious what's the reason? I was thinking having Experts, Routers, and MoELayers is more clear in terms of what each class is doing. Otherwise, there's no Experts class and this MoELinear
is essentially Experts? I don't have any strong preferences, but just want to know why you think MoELinear is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It matches our own convention of defining FeedForward (singular) rather than a group of things like Experts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But here we are defining a group of experts/feedforward. Wouldn't MoELinear
be confusing? It's more like GroupedMoELinear
?
RFC.md
Outdated
# expert_mask shape [num_experts, experts_per_token, bs*slen] | ||
expert_mask = torch.nn.functional.one_hot(selected_experts_indices, num_class=num_experts).permute(2,1,0) | ||
out = torch.zeros((batch_size * seq_len, hidden_dim)) | ||
for i in range(num_experts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this not be done in the broadcast format that Expert Choice is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for raising this. After some thoughts, I think it can be done in a way similar to EP. But it then requires special TC Forward(defined separately in the Expert forward function). I updated the RFC to have both approaches!
Thanks for the awesome RFC, really appreciate the detail in the code implementations. One concern I have is that token choice vs expert choice is conceptually spread across both the router and the experts classes, making them pretty tightly coupled, i.e., if I use TokenChoiceRouter I have to make sure I set I would try to either make the experts entirely routing agnostic (not sure if this is possible, based on your code it seems to affect the forward quite significantly), make separate expert classes for token choice / expert choice, or just combine the expert forward logic and routing logic into one class.
This is a great point, wouldn't the inference behavior be entirely different anyway? How will we handle different inference logic and potential optimizations? |
Thanks for the suggestion. It makes sense! unfortunately I think expert choice and token choice needs to have different forward impl. This is because for expert choice, each expert has I think making them into separate expert classes is reasonable. So we will have
Yeah also thanks for raising this question. I didn't keep inference in mind during the first draft design. I am discussing/consulting with Jie more about MOE inference. |
if router == "token_choice": | ||
self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) | ||
elif router == "expert_choice": | ||
self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) | ||
else: | ||
raise NotImplementedError("This router is not supported yet!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a minor point, but I wouldn't configure router type using strings. Just pass as an nn.Module as that's more generic (then anyone with their own custom router matching the same signature can use out of the box)
# - Final projection into the token space' | ||
token_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
self_attn = MultiHeadAttention() | ||
moe_layer = MoeLayer(router="token_choice") # or MoeLayer(router="expert_choice") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider an moe_layer
builder function as well to improve ease of configurability
else: | ||
scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype) | ||
# router scores/indices shape [num_experts, tokens_per_expert] | ||
top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline a bit but we should make sure there isn't any information leakage happening here for decoder models. Specifically, this line seems a bit fishy to me. It seems like each expert is looking across all tokens and choosing its top tokens_per_expert based on the entire sequence, right?
[RFC] MOE design in Torchtune
Background
This RFC proposes adding the MOE support in Torchtune. We want to design in a general way so that components can be easily swapped when implementing different MOE models. An MOE layer directly replaces the dense FFN layer in the transformer decoder layer and has two main components: router and experts.
Expert
An expert is essentially an FFN layer similar to the original dense FFN layer in the transformer decoder layer. There are two kinds of experts: routed experts and shared experts. Each expert in the routed experts specializes in learning certain patterns/aspects, and only part of the routed experts will be activated. On the other hand, shared experts are always activated, aiming at capturing and consolidating common knowledge across varying contexts.
Here's the proposed Experts design in torchtune:
Router
Router is a gating network that calculates router scores and learns token-to-expert affinity. There are two types of routing: token choice routing and expert choice routing.
Mixtral uses token choice topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router then defines how tokens select experts based on router scores.
Here's the proposed Token Choice Routing design in torchtune:
However, token choice routing has several pitfalls according to the expert choice paper.
Compared to token choice, expert choice topK routing lets experts select its top-k tokens. The ExpertChoiceTopKRouter class routes input tokens to different experts based on the router scores.
Here's the proposed Expert Choice Routing design in torchtune:
Moe Layer
An MOE layer consists of experts and routers.
Here's the proposed MoeLayer design in torchtune:
Model builder
Besides the above components: experts, routers, and MOE layers, we would need a model builder to pull all pieces together to form the Transformer decoder layer and then Transformer decoder:
Here's the proposed MOE model builder design in torchtune:
File changes for new modules/functions