From 10a591666ffa399913557986cb8ac2806d9c0f45 Mon Sep 17 00:00:00 2001 From: vince62s Date: Thu, 14 Dec 2023 18:01:05 +0100 Subject: [PATCH 1/8] first try new cache --- onmt/decoders/transformer.py | 37 ++++++++++++++++++++++++++++--- onmt/model_builder.py | 1 + onmt/modules/multi_headed_attn.py | 28 ++++++++++++++--------- onmt/translate/translator.py | 3 +++ 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 05797d36c3..a632590a70 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -43,6 +43,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + max_length=256, ): """ Args: @@ -260,6 +261,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + max_length=256, ): """ Args: @@ -289,6 +291,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + max_length=max_length, ) self.context_attn = MultiHeadedAttention( heads, @@ -448,6 +451,7 @@ def from_opt(cls, opt, embeddings): else 1, sliding_window=opt.sliding_window, rotary_interleave=opt.rotary_interleave, + max_length=opt.max_length, ) def init_state(self, src, enc_out, enc_final_hs): @@ -556,6 +560,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + max_length=256, ): super(TransformerDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -587,6 +592,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + max_length=max_length, ) for i in range(num_layers) ] @@ -823,6 +829,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + max_length=256, ): super(TransformerLMDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -853,10 +860,18 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + max_length=max_length, ) for i in range(num_layers) ] ) + if num_kv == 0: + self.num_kv_heads = heads + + else: + self.num_kv_heads = num_kv + self.dimperhead = d_model // heads + self.max_length = max_length def init_state(self, src=None, enc_out=None, enc_final_hs=None): super(TransformerLMDecoder, self).init_state(None, None, None) @@ -916,9 +931,25 @@ def _init_cache(self, tgt=None): else: layer.self_attn.layer_cache = ( True, - { - "keys": torch.tensor([], device=tgt.device), - "values": torch.tensor([], device=tgt.device), + { # [batchsize x heads x length x dimperhead] + "keys": torch.zeros( + [ + tgt.size(0), + self.num_kv_heads, + self.max_length + tgt.size(1), + self.dimperhead, + ], + device=tgt.device, + ).half(), + "values": torch.zeros( + [ + tgt.size(0), + self.num_kv_heads, + self.max_length + tgt.size(1), + self.dimperhead, + ], + device=tgt.device, + ).half(), }, ) if hasattr(layer.self_attn, "rope"): diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 37831c50d1..e58e63ea70 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -145,6 +145,7 @@ def load_test_model(opt, device_id=0, model_path=None): 0.0 # required to force no dropout at inference with flash ) + model_opt.max_length = opt.max_length model = build_base_model(model_opt, vocabs) precision = torch.float32 diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 43eb7e8536..29ee912b97 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -416,9 +416,9 @@ def forward( key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) + start_pos = step + seqlen = query.size(2) if self.max_relative_positions == -1: # Rotary Embeddings - start_pos = step - seqlen = query.size(2) if seqlen > self.rope.size(0): self.rope = rotaryembeddings( self.dim_per_head, maxseqlen=(seqlen + 2048) @@ -428,15 +428,21 @@ def forward( query, key, rope, interleave=self.rotary_interleave ) - if self.layer_cache[1]["keys"].numel() != 0: - key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) - value = torch.cat((self.layer_cache[1]["values"], value), dim=2) - if sliding_window > 0 and key.size(2) > sliding_window: - key = key[:, :, 1:, :] - value = value[:, :, 1:, :] - - self.layer_cache[1]["keys"] = key - self.layer_cache[1]["values"] = value + self.layer_cache[1]["keys"][ + :, :, start_pos : start_pos + seqlen, : + ] = key + self.layer_cache[1]["values"][ + :, :, start_pos : start_pos + seqlen, : + ] = value + + """ + if sliding_window > 0 and key.size(2) > sliding_window: + key = key[:, :, 1:, :] + value = value[:, :, 1:, :] + """ + + key = self.layer_cache[1]["keys"][:, :, : start_pos + seqlen, :] + value = self.layer_cache[1]["values"][:, :, : start_pos + seqlen, :] elif self.attn_type == "context": query = self.linear_query(query) query = shape(query, self.dim_per_head) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index a27fa72ae0..2a63b1ce95 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -1102,6 +1102,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): ) # (4) Begin decoding step by step: + beg_time = time() for step in range(decode_strategy.max_length): decoder_input = ( src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1) @@ -1139,6 +1140,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): if parallel_paths > 1 or any_finished: # select indexes in model state/cache self.model.decoder.map_state(lambda state, dim: state[select_indices]) + if step == 0: + print("step0 time: ", time() - beg_time) return self.report_results( gold_score, From aa2ae4e05eae5bafe040807775a02db384b56d66 Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 15:13:45 +0100 Subject: [PATCH 2/8] use flash_attn_with_kvcache --- onmt/decoders/transformer.py | 81 +++---------------- onmt/model_builder.py | 1 - onmt/modules/multi_headed_attn.py | 125 ++++++++++++++++++++++-------- onmt/modules/position_ffn.py | 9 ++- onmt/modules/rmsnorm.py | 27 +++++-- onmt/translate/translator.py | 2 +- 6 files changed, 134 insertions(+), 111 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 3e919a5c75..25fe60c8c4 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -10,11 +10,7 @@ from onmt.modules.position_ffn import PositionwiseFeedForward from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask - -try: - from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: - from onmt.modules.rmsnorm import RMSNorm +from onmt.modules.rmsnorm import RMSNorm class TransformerDecoderLayerBase(nn.Module): @@ -43,7 +39,6 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, - max_length=256, ): """ Args: @@ -262,7 +257,6 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, - max_length=256, ): """ Args: @@ -292,7 +286,6 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, - max_length=max_length, ) self.context_attn = MultiHeadedAttention( heads, @@ -453,7 +446,6 @@ def from_opt(cls, opt, embeddings): else 1, sliding_window=opt.sliding_window, rotary_interleave=opt.rotary_interleave, - max_length=opt.max_length, ) def init_state(self, src, enc_out, enc_final_hs): @@ -573,7 +565,6 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, - max_length=256, ): super(TransformerDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -605,18 +596,10 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, - max_length=max_length, ) for i in range(num_layers) ] ) - if num_kv == 0: - self.num_kv_heads = heads - - else: - self.num_kv_heads = num_kv - self.dimperhead = d_model // heads - self.max_length = max_length def detach_state(self): self.state["src"] = self.state["src"].detach() @@ -631,7 +614,7 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs): if enc_out is None: enc_out = self.embeddings(tgt) if step == 0: - self._init_cache(tgt, enc_out) + self._init_cache(enc_out) elif step is None: for layer in self.transformer_layers: if isinstance(layer.self_attn, AverageAttention): @@ -686,7 +669,7 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs): # TODO change the way attns is returned dict => list or tuple (onnx) return dec_out, attns - def _init_cache(self, tgt, enc_out): + def _init_cache(self, enc_out): batch_size = enc_out.size(0) depth = enc_out.size(-1) @@ -709,28 +692,9 @@ def _init_cache(self, tgt, enc_out): else: layer.self_attn.layer_cache = ( True, - { # [batchsize x heads x length x dimperhead] - "keys": torch.zeros( - [ - tgt.size(0), - self.num_kv_heads, - self.max_length + tgt.size(1), - self.dimperhead, - ], - device=tgt.device, - ).half(), - "values": torch.zeros( - [ - tgt.size(0), - self.num_kv_heads, - self.max_length + tgt.size(1), - self.dimperhead, - ], - device=tgt.device, - ).half(), - "key_pad_mask": tgt[:, :, 0] - .eq(self.embeddings.word_padding_idx) - .unsqueeze(1), + { + "keys": torch.tensor([], device=enc_out.device), + "values": torch.tensor([], device=enc_out.device), }, ) if hasattr(layer.self_attn, "rope"): @@ -868,7 +832,6 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, - max_length=256, ): super(TransformerLMDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -899,18 +862,10 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, - max_length=max_length, ) for i in range(num_layers) ] ) - if num_kv == 0: - self.num_kv_heads = heads - - else: - self.num_kv_heads = num_kv - self.dimperhead = d_model // heads - self.max_length = max_length def init_state(self, src=None, enc_out=None, enc_final_hs=None): super(TransformerLMDecoder, self).init_state(None, None, None) @@ -974,25 +929,9 @@ def _init_cache(self, tgt=None): else: layer.self_attn.layer_cache = ( True, - { # [batchsize x heads x length x dimperhead] - "keys": torch.zeros( - [ - tgt.size(0), - self.num_kv_heads, - self.max_length + tgt.size(1), - self.dimperhead, - ], - device=tgt.device, - ).half(), - "values": torch.zeros( - [ - tgt.size(0), - self.num_kv_heads, - self.max_length + tgt.size(1), - self.dimperhead, - ], - device=tgt.device, - ).half(), + { + "keys": torch.tensor([], device=tgt.device), + "values": torch.tensor([], device=tgt.device), "key_pad_mask": tgt[:, :, 0] .eq(self.embeddings.word_padding_idx) .unsqueeze(1), @@ -1000,3 +939,5 @@ def _init_cache(self, tgt=None): ) if hasattr(layer.self_attn, "rope"): layer.self_attn.rope = layer.self_attn.rope.to(tgt.device) + layer.self_attn.cos = layer.self_attn.cos.to(tgt.device) + layer.self_attn.sin = layer.self_attn.sin.to(tgt.device) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 820f4f2500..3ce19aa597 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -148,7 +148,6 @@ def load_test_model(opt, device_id=0, model_path=None): 0.0 # required to force no dropout at inference with flash ) - model_opt.max_length = opt.max_length model = build_base_model(model_opt, vocabs) precision = torch.float32 diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 6d49662bf9..8b60ca4ad4 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -11,7 +11,6 @@ from torch.distributed import all_reduce from importlib import import_module - # Help functions for Rotary Embeddings # https://arxiv.org/pdf/2104.09864.pdf # too convoluted to make maxseqlen a parameter. @@ -353,8 +352,17 @@ def __init__( if max_relative_positions == -1: # rotary embeddings self.rope = rotaryembeddings(self.dim_per_head) + self.cos = ( + self.rope[:, : self.rope.size(1) // 2].real.contiguous().half() + ) + self.sin = ( + self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half() + ) self.rotary_interleave = rotary_interleave - + else: + self.cos = None + self.sin = None + self.rotary_interleave = None if max_relative_positions == -2: # alibi positional bias self.alibi = AlibiPositionalBias(head_count) @@ -367,6 +375,9 @@ def __init__( and torch.cuda.get_device_capability()[0] >= 8 ): self.flash_attn_func = getattr(flash_pack, "flash_attn_func") + self.flash_attn_with_kvcache = getattr( + flash_pack, "flash_attn_with_kvcache" + ) self.flash2 = True except ImportError: self.flash2 = False @@ -422,36 +433,88 @@ def forward( start_pos = step seqlen = query.size(2) - print(start_pos, seqlen) - if self.max_relative_positions == -1: # Rotary Embeddings - if seqlen > self.rope.size(0): - self.rope = rotaryembeddings( - self.dim_per_head, maxseqlen=(seqlen + 2048) - ).to(self.rope.device) - rope = self.rope[start_pos : start_pos + seqlen] - query, key = apply_rotary_emb( - query, key, rope, interleave=self.rotary_interleave - ) - print(key.size()) - self.layer_cache[1]["keys"][ - :, :, start_pos : start_pos + seqlen, : - ] = key - print(self.layer_cache[1]["keys"].size()) - self.layer_cache[1]["values"][ - :, :, start_pos : start_pos + seqlen, : - ] = value - - """ - if sliding_window > 0 and key.size(2) > sliding_window: - key = key[:, :, 1:, :] - value = value[:, :, 1:, :] - """ - - key = self.layer_cache[1]["keys"][:, :, : start_pos + seqlen, :] - print(key.size()) - print() - value = self.layer_cache[1]["values"][:, :, : start_pos + seqlen, :] + if ( + step == 0 + or not self.flash2 + or self.max_relative_positions not in [0, -1] + or query.size(0) > 8 + ): + if self.max_relative_positions == -1: # Rotary Embeddings + if seqlen > self.rope.size(0): + self.rope = rotaryembeddings( + self.dim_per_head, maxseqlen=(seqlen + 2048) + ).to(self.rope.device) + rope = self.rope[start_pos : start_pos + seqlen] + query, key = apply_rotary_emb( + query, key, rope, interleave=self.rotary_interleave + ) + + if self.layer_cache[1]["keys"].numel() != 0: + key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) + value = torch.cat((self.layer_cache[1]["values"], value), dim=2) + if sliding_window > 0 and key.size(2) > sliding_window: + key = key[:, :, 1:, :] + value = value[:, :, 1:, :] + + self.layer_cache[1]["keys"] = key + self.layer_cache[1]["values"] = value + + else: + if self.max_relative_positions == -1: # Rotary Embeddings + if seqlen > self.rope.size(0): + self.rope = rotaryembeddings( + self.dim_per_head, maxseqlen=(seqlen + 2048) + ).to(self.rope.device) + self.cos = ( + self.rope[:, : self.rope.size(1) // 2] + .real.contiguous() + .half() + ) + self.sin = ( + self.rope[:, : self.rope.size(1) // 2] + .imag.contiguous() + .half() + ) + if start_pos >= self.layer_cache[1]["keys"].size(2): + self.layer_cache[1]["keys"] = torch.cat( + [ + self.layer_cache[1]["keys"], + torch.zeros( + self.layer_cache[1]["keys"].shape[:-2] + + (32,) + + self.layer_cache[1]["keys"].shape[-1:], + device=query.device, + ).half(), + ], + dim=-2, + ) + self.layer_cache[1]["values"] = torch.cat( + [ + self.layer_cache[1]["values"], + torch.zeros( + self.layer_cache[1]["values"].shape[:-2] + + (32,) + + self.layer_cache[1]["values"].shape[-1:], + device=query.device, + ).half(), + ], + dim=-2, + ) + context = self.flash_attn_with_kvcache( + query.transpose(1, 2), + self.layer_cache[1]["keys"].transpose(1, 2), + self.layer_cache[1]["values"].transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + rotary_cos=self.cos, + rotary_sin=self.sin, + cache_seqlens=step, + rotary_interleaved=self.rotary_interleave, + ).transpose(1, 2) + attn_output = self.final_linear(unshape(context)) + return attn_output, None + elif self.attn_type == "context": query = self.linear_query(query) query = shape(query, self.dim_per_head) diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index 7ef207228d..ccf81d5256 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -5,9 +5,14 @@ from torch.utils.checkpoint import checkpoint try: - from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: + import awq_inference_engine from onmt.modules.rmsnorm import RMSNorm +except ImportError: + try: + from apex.normalization import FusedRMSNorm as RMSNorm + except ImportError: + from onmt.modules.rmsnorm import RMSNorm + from torch.nn.utils import skip_init from torch.distributed import all_reduce diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 50d05529a4..0d6f33d2e6 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -3,11 +3,19 @@ import torch import torch.nn as nn +try: + import awq_inference_engine + + AWQ_INFERENCE_ENGINE = True +except: + AWQ_INFERENCE_ENGINE = False + class RMSNorm(torch.nn.Module): """RMSNorm: https://arxiv.org/abs/1910.07467 Args: - hidden_size (int): layer hidden_sizeension. + hidden_size (int): layer hidden_size dimension. + eps: variance epsilon. """ def __init__(self, hidden_size: int, eps: float = 1e-6): @@ -16,8 +24,15 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, hidden_states): - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + if AWQ_INFERENCE_ENGINE: + output = torch.empty_like(hidden_states) + awq_inference_engine.layernorm_forward_cuda( + hidden_states, self.weight, output, self.eps + ) + return output + else: + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.to(self.weight.dtype) + return hidden_states * self.weight diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 56be30ec0c..c576d7cda9 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -1113,7 +1113,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): batch, src_len=decode_strategy.src_len, src_map=src_map, - step=step if step == 0 else step + max(src_len.tolist()) - 1, + step=step if step == 0 else step + max(src_len.tolist()), batch_offset=decode_strategy.batch_offset, ) From 32af49922e8210c0f392a052623fe568977e14dc Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 16:10:02 +0100 Subject: [PATCH 3/8] fix flake --- onmt/modules/multi_headed_attn.py | 1 + onmt/modules/position_ffn.py | 11 +---------- onmt/modules/rmsnorm.py | 4 ++-- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 8b60ca4ad4..6e7ebebd34 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -439,6 +439,7 @@ def forward( or not self.flash2 or self.max_relative_positions not in [0, -1] or query.size(0) > 8 + or query.dtype != torch.float16 ): if self.max_relative_positions == -1: # Rotary Embeddings if seqlen > self.rope.size(0): diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index ccf81d5256..4fa85a6bb9 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -3,16 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint - -try: - import awq_inference_engine - from onmt.modules.rmsnorm import RMSNorm -except ImportError: - try: - from apex.normalization import FusedRMSNorm as RMSNorm - except ImportError: - from onmt.modules.rmsnorm import RMSNorm - +from onmt.modules.rmsnorm import RMSNorm from torch.nn.utils import skip_init from torch.distributed import all_reduce diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 0d6f33d2e6..a65ceda01a 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -6,8 +6,8 @@ try: import awq_inference_engine - AWQ_INFERENCE_ENGINE = True -except: + AWQ_INFERENCE_ENGINE = False +except ImportError: AWQ_INFERENCE_ENGINE = False From c920d65d3d8736783a2175ff5d50a123fd3e3d93 Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 16:28:49 +0100 Subject: [PATCH 4/8] fix --- onmt/modules/multi_headed_attn.py | 2 ++ onmt/modules/rmsnorm.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 6e7ebebd34..ce86eefdc7 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -514,6 +514,8 @@ def forward( rotary_interleaved=self.rotary_interleave, ).transpose(1, 2) attn_output = self.final_linear(unshape(context)) + if self.parallel_gpu > 1: + all_reduce(attn_output) return attn_output, None elif self.attn_type == "context": diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index a65ceda01a..2bf1d9cc4a 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -6,7 +6,7 @@ try: import awq_inference_engine - AWQ_INFERENCE_ENGINE = False + AWQ_INFERENCE_ENGINE = True except ImportError: AWQ_INFERENCE_ENGINE = False From 3f81b8f5b975f058534a0b080f7144e6e97a973d Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 16:52:55 +0100 Subject: [PATCH 5/8] patch rmsnorm for multiexperts --- onmt/modules/rmsnorm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 2bf1d9cc4a..53df70988e 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -26,9 +26,13 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): def forward(self, hidden_states): if AWQ_INFERENCE_ENGINE: output = torch.empty_like(hidden_states) + if hidden_states.dim() == 2: # patch for multi experts + hidden_states = hidden_states.unsqueeze(0) awq_inference_engine.layernorm_forward_cuda( hidden_states, self.weight, output, self.eps ) + if hidden_states.dim() == 2: # patch for multi experts + output = output.unsqueeze(0) return output else: hidden_states = hidden_states.to(torch.float32) From ea900d88daac9de3a6823ee2b4e288299a03416e Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 16:57:35 +0100 Subject: [PATCH 6/8] black is black --- onmt/modules/rmsnorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 53df70988e..3d8515aa5f 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -26,12 +26,12 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): def forward(self, hidden_states): if AWQ_INFERENCE_ENGINE: output = torch.empty_like(hidden_states) - if hidden_states.dim() == 2: # patch for multi experts + if hidden_states.dim() == 2: # patch for multi experts hidden_states = hidden_states.unsqueeze(0) awq_inference_engine.layernorm_forward_cuda( hidden_states, self.weight, output, self.eps ) - if hidden_states.dim() == 2: # patch for multi experts + if hidden_states.dim() == 2: # patch for multi experts output = output.unsqueeze(0) return output else: From d0ec7a807d171ab48e0b0987f442ad35b5b8a769 Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 26 Dec 2023 10:37:20 +0100 Subject: [PATCH 7/8] rope theta as an option --- onmt/decoders/transformer.py | 12 ++++++++++++ onmt/encoders/transformer.py | 6 ++++++ onmt/modules/moe.py | 4 ++++ onmt/modules/multi_headed_attn.py | 21 ++++++++++++++++----- onmt/opts.py | 7 +++++++ onmt/utils/distributed.py | 2 +- tools/convert_HF_llamalike.py | 7 +++++++ 7 files changed, 53 insertions(+), 6 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index f8b17ee443..557c566e20 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -40,6 +40,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -85,6 +86,7 @@ def __init__( sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ super(TransformerDecoderLayerBase, self).__init__() @@ -96,6 +98,7 @@ def __init__( max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, attn_type="self", self_attn_type=self_attn_type, add_qkvbias=add_qkvbias, @@ -276,6 +279,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -307,6 +311,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, ) @@ -469,6 +474,7 @@ def from_opt(cls, opt, embeddings): else 1, sliding_window=opt.sliding_window, rotary_interleave=opt.rotary_interleave, + rotary_theta=opt.rotary_theta, num_experts=opt.num_experts, num_experts_per_tok=opt.num_experts_per_tok, ) @@ -559,6 +565,7 @@ class TransformerDecoder(TransformerDecoderBase): parallel_gpu (int): Number of gpu for tensor parallelism sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ def __init__( @@ -590,6 +597,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -623,6 +631,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, ) @@ -830,6 +839,7 @@ class TransformerLMDecoder(TransformerDecoderBase): parallel_gpu (int): Number of gpu for tensor parallelism sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ def __init__( @@ -861,6 +871,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -893,6 +904,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, ) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 12998957dc..35d32ce709 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -40,6 +40,7 @@ class TransformerEncoderLayer(nn.Module): parallel_gpu (int): Number of gpu for tensor parallelism rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ def __init__( @@ -61,6 +62,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, rotary_interleave=True, + rotary_theta=1e4, ): super(TransformerEncoderLayer, self).__init__() @@ -72,6 +74,7 @@ def __init__( max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, attn_type="self", add_qkvbias=add_qkvbias, num_kv=num_kv, @@ -177,6 +180,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, rotary_interleave=True, + rotary_theta=1e4, ): super(TransformerEncoder, self).__init__() @@ -201,6 +205,7 @@ def __init__( use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, ) for i in range(num_layers) ] @@ -239,6 +244,7 @@ def from_opt(cls, opt, embeddings): if opt.parallel_mode == "tensor_parallel" else 1, rotary_interleave=opt.rotary_interleave, + rotary_theta=opt.rotary_theta, ) def forward(self, src, src_len=None): diff --git a/onmt/modules/moe.py b/onmt/modules/moe.py index 2e1c959636..f356130d97 100644 --- a/onmt/modules/moe.py +++ b/onmt/modules/moe.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from onmt.modules.position_ffn import PositionwiseFeedForward +from torch.distributed import all_reduce class MoE(nn.Module): @@ -40,12 +41,15 @@ def __init__( ) self.gate = nn.Linear(d_model, num_experts, bias=False) self.num_experts_per_tok = num_experts_per_tok + self.parallel_gpu = parallel_gpu def forward(self, x): orig_shape = x.shape x = x.view(-1, x.shape[-1]) scores = self.gate(x) + if self.parallel_gpu > 1: + all_reduce(scores) expert_weights, expert_indices = torch.topk( scores, self.num_experts_per_tok, dim=-1 ) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index ce86eefdc7..251e55814f 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -257,6 +257,7 @@ def __init__( max_relative_positions: int = 0, relative_positions_buckets: int = 0, rotary_interleave: bool = True, + rotary_theta: int = 1e4, attn_type: str = None, self_attn_type: str = None, add_qkvbias=False, @@ -351,7 +352,7 @@ def __init__( self.relative_attention_bias = None if max_relative_positions == -1: # rotary embeddings - self.rope = rotaryembeddings(self.dim_per_head) + self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta) self.cos = ( self.rope[:, : self.rope.size(1) // 2].real.contiguous().half() ) @@ -359,6 +360,7 @@ def __init__( self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half() ) self.rotary_interleave = rotary_interleave + self.rotary_theta = rotary_theta else: self.cos = None self.sin = None @@ -438,13 +440,15 @@ def forward( step == 0 or not self.flash2 or self.max_relative_positions not in [0, -1] - or query.size(0) > 8 + or query.size(0) > 128 or query.dtype != torch.float16 ): if self.max_relative_positions == -1: # Rotary Embeddings if seqlen > self.rope.size(0): self.rope = rotaryembeddings( - self.dim_per_head, maxseqlen=(seqlen + 2048) + self.dim_per_head, + maxseqlen=(seqlen + 2048), + base=self.rotary_theta, ).to(self.rope.device) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( @@ -465,7 +469,9 @@ def forward( if self.max_relative_positions == -1: # Rotary Embeddings if seqlen > self.rope.size(0): self.rope = rotaryembeddings( - self.dim_per_head, maxseqlen=(seqlen + 2048) + self.dim_per_head, + maxseqlen=(seqlen + 2048), + base=self.rotary_theta, ).to(self.rope.device) self.cos = ( self.rope[:, : self.rope.size(1) // 2] @@ -502,6 +508,9 @@ def forward( ], dim=-2, ) + if sliding_window > 0 and key.size(2) > sliding_window: + self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][:, :, 1:, :] + self.layer_cache[1]["values"] = self.layer_cache[1]["values"][:, :, 1:, :] context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2), @@ -561,7 +570,9 @@ def forward( seqlen = query.size(2) if seqlen > self.rope.size(0): self.rope = rotaryembeddings( - self.dim_per_head, maxseqlen=(seqlen + 2048) + self.dim_per_head, + maxseqlen=(seqlen + 2048), + base=self.rotary_theta, ).to(self.rope.device) rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb( diff --git a/onmt/opts.py b/onmt/opts.py index 1fa3305fe0..c0237f377f 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -880,6 +880,13 @@ def model_opts(parser): "True = default Llama from Meta (original)" "False = used by all Hugging face models", ) + group.add( + "--rotary_theta", + "-rotary_theta", + type=int, + default=10000, + help="Rotary theta base length" "1e4 for Llama2.Mistral" "1e6 for Mixtral", + ) group.add( "--heads", "-heads", diff --git a/onmt/utils/distributed.py b/onmt/utils/distributed.py index e6779c397f..cb0f55c4f9 100644 --- a/onmt/utils/distributed.py +++ b/onmt/utils/distributed.py @@ -212,7 +212,7 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result): device_id=device_id, ) scores, preds = translator._translate( - infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug + infer_iter, infer_iter.transforms, opt.attn_debug, opt.align_debug ) queue_result.put(scores) queue_result.put(preds) diff --git a/tools/convert_HF_llamalike.py b/tools/convert_HF_llamalike.py index 8fe9d4efc6..32ace36aee 100755 --- a/tools/convert_HF_llamalike.py +++ b/tools/convert_HF_llamalike.py @@ -223,8 +223,14 @@ def __init__(self, model_path: str): norm_eps = config["layer_norm_epsilon"] else: norm_eps = 1e-6 + if "rope_theta" in config.keys(): + rope_theta = config["rope_theta"] + else: + rope_theta = 1e4 if "sliding_window" in config.keys(): sliding_window = config["sliding_window"] + if sliding_window is None: + sliding_window = 4096 else: sliding_window = 0 @@ -633,6 +639,7 @@ def get_weight(checkpoint, tensor_name): self_attn_type="scaled-dot", max_relative_positions=-1, rotary_interleave=False, + rotary_theta=rope_theta, heads=heads, sliding_window=sliding_window, transformer_ff=transformer_ff, From 17373ca2b6f10f537d960f987866bbc5f4622f67 Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 26 Dec 2023 10:41:50 +0100 Subject: [PATCH 8/8] black --- onmt/modules/multi_headed_attn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 251e55814f..d71ffbc460 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -509,8 +509,12 @@ def forward( dim=-2, ) if sliding_window > 0 and key.size(2) > sliding_window: - self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][:, :, 1:, :] - self.layer_cache[1]["values"] = self.layer_cache[1]["values"][:, :, 1:, :] + self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ + :, :, 1:, : + ] + self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ + :, :, 1:, : + ] context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2),