Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] MOE design in Torchtune #1902

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

acisseJZhong
Copy link

@acisseJZhong acisseJZhong commented Oct 24, 2024

[RFC] MOE design in Torchtune

Background

This RFC proposes adding the MOE support in Torchtune. We want to design in a general way so that components can be easily swapped when implementing different MOE models. An MOE layer directly replaces the dense FFN layer in the transformer decoder layer and has two main components: router and experts.

Expert

An expert is essentially an FFN layer similar to the original dense FFN layer in the transformer decoder layer. There are two kinds of experts: routed experts and shared experts. Each expert in the routed experts specializes in learning certain patterns/aspects, and only part of the routed experts will be activated. On the other hand, shared experts are always activated, aiming at capturing and consolidating common knowledge across varying contexts.

Here's the proposed Experts design in torchtune:

class Experts(nn.Module):
    def __init__(self, dim_in, dim_out, num_experts=1, swiglu=True, nonlinearity=None):
        self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
        self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
        if swiglu:
            self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
            self.act_fn = F.silu()
        else:
            self.up_proj = None
            self.act_fn = nonlinearity

    def forward(self, x, num_local_tokens_per_expert=None):
        '''
        inputs:
            x: input tokens
                shape [bs*slen*experts_per_token, hidden_dim] for TC forward
                shape [num_experts*tokens_per_expert, hidden_dim] for EC forward
            num_local_tokens_per_expert: number of tokens for each expert, only used for TC forward
        outputs:
            out: output tokens
                shape [bs*slen*experts_per_token, hidden_dim] for TC forward
                shape [num_experts*tokens_per_expert, hidden_dim] for EC forward
        '''
        # TC forward
        if num_local_tokens_per_expert is not None:
            # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance
            # x shape [bs*slen*experts_per_token, hidden_dim]
            # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim]
            x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0)
            out_expert_splits = []
            for expert_index, x_expert_split in enumerate(x_expert_splits):
                gate_proj = self.gate_proj[expert_index]
                down_proj = self.down_proj[expert_index]
                up_proj = None
                if self.up_proj is not None:
                    up_proj = self.up_proj[expert_index]

                h = self.act_fn(torch.matmul(x_expert_split, gate_proj))
                if up_proj is not None:
                    h = h * torch.matmul(x_expert_split, up_proj)
                # [tokens_per_expert, hidden_dim]
                h = torch.matmul(h, down_proj)

                out_expert_splits.append(h)
            # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim]
            out = torch.cat(out_expert_splits, dim=0)
        # EC forward
        else:
            # x shape [num_experts, tokens_per_expert, hidden_dim]
            x = x.view(num_experts, -1, dim_in)
            h = self.act_fn(torch.bmm(x, self.gate_proj))
            if self.up_proj is not None:
                h = h * torch.bmm(x, self.up_proj)
            out = torch.bmm(h, self.down_proj).view(-1, dim_in)
        return out

# Expert builder for routed experts
def moe_experts(hidden_dim, model_dim, num_experts, swiglu=True, nonlinearity=None):
    return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity)

# Single expert / shared expert
def moe_expert(hidden_dim, model_dim, swiglu=True, nonlinearity=None):
    return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=1, swiglu=swiglu, nonlinearity=nonlinearity)

Router

Router is a gating network that calculates router scores and learns token-to-expert affinity. There are two types of routing: token choice routing and expert choice routing.

Mixtral uses token choice topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router then defines how tokens select experts based on router scores.

Here's the proposed Token Choice Routing design in torchtune:

class TokenChoiceTopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, experts_per_token):
        self.gate = nn.Linear(hidden_dim, num_experts)
        self.experts_per_token = experts_per_token

    def forward(self, x, use_sigmoid=False):
        '''
        input:
            x: input tokens
                shape [bs*slen, hidden_dim]
        outputs:
            routed_input: tokens gather by selected experts
                shape [bs*slen*experts_per_token, hidden_dim]
            token_indices: token indices sorted by selected experts indices
            num_local_tokens_per_expert: number of tokens assigned to each expert
                shape [num_experts,]
        '''
        # scores shape [bs*slen, num_experts]
        scores = self.gate(x)
        if use_sigmoid:
            scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype)
        else:
            scores = F.softmax(scores.to(softmax_dtype), dim=1).to(x.dtype)

        # TODO: implement load balancing auxiliary loss for token choice routing
        # https://github.com/NVIDIA/Megatron-LM/blob/f1f039224584f0bc6ba89c21ef4f491d7136e3ce/megatron/core/transformer/moe/router.py#L162

        # router scores/indices shape [bs*slen, experts_per_token]
        top_scores, selected_experts_indices = torch.topk(scores, k=self.experts_per_token, dim=1)
        top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)

        # shape [num_experts,]: how many tokens for each expert
        num_local_tokens_per_expert = torch.histc(selected_expert_indices.view(-1), bins=num_experts, min=0, max=num_experts)
        # shape [bs*slen*experts_per_token,]
        token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True)
        # top_scores shape [bs*slen*experts_per_token,]
        top_scores = top_scores.view(-1)[token_indices_experts_sorted]

        # token_indices shape [bs*slen*experts_per_token, hidden_dim]
        token_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim)
        # routed_input shape [bs*slen*experts_per_token, hidden_dim]
        routed_input = torch.gather(x, dim=0, index=token_indices)
        routed_input = routed_input * top_scores

        return routed_input, token_indices, num_local_tokens_per_expert

However, token choice routing has several pitfalls according to the expert choice paper.

  1. Poor load balance. Experts can become under or over-specialized. Load imbalance can hurt step latency / inference time.
  2. Experts under specialization. Ideally the gating network will learn token-to-expert affinity such that similar or relevant tokens are routed to the same expert. However, a sub-optimal strategy can produce redundant experts and/or experts that are not sufficiently specialized.
  3. Same compute for each token. Token choice will allocate a fixed number of experts to each token regardless of the importance of different tokens. Ideally an MOE model should flexibly allocate compute resources based on the complexity of the input.

Compared to token choice, expert choice topK routing lets experts select its top-k tokens. The ExpertChoiceTopKRouter class routes input tokens to different experts based on the router scores.

Here's the proposed Expert Choice Routing design in torchtune:

class ExpertChoiceTopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts):
        self.gate = nn.Linear(hidden_dim, num_experts)
        self.tokens_per_expert = tokens_per_expert

    def forward(self, x, use_sigmoid=False):
        '''
        input:
            x: shape [bs*slen, hidden_dim]
        outputs:
            routed_input: selected tokens
                shape [num_experts*tokens_per_expert, hidden_dim]
            token_indices: selected token indices
            num_local_tokens_per_expert: None
        '''
        # scores shape [num_experts, bs*slen]
        scores = self.gate(x).transpose(0,1)
        if use_sigmoid:
            scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype)
        else:
            scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype)
        # router scores/indices shape [num_experts, tokens_per_expert]
        top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1)

        # apply the token preprocess function and then run experts forward
        token_indices = selected_token_indices.reshape(-1, 1).expand(-1, D)
        # routed input shape [num_experts*tokens_per_expert, hidden_dim]
        routed_input = torch.gather(x, dim=0, index=token_indices)
        routed_input = routed_input * top_scores.reshape(-1, 1)
        return routed_input, token_indices, None,

Moe Layer

An MOE layer consists of experts and routers.

Here's the proposed MoeLayer design in torchtune:

class MoeLayer(nn.Module):
    def __init__(self, router="token_choice"):
        self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts)
        self.shared_expert = moe_expert(hidden_dim, model_dim)
        if router == "token_choice":
            self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token)
        elif router == "expert_choice":
            self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert)
        else:
            raise NotImplementedError("This router is not supported yet!")

    def forward(self, x, infernece=False):
        routed_input, token_indices, num_local_tokens_per_expert = self.router(x)
        # routed output shape [num_experts*tokens_per_expert, hidden_dim] for EC, [bs*slen*experts_per_token, hidden_dim] for TC
        routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert)

        # shared expert
        if use_shared_expert:
            out = self.shared_expert(x)
        else:
            out = torch.zeros_like(x)

        # add experts output
        out.data = scatter_add_(
            out.data,
            routed_output,
            selected_indices,
        )
        return out

