Skip to content

Commit

Permalink
left padding for LM inference (#2525)
Browse files Browse the repository at this point in the history
* left padding for LM inference
  • Loading branch information
l-k-11235 authored Nov 24, 2023
1 parent efd316d commit 78c8908
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 33 deletions.
19 changes: 15 additions & 4 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def _forward(self, *args, **kwargs):

def _compute_dec_mask(self, tgt_pad_mask, future):
tgt_len = tgt_pad_mask.size(-1)
if not future: # apply future_mask, result mask in (B, T, T)
if not future:
# Add triangular future_mask and pad_mask, result mask in (B, T, T).
future_mask = torch.ones(
[tgt_len, tgt_len],
device=tgt_pad_mask.device,
Expand All @@ -197,9 +198,14 @@ def _compute_dec_mask(self, tgt_pad_mask, future):
future_mask = future_mask.triu_(-self.sliding_window)
future_mask = future_mask.bool()
future_mask = ~future_mask.view(1, tgt_len, tgt_len)

# Patch for scaled dot product attention.
patch_mask = ~torch.all(
tgt_pad_mask + future_mask, dim=2, keepdim=True
).expand_as(tgt_pad_mask + future_mask)
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
else: # only mask padding, result mask in (B, 1, T)
dec_mask = torch.logical_and(dec_mask, patch_mask)
else:
# Only mask padding, result mask in (B, 1, T).
dec_mask = tgt_pad_mask
return dec_mask

Expand Down Expand Up @@ -717,7 +723,9 @@ def _forward(
dec_mask = None

if layer_in.size(1) > 1:
# masking is necessary when sequence length is greater than one
# Masking is necessary when sequence length is greater than one
# The decoding has not started yet,
# we compute the scores on the source tokens in one shot.
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
dec_mask = dec_mask.unsqueeze(1)
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
Expand Down Expand Up @@ -859,8 +867,11 @@ def detach_state(self):
def forward(self, tgt, enc_out=None, step=None, **kwargs):
"""Decode, possibly stepwise."""
if step == 0:
# decoding mode.
# Initialize KV cache.
self._init_cache(tgt)
elif step is None:
# training mode.
for layer in self.transformer_layers:
layer.self_attn.layer_cache = (
False,
Expand Down
10 changes: 8 additions & 2 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module that contain iterator used for dynamic data."""
import torch
from itertools import cycle
from onmt.constants import CorpusTask
from onmt.constants import CorpusTask, ModelTask
from onmt.inputters.text_corpus import get_corpora, build_corpora_iters
from onmt.inputters.text_utils import (
text_sort_key,
Expand Down Expand Up @@ -164,6 +164,10 @@ def __init__(
self.skip_empty_level = skip_empty_level
self.random_shuffler = RandomShuffler()
self.bucket_idx = 0
if task != CorpusTask.TRAIN and vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
self.left_pad = True
else:
self.left_pad = False

@classmethod
def from_opt(
Expand Down Expand Up @@ -354,7 +358,9 @@ def __iter__(self):
# within the batch
if self.task == CorpusTask.TRAIN:
minibatch.sort(key=lambda x: self.sort_key(x[0]), reverse=True)
tensor_batch = tensorify(self.vocabs, minibatch, self.device)
tensor_batch = tensorify(
self.vocabs, minibatch, self.device, self.left_pad
)
yield (tensor_batch, bucket_idx)


Expand Down
43 changes: 31 additions & 12 deletions onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def parse_align_idx(align_pharaoh):
return flatten_align_idx


def tensorify(vocabs, minibatch, device):
def tensorify(vocabs, minibatch, device, left_pad=False):
"""
This function transforms a batch of example in tensors
Each example looks like
Expand All @@ -193,21 +193,37 @@ def tensorify(vocabs, minibatch, device):
}
"""
tensor_batch = {}
tbatchsrc = [
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]
if left_pad:
tbatchsrc = [
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device).flip(
dims=[0]
)
for ex, indice in minibatch
]
else:
tbatchsrc = [
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]
padidx = vocabs["src"][DefaultTokens.PAD]
tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx)
if "feats" in minibatch[0][0]["src"]:
tbatchfs = [tbatchsrc]
for feat_id in range(len(minibatch[0][0]["src"]["feats"])):
tbatchfeat = [
torch.tensor(
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
)
for ex, indice in minibatch
]
if left_pad:
tbatchfeat = [
torch.tensor(
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
).flip(dims=[0])
for ex, indice in minibatch
]
else:
tbatchfeat = [
torch.tensor(
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
)
for ex, indice in minibatch
]
padidx = vocabs["src_feats"][feat_id][DefaultTokens.PAD]
tbatchfeat = pad_sequence(
tbatchfeat, batch_first=True, padding_value=padidx
Expand All @@ -218,7 +234,10 @@ def tensorify(vocabs, minibatch, device):
# Need to add features in last dimensions
tbatchsrc = tbatchsrc[:, :, None]

tensor_batch["src"] = tbatchsrc
if left_pad:
tensor_batch["src"] = tbatchsrc.flip(dims=[1])
else:
tensor_batch["src"] = tbatchsrc

tensor_batch["srclen"] = torch.tensor(
[len(ex["src"]["src_ids"]) for ex, indice in minibatch],
Expand Down
5 changes: 4 additions & 1 deletion onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def forward(
# 1) Project key, value, and query.
# as a reminder at training layer_cache[0] remains False
if self.layer_cache[0]:
# Retrieve keys and values from the KV cache (decoding mode only).
if self.attn_type == "self":
query, key, value = (
self.linear_query(query),
Expand Down Expand Up @@ -451,6 +452,7 @@ def forward(
self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value
else:
# Retrieve keys and values from linear layers (training mode).
key = self.maybe_ckpt(self.linear_keys, key)
value = self.maybe_ckpt(self.linear_values, value)
query = self.maybe_ckpt(self.linear_query, query)
Expand Down Expand Up @@ -491,12 +493,12 @@ def forward(
self.flash2
and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591
)

if (
self.max_relative_positions in [-1, 0]
and not return_attn
and query.device != torch.device("cpu")
):
# Apply flash2 attention.
causal = self.is_decoder and self.attn_type == "self" and mask is not None
if self.is_decoder and self.attn_type == "self" and flash2:
if causal:
Expand All @@ -514,6 +516,7 @@ def forward(
window_size=window_size,
).transpose(1, 2)
else:
# Apply scaled dot product attention.
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_math=True, enable_mem_efficient=True
):
Expand Down
21 changes: 7 additions & 14 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,6 @@ def _decode_and_generate(
step=step,
return_attn=self.global_scorer.has_cov_pen or return_attn,
)

# Generator forward.
if not self.copy_attn:
if "std" in dec_attn:
Expand Down Expand Up @@ -988,16 +987,6 @@ def _align_forward(self, batch, predictions):

def translate_batch(self, batch, attn_debug):
"""Translate a batch of sentences."""
batch_size = len(batch["srclen"])
if batch_size != 1:
warning_msg = (
"GeneratorLM does not support batch_size != 1"
" nicely. You can remove this limitation here."
" With batch_size > 1 the end of each input is"
" repeated until the input is finished. Then"
" generation will start."
)
self._log(warning_msg)
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearchLM(
Expand Down Expand Up @@ -1061,7 +1050,7 @@ def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs):
log_probs = log_probs[:, -1, :]
return log_probs

def _translate_batch_with_strategy(self, batch, decode_strategy):
def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
"""Translate a batch of sentences step by step using cache.
Args:
Expand All @@ -1081,7 +1070,12 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):
src = batch["src"]
src_len = batch["srclen"]

src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len)
if left_pad:
target_prefix = None
else:
src, src_len, target_prefix = self.split_src_to_prevent_padding(
src, src_len
)

# (2) init decoder
self.model.decoder.init_state(src, None, None)
Expand Down Expand Up @@ -1109,7 +1103,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):
decoder_input = (
src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1)
)

log_probs, attn = self._decode_and_generate(
decoder_input,
None,
Expand Down

0 comments on commit 78c8908

Please sign in to comment.