Skip to content

Commit

Permalink
[Misc] Use vllm-flash-attn instead of flash-attn (vllm-project#4686)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored May 8, 2024
1 parent 230c4b3 commit 89579a2
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 31 deletions.
21 changes: 0 additions & 21 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip cache remove vllm_nccl*
#################### EXTENSION Build IMAGE ####################

#################### FLASH_ATTENTION Build IMAGE ####################
FROM dev as flash-attn-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# flash attention version
ARG flash_attn_version=v2.5.8
ENV FLASH_ATTN_VERSION=${flash_attn_version}

WORKDIR /usr/src/flash-attention-v2

# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
--no-build-isolation --no-deps --no-cache-dir

#################### FLASH_ATTENTION Build IMAGE ####################

#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
Expand All @@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose

RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
--mount=type=cache,target=/root/.cache/pip \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
#################### vLLM installation IMAGE ####################


Expand Down
1 change: 1 addition & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0
14 changes: 9 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,18 @@ def _read_requirements(filename: str) -> List[str]:

if _is_cuda():
requirements = _read_requirements("requirements-cuda.txt")
cuda_major = torch.version.cuda.split(".")[0]
cuda_major, cuda_minor = torch.version.cuda.split(".")
modified_requirements = []
for req in requirements:
if "vllm-nccl-cu12" in req:
modified_requirements.append(
req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}"))
else:
modified_requirements.append(req)
req = req.replace("vllm-nccl-cu12",
f"vllm-nccl-cu{cuda_major}")
elif ("vllm-flash-attn" in req
and not (cuda_major == "12" and cuda_minor == "1")):
# vllm-flash-attn is built only for CUDA 12.1.
# Skip for other versions.
continue
modified_requirements.append(req)
requirements = modified_requirements
elif _is_hip():
requirements = _read_requirements("requirements-rocm.txt")
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, Tuple, Type

import torch
from flash_attn import flash_attn_varlen_func
from vllm_flash_attn import flash_attn_varlen_func

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import flashinfer
import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down
7 changes: 4 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
return _Backend.XFORMERS

try:
import flash_attn # noqa: F401
import vllm_flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
"package is not found. `pip install vllm-flash-attn` for better "
"performance.")
return _Backend.XFORMERS

backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
Expand Down

0 comments on commit 89579a2

Please sign in to comment.