Skip to content

Commit

Permalink
rolling ppl with sliding window (#2553)
Browse files Browse the repository at this point in the history
* rolling ppl with window size 4096 and stride 512
  • Loading branch information
l-k-11235 authored Jan 23, 2024
1 parent b67e492 commit 1c27987
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 135 deletions.
146 changes: 44 additions & 102 deletions eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,141 +1,83 @@
import copy
import json
import numpy as np
import os
import pyonmttok
import time
from onmt.constants import CorpusTask, DefaultTokens
from onmt.inference_engine import InferenceEnginePY
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
import onmt.opts as opts
from onmt.utils.logging import init_logger
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
from onmt.transforms import get_transforms_cls


def compute_file_ppl(output_filename):
with open(output_filename, "r") as f:
run_results = json.load(f)
nlls = []
lengths = []
for i, _res in enumerate(run_results["scored_results"]):
print(_res)
nlls.append(_res[0])
lengths.append(_res[1])
file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths))
print("wikitext-2 ppl: %.4f" % file_ppl)


def tokenize_dataset(opt, context_length):
print("Tokenization...")

# Prepare the dataset
# Clean and Concat the dataset
x = open(opt.src, "r").readlines()
x = [_x.rstrip("\n") for _x in x]
y = DefaultTokens.SEP.join(x)

with open(opt.src + ".temp", "w") as writer:
writer.write(y)

# ########################## #
# Build the dataset iterator #
# ########################## #

# Build the vocab
vocab_path_in = "/nas-labs/LM/big_llms/llama/7B/llama.vocab"
voc = []
with open(vocab_path_in, "r", encoding="utf-8") as reader:
for line in reader:
line = line.strip("\n")
voc.append(line)
vocabs = {}
src_vocab = pyonmttok.build_vocab_from_tokens(voc)
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["data_task"] = "lm"
vocabs["decoder_start_token"] = "<s>"

transforms_cls = get_transforms_cls(opt._all_transform)

new_opt = opt
new_opt.gpu = -1
new_opt.parallel_mode = "data_parallel"
new_opt.src = opt.src + ".temp"

dataset_iter = build_dynamic_dataset_iter(
new_opt, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=-1
)

input_tokens = []
for batch, i in dataset_iter:
for i in range(batch["src"].size()[0]):
start_ids = batch["src"][i, :, 0].cpu().numpy().tolist()
input_tokens += [
vocabs["src"].lookup_index(id)
for id in start_ids
if id != vocabs["src"].lookup_token(DefaultTokens.PAD)
]

def make_chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]

# #################### #
# Tokenize the dataset #
# ################### #
with open(opt.src + f".tokenized.context_{context_length}", "w") as writer:
for _chunk in make_chunks(input_tokens, context_length - 1):
writer.write(" ".join(_chunk) + "\n")
print(len(_chunk))
xx = [_x for _x in x if _x != " \n"]
from onmt.transforms.tokenize import SentencePieceTransform

tokenizer = SentencePieceTransform(opt)
tokenizer.warm_up()
tokens = tokenizer._tokenize(xx)
print("Done !")

z = open(opt.src + f".tokenized.context_{context_length}", "r").readlines()
print(len(z[0].split(" ")))
return tokens


def evaluate(opt):
"""Score the wikitext2 testset"""
"""Score the wikitext2 testset
The perplexity of the file is calculated with a window size of max_seq_length = 4096 tokens.
At each step, the window shifts by 512 tokens, and its first max_seq_length - stride
tokens are considered as context tokens. This means that their logits are not
taken into account, allowing this rolling perplexity to be calculated without overlap."""

ArgumentParser.validate_translate_opts(opt)
ArgumentParser._get_all_transform_translate(opt)
ArgumentParser._validate_transforms_opts(opt)
ArgumentParser.validate_translate_opts_dynamic(opt)
logger = init_logger(opt.log_file)
set_random_seed(opt.seed, use_gpu(opt))

run_results = {}
dir_name = os.path.dirname(opt.models[0])
base_name = os.path.basename(opt.models[0])

output_filename = os.path.join(
dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3]
)
# Tokenize the dataset.
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokens = tokenize_dataset(opt, context_length=512)

# Build the translator (along with the model.
engine_opt = copy.copy(opt)
engine_opt._all_transform = []
engine = InferenceEnginePY(engine_opt)

# Tokenize the dataset.
opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokenize_dataset(opt, context_length=512)

# Score the tokeznized dataset
engine.opt.src = opt.src + f".tokenized.context_{512}"
start_time = time.time()
scored_results = engine.score_file()
engine.terminate()
run_results["scored_results"] = scored_results
# Score the dataset.
stride = 512
max_seq_length = 4096

with open(output_filename, "w") as f:
json.dump(run_results, f, ensure_ascii=False, indent=2)
seq_len = len(tokens)
src = []
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + max_seq_length, seq_len)
src.append(" ".join(tokens[begin_loc:end_loc]))

compute_file_ppl(output_filename)
start_time = time.time()
engine.translator.return_gold_log_probs = True
score_results = engine.score_list(src=src)
nlls = []
lengths = []
for _, log_probs, _ in score_results:
lengths.append(stride)
# zero out the context tokens
nlls += [
log_probs[i][0]
for i, _ in enumerate(log_probs)
if i > (max_seq_length - stride)
]
ppl = np.exp(-np.sum(nlls) / np.sum(lengths))

engine.terminate()
end_time = time.time()
logger.info("total run time %.2f" % (end_time - start_time))
logger.info(
"wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501
% (ppl)
)


def _get_parser():
Expand Down
1 change: 1 addition & 0 deletions onmt/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _translate(self, infer_iter):

def _score(self, infer_iter):
self.translator.with_scores = True
self.return_gold_log_probs = True
return self.translator._score(infer_iter)

def score_list_parallel(self, src):
Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/pull_request_chk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model_lm.pt \
-ban_unk_token \
-length_penalty none \
-out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1
diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(python -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling
diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(${PYTHON} -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}
rm $TMP_OUT_DIR/gen_sampling
Expand Down
73 changes: 41 additions & 32 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
logger=None,
seed=-1,
with_score=False,
return_gold_log_probs=False,
):
self.model = model
self.vocabs = vocabs
Expand Down Expand Up @@ -205,6 +206,8 @@ def __init__(
set_random_seed(seed, self._use_cuda)
self.with_score = with_score

self.return_gold_log_probs = return_gold_log_probs

@classmethod
def from_opt(
cls,
Expand Down Expand Up @@ -280,26 +283,17 @@ def _log(self, msg):
print(msg)

def _gold_score(
self,
batch,
enc_out,
src_len,
use_src_map,
enc_final_hs,
batch_size,
src,
self, batch, enc_out, src_len, use_src_map, enc_final_hs, batch_size, src
):
if "tgt" in batch.keys() and not self.tgt_file_prefix:
gs = self._score_target(
batch,
enc_out,
src_len,
batch["src_map"] if use_src_map else None,
gs, glp = self._score_target(
batch, enc_out, src_len, batch["src_map"] if use_src_map else None
)
self.model.decoder.init_state(src, enc_out, enc_final_hs)
else:
gs = [0] * batch_size
return gs
glp = None
return gs, glp

def _translate(
self,
Expand Down Expand Up @@ -584,12 +578,25 @@ def _score(self, infer_iter):
self.with_scores = True
scored_bucket = {}
for batch, bucket_idx in infer_iter:
batch_data = self.translate_batch(batch, attn_debug=False)
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
if self.return_gold_log_probs:
batch_gold_log_probs = (
batch_data["gold_log_probs"].cpu().numpy().tolist()
)
else:
batch_gold_log_probs = None
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
batch_inds_in_bucket = batch["ind_in_bucket"]
for i, _score in enumerate(batch_gold_scores):
scored_bucket[batch_inds_in_bucket[i]] = (_score, batch_tgt_lengths[i])
log_probs = (
batch_gold_log_probs[i] if self.return_gold_log_probs else None
)
scored_bucket[batch_inds_in_bucket[i]] = (
_score,
log_probs,
batch_tgt_lengths[i],
)
score_results = [scored_bucket[i] for i in range(len(scored_bucket))]
return score_results

Expand Down Expand Up @@ -720,6 +727,7 @@ def _score_target(self, batch, enc_out, src_len, src_map):
def report_results(
self,
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
Expand All @@ -730,6 +738,7 @@ def report_results(
"attention": None,
"batch": batch,
"gold_score": gold_score,
"gold_log_probs": gold_log_probs,
}

results["scores"] = decode_strategy.scores
Expand Down Expand Up @@ -900,7 +909,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):

self.model.decoder.init_state(src, enc_out, enc_final_hs)

gold_score = self._gold_score(
gold_score, gold_log_probs = self._gold_score(
batch,
enc_out,
src_len,
Expand Down Expand Up @@ -961,6 +970,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):

return self.report_results(
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
Expand All @@ -982,7 +992,7 @@ def _score_target(self, batch, enc_out, src_len, src_map):
gold = tgt[:, 1:, :]
gold_scores = log_probs.gather(2, gold)
gold_scores = gold_scores.sum(dim=1).view(-1)
return gold_scores
return gold_scores, None


class GeneratorLM(Inference):
Expand All @@ -1001,8 +1011,9 @@ def _align_forward(self, batch, predictions):
"""
raise NotImplementedError

def translate_batch(self, batch, attn_debug):
def translate_batch(self, batch, attn_debug, scoring=False):
"""Translate a batch of sentences."""
max_length = 0 if scoring else self.max_length
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearchLM(
Expand All @@ -1015,7 +1026,7 @@ def translate_batch(self, batch, attn_debug):
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=self.max_length,
max_length=max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
Expand All @@ -1039,7 +1050,7 @@ def translate_batch(self, batch, attn_debug):
n_best=self.n_best,
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=self.max_length,
max_length=max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
Expand Down Expand Up @@ -1095,14 +1106,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):

# (2) init decoder
self.model.decoder.init_state(src, None, None)
gold_score = self._gold_score(
batch,
None,
src_len,
use_src_map,
None,
batch_size,
src,
gold_score, gold_log_probs = self._gold_score(
batch, None, src_len, use_src_map, None, batch_size, src
)

# (3) prep decode_strategy. Possibly repeat src objects.
Expand Down Expand Up @@ -1158,6 +1163,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):

return self.report_results(
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
Expand All @@ -1177,7 +1183,10 @@ def _score_target(self, batch, enc_out, src_len, src_map):
)

log_probs[:, :, self._tgt_pad_idx] = 0
gold_scores = log_probs.gather(2, tgt)
gold_scores = gold_scores.sum(dim=1).view(-1)
gold_log_probs = log_probs.gather(2, tgt)
gold_scores = gold_log_probs.sum(dim=1).view(-1)

if self.return_gold_log_probs:
return gold_scores, gold_log_probs

return gold_scores
return gold_scores, None

0 comments on commit 1c27987

Please sign in to comment.