Releases: HabanaAI/vllm-fork
v0.5.3.post1+Gaudi-1.18.0
vLLM with Intel® Gaudi® AI Accelerators - Gaudi Software Suite 1.18.0
Requirements and Installation
Please follow the instructions provided in the Gaudi Installation Guide to set up the environment. To achieve the best performance, please follow the methods outlined in the Optimizing Training Platform Guide.
Requirements
- OS: Ubuntu 22.04 LTS
- Python: 3.10
- Intel Gaudi accelerator
- Intel Gaudi software version 1.18.0
To verify that the Intel Gaudi software was correctly installed, run:
$ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
$ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed
$ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
$ pip list | grep neural # verify that neural-compressor is installed
Refer to Intel Gaudi Software Stack Verification for more details.
Run Docker Image
It is highly recommended to use the latest Docker image from Intel Gaudi vault. Refer to the Intel Gaudi documentation for more details.
Use the following commands to run a Docker image:
$ docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.1:latest
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.1:latest
Build and Install vLLM
Currently, the latest features and performance optimizations are developed in Gaudi's vLLM-fork and we periodically upstream them to vLLM main repo. To install latest HabanaAI/vLLM-fork, run the following:
$ git clone https://github.com/HabanaAI/vllm-fork.git
$ cd vllm-fork
$ git checkout v0.5.3.post1+Gaudi-1.18.0
$ pip install -e .
Supported Features
- Offline batched inference
- Online inference via OpenAI-Compatible Server
- HPU autodetection - no need to manually select device within vLLM
- Paged KV cache with algorithms enabled for Intel Gaudi accelerators
- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, prefill attention, Root Mean Square Layer Normalization, Rotary Positional Encoding
- Tensor parallelism support for multi-card inference
- Inference with HPU Graphs for accelerating low-batch latency and throughput
- Attention with Linear Biases (ALiBi)
- LoRA adapters
- Quantization with INC
Unsupported Features
- Beam search
- Prefill chunking (mixed-batch inferencing)
Supported Configurations
The following configurations have been validated to be function with Gaudi2 devices. Configurations that are not listed may or may not work.
- meta-llama/Llama-2-7b on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-7b-chat-hf on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-8B on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-8B-Instruct on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-8B on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-8B-Instruct on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-70b with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-70b-chat-hf with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-70B with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-70B-Instruct with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-70B with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-70B-Instruct with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- mistralai/Mistral-7B-Instruct-v0.3 on single HPU or with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling
- mistralai/Mixtral-8x7B-Instruct-v0.1 with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling
Performance Tuning
Execution modes
Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via PT_HPU_LAZY_MODE
environment variable), and --enforce-eager
flag.
PT_HPU_LAZY_MODE |
enforce_eager |
execution mode |
---|---|---|
0 | 0 | torch.compile |
0 | 1 | PyTorch eager mode |
1 | 0 | HPU Graphs |
1 | 1 | PyTorch lazy mode |
Warning
In 1.18.0, all modes utilizing PT_HPU_LAZY_MODE=0
are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode.
Bucketing mechanism
Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. Intel Gaudi Graph Compiler is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - batch_size
and sequence_length
.
Note
Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase.
Bucketing ranges are determined with 3 parameters - min
, step
and max
. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup:
INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
min
determines the lowest value of the bucket. step
determines the interval between buckets, and max
d...
v0.6.2
Full Changelog: v0.5.4...v0.6.2
v0.5.3.post1-Gaudi-1.17.0
vLLM with Intel® Gaudi® AI Accelerators
This README provides instructions on running vLLM with Intel Gaudi devices.
Requirements and Installation
Please follow the instructions provided in the Gaudi Installation Guide to set up the environment. To achieve the best performance, please follow the methods outlined in the Optimizing Training Platform Guide.
Requirements
- OS: Ubuntu 22.04 LTS
- Python: 3.10
- Intel Gaudi accelerator
- Intel Gaudi software version 1.17.0
To verify that the Intel Gaudi software was correctly installed, run:
$ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
$ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed
$ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed
Refer to Intel Gaudi Software Stack Verification for more details.
Run Docker Image
It is highly recommended to use the latest Docker image from Intel Gaudi vault. Refer to the Intel Gaudi documentation for more details.
Use the following commands to run a Docker image:
$ docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
Build and Install vLLM
To build and install vLLM from source, run:
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ python setup.py develop
Currently, the latest features and performance optimizations are developed in Gaudi's vLLM-fork and we periodically upstream them to vLLM main repo. To install latest HabanaAI/vLLM-fork, run the following:
$ git clone https://github.com/HabanaAI/vllm-fork.git
$ cd vllm-fork
$ git checkout habana_main
$ python setup.py develop
Supported Features
- Offline batched inference
- Online inference via OpenAI-Compatible Server
- HPU autodetection - no need to manually select device within vLLM
- Paged KV cache with algorithms enabled for Intel Gaudi accelerators
- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, prefill attention, Root Mean Square Layer Normalization, Rotary Positional Encoding
- Tensor parallelism support for multi-card inference
- Inference with HPU Graphs for accelerating low-batch latency and throughput
Unsupported Features
- Beam search
- LoRA adapters
- Attention with Linear Biases (ALiBi)
- Quantization (AWQ, FP8 E5M2, FP8 E4M3)
- Prefill chunking (mixed-batch inferencing)
Supported Configurations
The following configurations have been validated to be function with Gaudi2 devices. Configurations that are not listed may or may not work.
- meta-llama/Llama-2-7b on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-7b-chat-hf on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-8B on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-8B-Instruct on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-8B on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-8B-Instruct on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-70b with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-70b-chat-hf with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-70B with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3-70B-Instruct with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-70B with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Meta-Llama-3.1-70B-Instruct with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- mistralai/Mistral-7B-Instruct-v0.3 on single HPU or with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling
- mistralai/Mixtral-8x7B-Instruct-v0.1 with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling
Performance Tuning
Execution modes
Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via PT_HPU_LAZY_MODE
environment variable), and --enforce-eager
flag.
PT_HPU_LAZY_MODE |
enforce_eager |
execution mode |
---|---|---|
0 | 0 | torch.compile |
0 | 1 | PyTorch eager mode |
1 | 0 | HPU Graphs |
1 | 1 | PyTorch lazy mode |
Warning
In 1.17.0, all modes utilizing PT_HPU_LAZY_MODE=0
are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.17.0, please use HPU Graphs, or PyTorch lazy mode.
Bucketing mechanism
Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. Intel Gaudi Graph Compiler is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - batch_size
and sequence_length
.
Note
Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase.
Bucketing ranges are determined with 3 parameters - min
, step
and max
. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup:
INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
min
...
v0.4.2-Gaudi-1.16.0
vLLM with Intel® Gaudi® 2 AI Accelerators
This README provides instructions on running vLLM with Intel Gaudi devices.
Requirements and Installation
Please follow the instructions provided in the Gaudi Installation Guide to set up the environment. To achieve the best performance, please follow the methods outlined in the Optimizing Training Platform Guide.
Note
In this release (1.16.0), we are only targeting functionality and accuracy. Performance will be improved in next releases.
Requirements
- OS: Ubuntu 22.04 LTS
- Python: 3.10
- Intel Gaudi 2 accelerator
- Intel Gaudi software version 1.16.0
To verify that the Intel Gaudi software was correctly installed, run:
$ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
$ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed
$ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed
Refer to Intel Gaudi Software Stack Verification for more details.
Run Docker Image
It is highly recommended to use the latest Docker image from Intel Gaudi vault. Refer to the Intel Gaudi documentation for more details.
Use the following commands to run a Docker image:
$ docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest
Build and Install vLLM-fork
To build and install vLLM-fork from source, run:
$ git clone https://github.com/HabanaAI/vllm-fork.git
$ cd vllm-fork
$ git checkout v0.4.2-Gaudi-1.16.0
$ pip install -e . # This may take 5-10 minutes.
Supported Features
- Offline batched inference
- Online inference via OpenAI-Compatible Server
- HPU autodetection - no need to manually select device within vLLM
- Paged KV cache with algorithms enabled for Intel Gaudi 2 accelerators
- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, prefill attention, Root Mean Square Layer Normalization, Rotary Positional Encoding
- Tensor parallelism support for multi-card inference
- Inference with HPU Graphs for accelerating low-batch latency and throughput
Unsupported Features
- Beam search
- LoRA adapters
- Attention with Linear Biases (ALiBi)
- Quantization (AWQ, FP8 E5M2, FP8 E4M3)
- Prefill chunking (mixed-batch inferencing)
Supported Configurations
The following configurations have been validated to be function with Gaudi devices. Configurations that are not listed may or may not work.
- meta-llama/Llama-2-7b on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-7b-chat-hf on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-70b with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- meta-llama/Llama-2-70b-chat-hf with tensor parallelism 8x HPU, BF16 datatype with random or greedy sampling
Performance Tips
- We recommend running inference on Gaudi 2 with
block_size
of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine under-utilization (see Gaudi Architecture). - For max throughput on Llama 7B, we recommend running with batch size of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section.
Troubleshooting: Tweaking HPU Graphs
If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following the below:
-
Tweak
gpu_memory_utilization
knob. It will decrease the allocation of KV cache, leaving some headroom for
capturing graphs with larger batch size. By defaultgpu_memory_utilization
is set to 0.9. It attempts to allocate ~90% of HBM left for KV cache after short profiling run. Note that decreasing reduces the number of KV cache blocks you have available, and therefore reduces the effective maximum number of tokens you can handle at a given time. -
If this method is not efficient, you can disable
HPUGraph
completely. With HPU Graphs disabled, you are trading latency and throughput at lower batches for potentially higher throughput on higher batches. You can do that by adding--enforce-eager
flag to server (for online inference), or by passingenforce_eager=True
argument to LLM constructor (for offline inference).