Skip to content

Commit

Permalink
merge bert-debug
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser committed Oct 23, 2024
1 parent faa29eb commit db56b52
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 25 deletions.
6 changes: 6 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
Expand Down Expand Up @@ -903,6 +906,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down Expand Up @@ -692,6 +695,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ Text Generation
* - :code:`InternVLChatModel`
- InternVL2
- T + I\ :sup:`E+`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
- :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
- ✅︎
* - :code:`LlavaForConditionalGeneration`
Expand Down
15 changes: 13 additions & 2 deletions tests/models/decoder_only/vision_language/test_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@
"cherry_blossom":
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
})
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in short.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501

models = [
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL2-2B",
# NOTE: Mono-InternVL-2B doesn't work with fp16,
# it will result NaN during inference.
# See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9
"OpenGVLab/Mono-InternVL-2B",
# Broken due to outdated implementation of Phi-3
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
# "OpenGVLab/InternVL2-4B",
]
target_dtype = "bfloat16"


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
Expand All @@ -52,9 +57,15 @@ def generate(

input_embeds = input_embeds.reshape(B, N, C)

outputs = self.language_model.generate(
forward_kwargs = dict(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
)
if getattr(self, "use_visual_token_mask", False):
visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype)
forward_kwargs["visual_token_mask"] = visual_token_mask
outputs = self.language_model.generate(
**forward_kwargs,
**generate_kwargs,
)

Expand Down
36 changes: 36 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,42 @@
from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType


@pytest.mark.parametrize(("model_id", "expected_task"), [
("facebook/opt-125m", "generate"),
("intfloat/e5-mistral-7b-instruct", "embedding"),
])
def test_auto_task(model_id, expected_task):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)

assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
ModelConfig(
model_id,
task=bad_task,
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)



@pytest.mark.parametrize(("model_id", "expected_task"), [
("facebook/opt-125m", "generate"),
("intfloat/e5-mistral-7b-instruct", "embedding"),
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [32, 64, 80, 96, 112, 128, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
return [32, 64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "bias", False))
bias=getattr(self.config, "eagle_fc_bias", False))

self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
Expand Down
31 changes: 31 additions & 0 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,37 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
return embeddings


class InternVisionPatchModel(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)

def get_input_embeddings(self):
return self.embeddings

def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')

if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')

return hidden_states


class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down
166 changes: 166 additions & 0 deletions vllm/model_executor/models/internlm2_ve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2ForCausalLM,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors

from .utils import make_layers


class InternLM2VEDecoderLayer(nn.Module):

def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.feed_forward_ve = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)

# Fully Connected
hidden_states, residual = self.ffn_norm(hidden_states, residual)
if visual_token_mask is not None and visual_token_mask.any():
visual_token_mask = visual_token_mask.repeat(
1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask
hidden_states[visual_token_mask] = self.feed_forward_ve(
hidden_states[visual_token_mask].reshape(
-1, self.hidden_size)).flatten()
if text_token_mask.any():
hidden_states[text_token_mask] = self.feed_forward(
hidden_states[text_token_mask].reshape(
-1, self.hidden_size)).flatten()
else:
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual


class InternLM2VEModel(InternLM2Model):

def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
visual_token_mask=visual_token_mask,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class InternLM2VEForCausalLM(InternLM2ForCausalLM):

def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config, cache_config, quant_config)
Loading

0 comments on commit db56b52

Please sign in to comment.