From 96a53710fa698ed5720b2a87eed8a65eb3deef8b Mon Sep 17 00:00:00 2001 From: ayissi-msft Date: Thu, 24 Oct 2024 15:16:02 -0700 Subject: [PATCH] Updating the perplexity script adding the logsoftmax loci --- .../model_validation/perplexity_metrics.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tools/python/model_validation/perplexity_metrics.py b/tools/python/model_validation/perplexity_metrics.py index 707515d90..bbb0dd7ff 100644 --- a/tools/python/model_validation/perplexity_metrics.py +++ b/tools/python/model_validation/perplexity_metrics.py @@ -13,33 +13,32 @@ def perplexity_eval(model_dir): dataset = get_wikitext2() + total_log_probs = 0 + total_token_count = 0 + for batch in dataset: text = batch["text"] input_ids = tokenizer.encode_batch([text]) params = og.GeneratorParams(model) - params.set_model_input("input_ids", input_ids) generator = og.Generator(model, params) - logits = generator.compute_logits() + logits = generator.compute_logits("logits") targets = np.roll(input_ids, -1, axis=1) - probs = torch.softmax(torch.tensor(logits), dim=1).numpy() + # Use LogSoftMax here + log_probs = torch.nn.functional.log_softmax(torch.tensor(logits), dim=1).numpy() batch_size, seq_length = targets.shape - target_probs = probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets] - - log_probs = np.log(target_probs) - - total_log_probs += np.sum(log_probs) + target_log_probs = log_probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets] + total_log_probs += np.sum(target_log_probs) total_token_count += targets.size avg_log_prob = total_log_probs / total_token_count - perplexity = np.exp(-avg_log_prob) print(f"The perplexity of {model_dir} is {perplexity}") - return perplexity + return perplexity \ No newline at end of file