diff --git a/tools/python/model_validation/perplexity_metrics.py b/tools/python/model_validation/perplexity_metrics.py new file mode 100644 index 000000000..bbb0dd7ff --- /dev/null +++ b/tools/python/model_validation/perplexity_metrics.py @@ -0,0 +1,44 @@ +from datasets import load_dataset +import numpy as np +import onnxruntime_genai as og +import torch + +def get_wikitext2(): + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + return testdata + +def perplexity_eval(model_dir): + model = og.Model(f'{model_dir}') + tokenizer = og.Tokenizer(model) + + dataset = get_wikitext2() + + total_log_probs = 0 + total_token_count = 0 + + for batch in dataset: + text = batch["text"] + + input_ids = tokenizer.encode_batch([text]) + + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + + logits = generator.compute_logits("logits") + + targets = np.roll(input_ids, -1, axis=1) + + # Use LogSoftMax here + log_probs = torch.nn.functional.log_softmax(torch.tensor(logits), dim=1).numpy() + + batch_size, seq_length = targets.shape + target_log_probs = log_probs[np.arange(batch_size)[:, None], np.arange(seq_length), targets] + + total_log_probs += np.sum(target_log_probs) + total_token_count += targets.size + + avg_log_prob = total_log_probs / total_token_count + perplexity = np.exp(-avg_log_prob) + + print(f"The perplexity of {model_dir} is {perplexity}") + return perplexity \ No newline at end of file diff --git a/tools/python/model_validation/validation_config.json b/tools/python/model_validation/validation_config.json index c7259481d..9e06f7114 100644 --- a/tools/python/model_validation/validation_config.json +++ b/tools/python/model_validation/validation_config.json @@ -2,23 +2,29 @@ "models": [ { "name": "Qwen/Qwen2.5-7B-Instruct", - "chat_template": "<|im_start|>\n <|user|> \n {input} <|im_end>\n'<|im_start|>assistant\n" + "chat_template": "<|im_start|>\n <|user|> \n {input} <|im_end>\n'<|im_start|>assistant\n", + "metrics": [] }, { "name": "meta-llama/Llama-2-7b-chat-hf", - "chat_template": "[INST]<>\n{input}<>[INST]" + "chat_template": "[INST]<>\n{input}<>[INST]", + "metrics": [] + }, { "name": "mistralai/Mistral-7B-Instruct-v0.3", - "chat_template": " \"[INST] \" + {input} + \"[/INST]\"" + "chat_template": " \"[INST] \" + {input} + \"[/INST]\"", + "metrics": [] }, { "name": "microsoft/Phi-3.5-mini-instruct", - "chat_template": "<|user|>\n{input} <|end|>\n<|assistant|>" + "chat_template": "<|user|>\n{input} <|end|>\n<|assistant|>", + "metrics": [] }, { "name": "google/gemma-2-2b-it", - "chat_template": "'' + <|user|> '\n' + {input} + '\n" + "chat_template": "'' + <|user|> '\n' + {input} + '\n", + "metrics": [] } ], "inputs": [ diff --git a/tools/python/model_validation/validation_tool.py b/tools/python/model_validation/validation_tool.py index fad36aa36..a798b5c18 100644 --- a/tools/python/model_validation/validation_tool.py +++ b/tools/python/model_validation/validation_tool.py @@ -4,6 +4,7 @@ import json import os import pandas as pd +from perplexity_metrics import perplexity_eval def create_table(output): df = pd.DataFrame(output, columns=['Model Name', 'Validation Completed', 'Exceptions / Failures']) @@ -113,6 +114,9 @@ def validate_model(args, model_dict, model_dir): print(f'Failure after validation model {e}') exception = True output.append([model_dict["name"], validation_complete, e]) + + # function call out here? + perplexity_eval(output_path) if not exception: output.append([model_dict["name"], validation_complete, e])