Model builder

Besides the above components: experts, routers, and MOE layers, we would need a model builder to pull all pieces together to form the Transformer decoder layer and then Transformer decoder:

Here's the proposed MOE model builder design in torchtune:

def moe(...) -> TransformerDecoder:
    # Build the decoder associated with the moe model. This includes
    # - Token embeddings
    # - num_layers number of TransfomerDecoderLayer block
    # - RMS Norm layer applied to the ouput of the transfomer
    # - Final projection into the token space'
    token_embeddings = nn.Embedding(vocab_size, embed_dim)
    self_attn = MultiHeadAttention()
    moe_layer = MoeLayer(router="token_choice") # or MoeLayer(router="expert_choice")
    norm = RMSNorm(dim=embed_dim)
    layer = TransformerSelfAttentionLayer(attn=self_attn, mlp=moe_layer, sa_norm=norm, mlp_norm=norm)
    output_proj = nn.Linear(embed_dim, vocab_size)
    return TransformerDecoder(
        tok_embeddings=tok_embeddings,
        layers=layer,
        num_layers=num_layers,
        max_seq_len=max_seq_len,
        num_heads=num_heads,
        head_dim=head_dim,
        norm=RMSNorm(dim=embed_dim),
        output=output_proj,
    )

File changes for new modules/functions

torchtune/
    modules/
        moe/
            moe_layers.py
                TokenChoiceTopKRouter()
                ExpertChoiceTopKRouter()
                MoeLayer()
            experts.py
                Experts()
    models/
        moe/
            _component_builders.py
                moe()
                moe_expert()
                moe_experts()

Copy link

pytorch-bot bot commented Oct 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1902

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ed424d8 with merge base dc0591c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 24, 2024
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't have chance to finish this yet, but I left some early comments/questions. Thanks for the detailed RFC!

RFC.md Outdated
return top_scores, top_indices


class ExpertChoiceMoeLayer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment on naming, we use layer at the decoder level. So I think this should just be something like MoEForward or MixtureOfExperts.

Do you think it would be possible to reuse this layer for both Token and Expert routing modules? The experts, shared_expert, and router are passed in inside init anyway. Or is there not enough in common between them?

If this module can't be generalized I wonder if it would make more sense to combine it with the router logic instead of having them be separate classes.

Copy link
Author

@acisseJZhong acisseJZhong Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I think MoeLayer / MoeBlock is a more common name, as I saw almost all MOE papers/OSS code call it this way. It also replaces the original FF layer / MLP layer.

  2. at least in my mind it's hard to reuse the same MoeLayer for both token and expert routings, since the MoeLayer's forward function is where we route tokens into different experts based on either token choice or expert choice routing.
    For token choice, we are looping over each experts, and calculate hidden states for each selected expert.
    For expert choice, not all tokens will be chosen, we are gathering only the selected tokens and run experts forward on them.
    Please let me know if you have a better way in mind to combine these two!

  3. We can combine router with MoeLayer(in fact it is combined during my first draft), but I think decoupling router, experts, MoeLayer would also be nice.

