Skip to content

Commit

Permalink
Updating the perplexity script adding the logsoftmax loci
Browse files Browse the repository at this point in the history
  • Loading branch information
ayissi-msft committed Oct 24, 2024
1 parent 4aafc2b commit 96a5371
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions tools/python/model_validation/perplexity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,32 @@ def perplexity_eval(model_dir):

dataset = get_wikitext2()

total_log_probs = 0
total_token_count = 0

for batch in dataset:
text = batch["text"]

input_ids = tokenizer.encode_batch([text])

params = og.GeneratorParams(model)
params.set_model_input("input_ids", input_ids)
generator = og.Generator(model, params)

logits = generator.compute_logits()
logits = generator.compute_logits("logits")

targets = np.roll(input_ids, -1, axis=1)

probs = torch.softmax(torch.tensor(logits), dim=1).numpy()
# Use LogSoftMax here
log_probs = torch.nn.functional.log_softmax(torch.tensor(logits), dim=1).numpy()

batch_size, seq_length = targets.shape
target_probs = probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets]

log_probs = np.log(target_probs)

total_log_probs += np.sum(log_probs)
target_log_probs = log_probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets]

total_log_probs += np.sum(target_log_probs)
total_token_count += targets.size

avg_log_prob = total_log_probs / total_token_count

perplexity = np.exp(-avg_log_prob)

print(f"The perplexity of {model_dir} is {perplexity}")
return perplexity
return perplexity

0 comments on commit 96a5371

Please sign in to comment.