Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement to enable quantization of Merge Models #623

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn.functional as F
import inspect
import os
from exllamav2.util import substitute_inf_with_max

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -964,37 +965,37 @@ def forward(self,
use_flash_attn = has_flash_attn and not cfg.no_flash_attn

if isinstance(attn_params, ExLlamaV2Attention.PagedParams):
return self.forward_paged(
return substitute_inf_with_max(self.forward_paged(
hidden_states,
cache,
attn_params,
loras = loras,
**kwargs
)
))

if self.is_tp:
if cache is not None and use_flash_attn:
return self.forward_tp(
return substitute_inf_with_max(self.forward_tp(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
**kwargs,
)
))
else:
# TODO: Can't use the optimized forward function because it writes directly to a fixed output
# tensor, and flash-attn currently has a bug that prevents that from working when q_len == 1
return self.forward_tp_old(
return substitute_inf_with_max(self.forward_tp_old(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
**kwargs,
)
))

if self.q_handle is None or intermediates:
return self.forward_torch(
Expand Down Expand Up @@ -1113,7 +1114,7 @@ def forward(self,
if cfg.arch.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)

return hidden_states
return substitute_inf_with_max(hidden_states)

def forward_tp(
self,
Expand Down Expand Up @@ -1428,9 +1429,9 @@ def forward_torch(
if intermediates:
return {"post_norm": post_norm,
"attn_output": attn_output,
"hidden_states": hidden_states}
"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)


def update_loras(self):
Expand Down
6 changes: 4 additions & 2 deletions exllamav2/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2

from exllamav2.util import substitute_inf_with_max

EMBEDDING_INDEX: int = 1000000

class ExLlamaV2Embedding(ExLlamaV2Module):
Expand Down Expand Up @@ -185,6 +187,6 @@ def forward(
hidden_states = ctx.copy_pinned(0, hidden_states)

if intermediates:
return {"hidden_states": hidden_states}
return {"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)
10 changes: 6 additions & 4 deletions exllamav2/headnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2

from exllamav2.util import substitute_inf_with_max

class ExLlamaV2HeadNorm(ExLlamaV2Module):

name: str = "LayerNorm"
Expand Down Expand Up @@ -122,9 +124,9 @@ def forward(
self.variance_epsilon)

if intermediates:
return {"hidden_states": hidden_states}
return {"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)

def forward_torch(
self,
Expand All @@ -146,8 +148,8 @@ def forward_torch(
hidden_states = hidden_states.to(input_dtype)

if intermediates:
return {"hidden_states": hidden_states}
return {"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)


10 changes: 6 additions & 4 deletions exllamav2/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2

from exllamav2.util import substitute_inf_with_max

class ExLlamaV2LayerNorm(ExLlamaV2Module):

name: str = "LayerNorm"
Expand Down Expand Up @@ -119,9 +121,9 @@ def forward(
hidden_states = norm.view(output_shape)

if intermediates:
return {"hidden_states": hidden_states}
return {"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)


def forward_torch(
Expand All @@ -139,8 +141,8 @@ def forward_torch(
hidden_states = self.layernorm(hidden_states)

if intermediates:
return {"hidden_states": hidden_states}
return {"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)


21 changes: 10 additions & 11 deletions exllamav2/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exllamav2.module import ExLlamaV2Module
from exllamav2.compat import safe_move_tensor
from exllamav2.tensor_p import BROADCAST_VC
from exllamav2.util import unpack_4bit, pack_4bit
from exllamav2.util import unpack_4bit, pack_4bit, substitute_inf_with_max
import gc

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -295,8 +295,7 @@ def temp_fwd_size(self) -> int:
max_len = self.model.config.max_input_len if self.max_out_len is None else \
min(self.max_out_len, self.model.config.max_input_len)
return self.out_features * max_len * self.model.config.max_batch_size * 4 + 128



def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -312,7 +311,7 @@ def forward(

if self.is_tp:
if self.out_features_tp:
return self.forward_tp(
return substitute_inf_with_max(self.forward_tp(
hidden_states,
cache,
attn_params,
Expand All @@ -322,9 +321,9 @@ def forward(
force_recons,
force_cuda,
**kwargs
)
))
elif self.in_features_tp:
return self.forward_tp_row(
return substitute_inf_with_max(self.forward_tp_row(
hidden_states,
cache,
attn_params,
Expand All @@ -334,7 +333,7 @@ def forward(
force_recons,
force_cuda,
**kwargs
)
))
else:
assert False, "Unitialized TP linear layer"

Expand All @@ -344,9 +343,9 @@ def forward(
hidden_states_out = loras[0].lm_head(hidden_states)

if intermediates:
return {"hidden_states": hidden_states_out}
return {"hidden_states": substitute_inf_with_max(hidden_states_out)}
else:
return hidden_states_out
return substitute_inf_with_max(hidden_states_out)

if self.q_handle is not None and not force_recons:

Expand Down Expand Up @@ -380,9 +379,9 @@ def forward(
hidden_states_out += torch.matmul(temp, lora_b)

if intermediates:
return {"hidden_states": hidden_states_out}
return {"hidden_states": substitute_inf_with_max(hidden_states_out)}
else:
return hidden_states_out
return substitute_inf_with_max(hidden_states_out)


def forward_tp(
Expand Down
11 changes: 6 additions & 5 deletions exllamav2/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.lora import ExLlamaV2Lora
from exllamav2.tensor_p import BROADCAST_ID, BROADCAST_RS
from exllamav2.util import substitute_inf_with_max
# from line_profiler import profile

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -288,15 +289,15 @@ def forward(
) -> torch.Tensor | dict[str: torch.Tensor]:

if self.is_tp:
return self.forward_tp(
return substitute_inf_with_max(self.forward_tp(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
**kwargs
)
))

cfg = self.model.config

Expand All @@ -319,7 +320,7 @@ def forward(
if cfg.arch.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)

return hidden_states
return substitute_inf_with_max(hidden_states)


# @profile
Expand Down Expand Up @@ -457,9 +458,9 @@ def forward_torch(
if intermediates:
return {"post_norm": post_norm,
"pre_down": y,
"hidden_states": hidden_states}
"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)


def update_loras(self):
Expand Down
9 changes: 5 additions & 4 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os, sys

from exllamav2.architecture import RopeStyle
from exllamav2.util import substitute_inf_with_max

min_version = (3, 8)
if sys.version_info < min_version:
Expand Down Expand Up @@ -820,9 +821,9 @@ def forward(
if abort_event and abort_event.is_set(): return

if "last_state" in result:
return result.get("logits"), result["last_state"]
return substitute_inf_with_max(result.get("logits")), substitute_inf_with_max(result["last_state"])
else:
return result.get("logits")
return substitute_inf_with_max(result.get("logits"))

# Confirm that the input fits within the allocated cache space

Expand Down Expand Up @@ -893,9 +894,9 @@ def forward(
last_state = r.get("last_state")

if last_state is None:
return result
return substitute_inf_with_max(result)
else:
return result, last_state
return substitute_inf_with_max(result), substitute_inf_with_max(last_state)


@torch.inference_mode()
Expand Down
3 changes: 2 additions & 1 deletion exllamav2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from exllamav2.config import ExLlamaV2Config
from exllamav2.fasttensors import STFile
from exllamav2.compat import safe_move_tensor
from exllamav2.util import substitute_inf_with_max

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -282,4 +283,4 @@ def forward(self, hidden_states, *args, **kwargs):
hidden_states = self.post_forward(hidden_states, *args, **kwargs)
hidden_states = safe_move_tensor(hidden_states, dev)

return hidden_states
return substitute_inf_with_max(hidden_states)
7 changes: 4 additions & 3 deletions exllamav2/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from exllamav2.linear import ExLlamaV2Linear
from exllamav2.lora import ExLlamaV2Lora
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.util import substitute_inf_with_max

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -244,7 +245,7 @@ def forward(
# ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]), pass_loras, pass_lora_temp)
ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]))

return hidden_states
return substitute_inf_with_max(hidden_states)


def forward_torch(
Expand Down Expand Up @@ -313,9 +314,9 @@ def forward_torch(

if intermediates:
result["hidden_states"] = final_hidden_states
return result
return substitute_inf_with_max(result)
else:
return final_hidden_states
return substitute_inf_with_max(final_hidden_states)


def update_loras(self):
Expand Down
3 changes: 2 additions & 1 deletion exllamav2/parallel_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from exllamav2.lora import ExLlamaV2Lora
from exllamav2.layernorm import ExLlamaV2LayerNorm
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.util import substitute_inf_with_max

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -119,7 +120,7 @@ def forward(
b = self.mlp.forward(b, cache, attn_params, past_len, intermediates, loras, **kwargs)
hidden_states += a
hidden_states += b
return hidden_states
return substitute_inf_with_max(hidden_states)


def forward_interm(
Expand Down
5 changes: 3 additions & 2 deletions exllamav2/pos_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from exllamav2.module import ExLlamaV2Module
from exllamav2.attn import ExLlamaV2Attention
from exllamav2.compat import safe_move_tensor
from exllamav2.util import substitute_inf_with_max

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -118,6 +119,6 @@ def forward(
hidden_states[b, target_a:target_b] += emb_slice

if intermediates:
return {"hidden_states": hidden_states}
return {"hidden_states": substitute_inf_with_max(hidden_states)}
else:
return hidden_states
return substitute_inf_with_max(hidden_states)
Loading