diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..ddca95c0e8786 100644 --- a/Dockerfile +++ b/Dockerfile @@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.8 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base @@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6548d7a6684b2..ba8c614d205d2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 diff --git a/setup.py b/setup.py index 3768daf9d6fab..d9ba96b82329a 100644 --- a/setup.py +++ b/setup.py @@ -355,14 +355,18 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): requirements = _read_requirements("requirements-cuda.txt") - cuda_major = torch.version.cuda.split(".")[0] + cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: if "vllm-nccl-cu12" in req: - modified_requirements.append( - req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}")) - else: - modified_requirements.append(req) + req = req.replace("vllm-nccl-cu12", + f"vllm-nccl-cu{cuda_major}") + elif ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): + # vllm-flash-attn is built only for CUDA 12.1. + # Skip for other versions. + continue + modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("requirements-rocm.txt") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c2fec9153f2d8..4bad226512b69 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8f13f3525512b..36e162671f944 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,8 +3,8 @@ import flashinfer import torch -from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 34da0f6c6cdfc..f4446bac6b8d2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS try: - import flash_attn # noqa: F401 + import vllm_flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention-2 backend because the flash_attn " - "package is not found. Please install it for better performance.") + "Cannot use FlashAttention-2 backend because the vllm_flash_attn " + "package is not found. `pip install vllm-flash-attn` for better " + "performance.") return _Backend.XFORMERS backend_by_env_var = envs.VLLM_ATTENTION_BACKEND