```python
class Experts(nn.Module):
def __init__(self, dim_in, dim_out, nonlinearity, num_experts=1, swiglu=True):
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the implications of defining your experts all as a single parameter? Noob thoughts here; I think it makes sense that it would be faster for training since all experts are updated each step, but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not an expert, but)I think the reason of combining all the experts is to use torch.bmm() and groupedGEMM instead of torch.matmul() to optimize performance.

Good question, for inference we are always doing TC even if training uses EC. And for TC forward, we are essentially looping over the experts one at a time.


**Here's the proposed Experts design in torchtune:**
```python
class Experts(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming, I think this should be something like MoELinear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what's the reason? I was thinking having Experts, Routers, and MoELayers is more clear in terms of what each class is doing. Otherwise, there's no Experts class and this MoELinear is essentially Experts? I don't have any strong preferences, but just want to know why you think MoELinear is better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matches our own convention of defining FeedForward (singular) rather than a group of things like Experts.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But here we are defining a group of experts/feedforward. Wouldn't MoELinear be confusing? It's more like GroupedMoELinear?

RFC.md Outdated
# expert_mask shape [num_experts, experts_per_token, bs*slen]
expert_mask = torch.nn.functional.one_hot(selected_experts_indices, num_class=num_experts).permute(2,1,0)
out = torch.zeros((batch_size * seq_len, hidden_dim))
for i in range(num_experts):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this not be done in the broadcast format that Expert Choice is?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this. After some thoughts, I think it can be done in a way similar to EP. But it then requires special TC Forward(defined separately in the Expert forward function). I updated the RFC to have both approaches!

@joecummings joecummings added the rfc Request for comments label Oct 25, 2024
@RdoubleA
Copy link
Contributor

Thanks for the awesome RFC, really appreciate the detail in the code implementations. One concern I have is that token choice vs expert choice is conceptually spread across both the router and the experts classes, making them pretty tightly coupled, i.e., if I use TokenChoiceRouter I have to make sure I set use_token_choice=False in the expert forward, otherwise I will get totally incorrect results. Also, if I ever want to use the Experts class with a new routing mechanism, I would need add a lot of if-else chunks and more parameters.

I would try to either make the experts entirely routing agnostic (not sure if this is possible, based on your code it seems to affect the forward quite significantly), make separate expert classes for token choice / expert choice, or just combine the expert forward logic and routing logic into one class.

but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts

This is a great point, wouldn't the inference behavior be entirely different anyway? How will we handle different inference logic and potential optimizations?

@acisseJZhong
Copy link
Author

acisseJZhong commented Oct 28, 2024

I would try to either make the experts entirely routing agnostic (not sure if this is possible, based on your code it seems to affect the forward quite significantly), make separate expert classes for token choice / expert choice, or just combine the expert forward logic and routing logic into one class.

Thanks for the suggestion. It makes sense! unfortunately I think expert choice and token choice needs to have different forward impl. This is because for expert choice, each expert has tokens_per_expert tokens and it is fixed. However, for token choice, tokens_per_expert is different for each expert. This is also why we passed num_local_tokens_per_expert into token choice forward function.

I think making them into separate expert classes is reasonable. So we will have TokenChoiceExperts and ExpertChoiceExperts. I am hesitant on combining the expert forward logic and routing logic, as this makes things even more complicated and hard to understand.

but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts

This is a great point, wouldn't the inference behavior be entirely different anyway? How will we handle different inference logic and potential optimizations?

Yeah also thanks for raising this question. I didn't keep inference in mind during the first draft design. I am discussing/consulting with Jie more about MOE inference.

Comment on lines +180 to +185
if router == "token_choice":
self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token)
elif router == "expert_choice":
self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert)
else:
raise NotImplementedError("This router is not supported yet!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor point, but I wouldn't configure router type using strings. Just pass as an nn.Module as that's more generic (then anyone with their own custom router matching the same signature can use out of the box)

# - Final projection into the token space'
token_embeddings = nn.Embedding(vocab_size, embed_dim)
self_attn = MultiHeadAttention()
moe_layer = MoeLayer(router="token_choice") # or MoeLayer(router="expert_choice")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider an moe_layer builder function as well to improve ease of configurability

else:
scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype)
# router scores/indices shape [num_experts, tokens_per_expert]
top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline a bit but we should make sure there isn't any information leakage happening here for decoder models. Specifically, this line seems a bit fishy to me. It seems like each expert is looking across all tokens and choosing its top tokens_per_expert based on the entire sequence, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rfc Request for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants