Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement paged attention in HF format? #616

Open
fahadh4ilyas opened this issue Sep 6, 2024 · 5 comments
Open

How to implement paged attention in HF format? #616

fahadh4ilyas opened this issue Sep 6, 2024 · 5 comments

Comments

@fahadh4ilyas
Copy link
Contributor

So, I just create exllamav2 in HF format and it works well in batch. My code is in #606. Now, I got new problem. Bigger batch means bigger memory usage and mostly is for padding especially if there is different size in token sequence. Could you explain to me how exllamav2 paged attention works in code? I check the code in exllamav2/model.py, PagedParams is used but I don't know what to fill into the parameter.

@turboderp
Copy link
Owner

turboderp commented Sep 6, 2024

To use the paged mode (flash-attn only), you first need a cache initialized with a batch size of 1 and a length which is some multiple of the page size. The page size is always 256 with the current version of flash-attn. Essentially this cache won't have a shape, just a total capacity.

PagedParams is constructed like so:

params = ExLlamaV2Attention.PagedParams(
    batch_size = batch_size,
    block_index = block_index,
    cache_seqlens = cache_seqlens,
    max_cache_seqlen = cache_seqlens.max().item(),
    page_size = 256,
    q_len = q_len,
)
  • batch_size here is the actual size of your batch, even though you're using a flat cache.
  • block_index is an int tensor of shape (batch_size, max_num_pages) which defines which pages in the cache to use for which sequences in the batch. It can be padded to the right with arbitrary values for pages you'll never get to.
  • cache_seqlens is an int tensor of shape (batch_size,) determining where in each sequence the input IDs to the forward pass belong.
  • q_len is the length of whatever you're sending through the forward pass, typically one. input_ids to the model would therefore be shape (batch_size, q_len)

So say you have three sequences that are currently 10, 1025 and 320 tokens long, respectively, and you want room in the cache for each to grow by 500 tokens. You're forwarding a single token. That could look like:

batch_size: 
    3

block_index:
    [
        [  0, 1,  0,  0,  0, 0 ],   # positions 0:512in the cache, and some padding
        [  2, 3,  4,  5,  6, 7 ],   # positions 512:2048
        [  8, 9, 10, 11, 12, 0 ]    # positions 2048:3328+ padding
    ]
     
cache_seqlens:
    [ 10, 1025, 320 ]

page_size:
    256
    
q_len:
    1
 
input_ids:
    [ 
        [token_a],
        [token_b],
        [token_c]
    ]
    

So when the forward pass writes the keys/values for position 10, it only touches page 0 in the cache. At the same time it will write position 512+1025, which goes to page 6, etc. It's the cache_seqlens tensor that determines how long each past is and thereby which page to look up in the block index.

Now, there's some choices you could make about how to get to the above point in the first place. input_ids is still always a rectangular tensor, so to prefill the initial 10, 1025 and 320 tokens you'd need to to three forward passes to avoid padding.

You could do one with a shape of (3, 10), then another with shape (2, 310) and finally (1, 705).

Or you just do each sequence in the element as a bsz 1 forward pass. This is what the dynamic generator does and it simplifies things a lot, especially for continuous batching. I.e.:

prompt a:
    batch_size: 1
    block_index: [[0]]
    cache_seqlens: [[0]]
    q_len: 10
    input_ids: tokenizer.encode(prompt_a)

prompt b:
    batch_size: 1
    block_index: [[2, 3, 4, 5, 6]]
    cache_seqlens: [[0]]
    q_len: 1025
    input_ids: tokenizer.encode(prompt_b)

prompt c:
    batch_size: 1
    block_index: [[8, 9]]
    cache_seqlens: [[0]]
    q_len: 320
    input_ids: tokenizer.encode(prompt_c)

There's a bunch of fun details about paged attention, such as the fact that the page indices don't need to be contiguous. Also they don't need to be unique, as long as you're not updating the same page twice in a forward pass. The dynamic generator uses both of those details for deduplication and continuous batching, respectively.

If you wanted to not have a predefined length max_new_tokens you could allocate pages dynamically during inference. There's nothing that prevents you from adding page 13 after page 1 in the first sequence, or growing the block_index tensor by one column to add page 14 after page 7.

It does of course require some bookkeeping in your generator, and I'm not sure how well that plays together with HF and pipelines and whatnot.

@fahadh4ilyas
Copy link
Contributor Author

Okay, I kind of get the concept. I think I want to forward each sequence as a bsz 1 forward pass. Does this means we have to do for-looping each sequence for one big batch forward pass? What about the cache instance? should I make one for each sequence or just make one for all? But, how the cache know which sequence is forwarded with it?

@turboderp
Copy link
Owner

You use one cache for everything, and it's the block_index tensor that says which pages in the cache are used for each sequence, whether you're doing them one at a time or batching.

One way to go about it would be to start by tokenizing all the prompts in a batch, then constructing the block index based on how many pages each sequence is going to need, including both the prompt and the completion:

block_index_batch:
    [
        [  0, 1,  0,  0,  0, 0 ],  # 10+500 tokens needs 2 pages
        [  2, 3,  4,  5,  6, 7 ],  # 1025+500 tokens -> 6 pages
        [  8, 9, 10, 11, 0, 0 ]  # 320+500 -> 4 pages
    ]

Then you run the three individual forward passes to prefill:

seq a: block_index = block_index_batch[0:1, :]
seq b: block_index = block_index_batch[1:2, :]
seq c: block_index = block_index_batch[2:3, :]

It doesn't matter if the block index has extra padding on the right, since it's indexed from the left. And then for each token you pass block_index_batch so you can index into all three sequences at once.

@fahadh4ilyas
Copy link
Contributor Author

I understand. But, I have another doubt. What about the input mask and position offset? For input mask might be solved because the masking process is done inside flash attention. But, what about position offset?

@turboderp
Copy link
Owner

You wouldn't use masking or position offsets in paged mode, only a list of sequence lengths, and then the flash-attn kernel handles the rest. This allows all sequences to start at position zero (as long as that corresponds to a page boundary in the cache, as determined by block_index) and have variable lengths as determined by cache_seqlens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants