[Paper]
- [2024.08.01]:🚀 Palu ver. 1 is released.
Palu is a KV-Cache compression framework that utilizes low-rank projection to compress the hidden dimension of KV-Cache, thereby reducing memory footprint and increasing speed.
Palu is a pioneer KV-Cache compression framework that reduce the hidden dimenssion of KV-Cache via low-rank projection. Different from MLA in DeepSeek-V2 that requires a large-scale training from scratch, Palu works with existing LLMs such as Llama3, Mistral, in a post-training manner. To achieve this, Palu decomposes the linear layers into low-rank matrices, caches the smaller intermediate states, and reconstructs the full keys and values on the fly. To improve accuracy, compression rate, and efficiency, Palu further encompasses (1) a medium-grained low-rank decomposition scheme, (2) an efficient rank search algorithm, (3) matrix fusion for quantization friendliness enhancements, and (4) co-designed GPU kernels.
Our extensive experiments with popular LLMs show that Palu can compress KV-Cache by more than 91.25% while maintaining a significantly better accuracy (up to 1.19 lower perplexity) than state-of-the-art KV-Cache quantization methods at a similar or even higher memory usage. For more details, please refer to our paper.
- Upgrade
transformers>=4.43.3
, for Llama3.1 support - Update reconstruction kernel, with quantization integrated.
- Support FlashAttention or FlashInfer to enhance competatiblity
- Clone the repository (Make sure you have Git, Conda installed on your system)
git clone --recurse-submodules https://github.com/shadowpa0327/Palu.git
cd Palu
- Prepare environment
conda create -n Palu python=3.10
conda activate Palu
pip install -r requirements.txt
- Install 3rdparty libraries
pip install -e 3rdparty/lm-evaluation-harness
pip install -e 3rdparty/fast-hadamard-transform
We provide a script compress.py
to perform the rank search and low-rank decomposition to generate the low-rank projection matrices for compressing KV-Cache. Here, we perform the decomposition with proposed G-LRD
methods with group size equal to 4 as an example.
python compress.py \
--model_id="meta-llama/Llama2-7b-hf" \
--calib_dataset wikitext2 \
--param_ratio_target 0.7 \
--search_method fisher_uniform \
--head_group_size 4 \
--dump_huggingface_model \
--use_cache
After executing the above command, a compressed models with decomposed low-rank projection matrices will be dumped into the Llama-2-7b-hf_ratio-0.5_gs-4-fisher_uniform
directory. Here, the dumped models is stored via the huggingface transformers format.
With the compressed model dumped, we can evaluate the performance of the compressed model on the various tasks. We provide the scripts for evaluating the perplexity, zero-shot evaluation, and LongBench. By default, we will keep the compressed KV-Cache in fp16.
To evaluate the perplexity of the compressed model on the wikitext2
dataset with sequence length 2048, we can use the ppl_eval.py
script.
python run_ppl_eval.py \
--model_name_or_path /Path/To/Palu/Model \
--datasets wikitext2 \
--seqlen 2048
To reproduce the evalaution of c4
perplexity, simply change the datasets
argument to c4
.
To evaluate the performance of quantization being integrated, please pass --lt_bit {num_bits}
and --lr_haramard
in the arguments to enable our low-rank aware quantization.
For example, to evaluate the Palu with 3-bit low-rank aware quantization, please run:
python run_ppl_eval.py \
--model_name_or_path /Path/To/Palu/Model \
--datasets wikitext2 \
--seqlen 4096 \
--lt_bits 3 \
--lt_hadamard
Note: run_ppl_eval.py
does not support multi-gpu evaluation. If your machine has multiple GPUs, please set CUDA_VISIBLE_DEVICES
to the desired GPU id.
To run zero-shot evaluations, please use the run_lm_eval.py
script, which implements a wrapper around the lm-evaluation-harness library.
Before we start, please make sure the lm-eval==0.4.2
library is installed.
To reproduce the results in our paper, simply execute:
CUDA_VISIBLE_DEVICES=0 python run_lm_eval.py --model_name_or_path "./Meta-Llama-3-8b-Instruct_ratio-0.7_gs-4-fisher_uniform" \
--tasks "openbookqa,hellaswag,piqa,arc_easy,arc_challenge,winogrande"
We also provide a script to run the long-bench evaluation on the compressed model.
CUDA_VISIBLE_DEVICES=0 python run_long_bench.py \
--model_name_or_path /Path/To/Palu/Model
The scrips will evaluate on "triviaqa", "qasper", "trec", "samsum", "lcc", "repobench-p", "qmsum" and "multi_news" datasets by default.
User may also leverage the --datasets
argument to specify the tasks to evaluate. For example, add --datasets "triviaqa,qasper"
to evaluate on "triviaqa" and "qasper" datasets only.
We provide a script to evaluate the compressed attention module latency under different settings. Below is an example demonstrating how to use this script to evaluate the latency of Palu attention module.
CUDA_VISIBLE_DEVICES=0 python run_latency_attention.py \
--rank_k 1024 --rank_v 3072 --group_size 4 \
--prompt_len 65536 --palu
The above command will run the latency evaluation with the following arguments:
--rank_k
: Set the rank of the key matrix.--rank_v
: Set the rank of the value matrix.--group_size
: Set the group size, which is used in the low-rank decomposition.--prompt_len
: Set the prompt length.--palu
: Enable Palu compression technique.
We also provide a script to evaluate our reconstruction kernel latency. Below is an example demonstrating how to use this script to evaluate the latency of Palu reconstruction kernel.
CUDA_VISIBLE_DEVICES=0 python run_latency_kernel.py \
--total_rank 1024 --group_size 4
If you find this work useful, please consider citing our paper:
@misc{chang2024palucompressingkvcachelowrank,
title={Palu: Compressing KV-Cache with Low-Rank Projection},
author={Chi-Chih Chang and Wei-Cheng Lin and Chien-Yu Lin and Chong-Yan Chen and Yu-Fang Hu and Pei-Shuo Wang and Ning-Chi Huang and Luis Ceze and Kai-Chiang Wu},
year={2024},
eprint={2407.21118},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2407.21118},
}