diff --git a/tools/python/model_validation/perplexity_metrics.py b/tools/python/model_validation/perplexity_metrics.py index 707515d90..bbb0dd7ff 100644 --- a/tools/python/model_validation/perplexity_metrics.py +++ b/tools/python/model_validation/perplexity_metrics.py @@ -13,33 +13,32 @@ def perplexity_eval(model_dir): dataset = get_wikitext2() + total_log_probs = 0 + total_token_count = 0 + for batch in dataset: text = batch["text"] input_ids = tokenizer.encode_batch([text]) params = og.GeneratorParams(model) - params.set_model_input("input_ids", input_ids) generator = og.Generator(model, params) - logits = generator.compute_logits() + logits = generator.compute_logits("logits") targets = np.roll(input_ids, -1, axis=1) - probs = torch.softmax(torch.tensor(logits), dim=1).numpy() + # Use LogSoftMax here + log_probs = torch.nn.functional.log_softmax(torch.tensor(logits), dim=1).numpy() batch_size, seq_length = targets.shape - target_probs = probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets] - - log_probs = np.log(target_probs) - - total_log_probs += np.sum(log_probs) + target_log_probs = log_probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets] + total_log_probs += np.sum(target_log_probs) total_token_count += targets.size avg_log_prob = total_log_probs / total_token_count - perplexity = np.exp(-avg_log_prob) print(f"The perplexity of {model_dir} is {perplexity}") - return perplexity + return perplexity \ No newline at end of file