From 63db5d656574f7bc56c221ed6a9d0d336dfa94ad Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Mon, 9 Sep 2024 09:25:31 -0400 Subject: [PATCH 1/3] stopped using the compute_metrics to calculate roc_auc/pr_auc/accuracy for the evaluation step, the reason is that for large eval datasets, the evalulation steps can run out of the CPU memory as it keeps all predictions on CPU --- .../runners/hf_cehrbert_finetune_runner.py | 86 ++++++++++++------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py index 87499acd..467c6bd4 100644 --- a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py @@ -1,13 +1,13 @@ import json import os -from typing import Tuple +from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd from datasets import DatasetDict, load_from_disk from peft import LoraConfig, get_peft_model from scipy.special import expit as sigmoid -from sklearn.metrics import accuracy_score, auc, precision_recall_curve, roc_auc_score +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score from transformers import EarlyStoppingCallback, Trainer, set_seed from transformers.utils import logging @@ -33,38 +33,59 @@ LOG = logging.get_logger("transformers") -def compute_metrics(eval_pred): - outputs, labels = eval_pred - logits = outputs[0] +def compute_metrics(references: List[float], logits: List[float]) -> Dict[str, Any]: + """ + Compute a set of evaluation metrics including accuracy, ROC-AUC, and PR-AUC. - # Convert logits to probabilities using sigmoid - probabilities = sigmoid(logits) - - if probabilities.shape[1] == 2: - positive_probs = probabilities[:, 1] - else: - positive_probs = probabilities.squeeze() # Ensure it's a 1D array - - # Calculate predictions based on probability threshold of 0.5 - predictions = (positive_probs > 0.5).astype(np.int32) + Args: + references (list or array-like): Ground truth (correct) labels for each sample. + logits (list or array-like): Predicted scores for each sample, typically the model's output. - # Calculate accuracy - accuracy = accuracy_score(labels, predictions) + Returns: + Dict[str, Any]: A dictionary containing the computed metrics where keys represent the metric names + (e.g., 'accuracy', 'roc_auc', 'pr_auc') and values are the corresponding metric values. + This function uses the `evaluate` library to compute the following metrics: + - Accuracy: The proportion of correct predictions. + - ROC-AUC: The area under the receiver operating characteristic curve. + """ + # Convert logits to probabilities using sigmoid + probabilities = sigmoid(logits) + # # Calculate PR-AUC # Calculate ROC-AUC - roc_auc = roc_auc_score(labels, positive_probs) - - # Calculate PR-AUC - precision, recall, _ = precision_recall_curve(labels, positive_probs) + roc_auc = roc_auc_score(references, probabilities) + precision, recall, _ = precision_recall_curve(references, probabilities) pr_auc = auc(recall, precision) - - return {"accuracy": accuracy, "roc_auc": roc_auc, "pr_auc": pr_auc} + return {"roc_auc": roc_auc, "pr_auc": pr_auc} def load_pretrained_model_and_tokenizer( model_args, ) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]: - # Try to load the pretrained tokenizer + """ + Loads a pretrained model and tokenizer based on the given model arguments. + + Args: + model_args (Namespace): An argument object containing the following fields: + - tokenizer_name_or_path (str): The path or name of the pretrained tokenizer to load. + - model_name_or_path (str): The path or name of the pretrained model to load. + - finetune_model_type (str): The type of fine-tuning model to use. Must be one of the values in `FineTuneModelType`. + + Returns: + Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]: + - CehrBertPreTrainedModel: The loaded pretrained model (either a classification or LSTM model). + - CehrBertTokenizer: The loaded pretrained tokenizer. + + Raises: + ValueError: If the tokenizer cannot be loaded from the specified path, or if the fine-tuning model type is invalid. + + Notes: + - If loading the model fails, the function will attempt to create a new model using the provided model arguments + and the tokenizer's configuration. + - The function supports two types of models for fine-tuning: + - `CehrBertForClassification` for pooling-based models. + - `CehrBertLstmForClassification` for LSTM-based models. + """ try: tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path) except Exception: @@ -284,7 +305,6 @@ def assign_split(example): data_collator=collator, train_dataset=processed_dataset["train"], eval_dataset=processed_dataset["validation"], - compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)], args=training_args, ) @@ -297,6 +317,8 @@ def assign_split(example): trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) trainer.save_state() if training_args.do_predict: @@ -309,12 +331,6 @@ def assign_split(example): trainer._load_from_checkpoint(training_args.output_dir) test_results = trainer.predict(processed_dataset["test"]) - # Save results to JSON - test_results_path = os.path.join(training_args.output_dir, "test_results.json") - with open(test_results_path, "w") as f: - json.dump(test_results.metrics, f, indent=4) - - LOG.info(f"Test results: {test_results.metrics}") person_ids = [row["person_id"] for row in processed_dataset["test"]] @@ -330,6 +346,14 @@ def assign_split(example): prediction_pd = pd.DataFrame({"person_id ": person_ids, "prediction": predictions, "label": labels}) prediction_pd.to_csv(os.path.join(training_args.output_dir, "test_predictions.csv"), index=False) + # Save results to JSON + metrics = compute_metrics(references=labels, logits=predictions) + test_results_path = os.path.join(training_args.output_dir, "test_results.json") + with open(test_results_path, "w") as f: + json.dump(metrics, f, indent=4) + + LOG.info(f"Test results: {metrics}") + if __name__ == "__main__": main() From 58b761c9fcf25807b5835fd456ccf80ea64627ee Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Mon, 9 Sep 2024 09:27:29 -0400 Subject: [PATCH 2/3] added the backward compatibility if the finetuning data is constructed from the OMOP data, where the age_at_index is labelled as age --- .../data_generators/hf_data_generator/hf_dataset_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index 2a413bb6..576817e8 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -453,6 +453,6 @@ class HFFineTuningMapping(DatasetMapping): def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: return { - "age_at_index": record["age_at_index"], + "age_at_index": record["age"] if "age" in record else record["age_at_index"], "classifier_label": record["label"], } From 1b1ccb8eb25afc36a61ffc0edba964befe5c0e77 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Mon, 9 Sep 2024 09:30:30 -0400 Subject: [PATCH 3/3] updated the docstring of compute_metrics --- .../runners/hf_cehrbert_finetune_runner.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py index 467c6bd4..8746c324 100644 --- a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py @@ -35,19 +35,20 @@ def compute_metrics(references: List[float], logits: List[float]) -> Dict[str, Any]: """ - Compute a set of evaluation metrics including accuracy, ROC-AUC, and PR-AUC. + Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits. Args: - references (list or array-like): Ground truth (correct) labels for each sample. - logits (list or array-like): Predicted scores for each sample, typically the model's output. + references (List[float]): Ground truth binary labels (0 or 1). + logits (List[float]): Logits output from the model (raw prediction scores), which will be converted to probabilities using the sigmoid function. Returns: - Dict[str, Any]: A dictionary containing the computed metrics where keys represent the metric names - (e.g., 'accuracy', 'roc_auc', 'pr_auc') and values are the corresponding metric values. + Dict[str, Any]: A dictionary containing: + - 'roc_auc': The Area Under the Receiver Operating Characteristic Curve (ROC-AUC). + - 'pr_auc': The Area Under the Precision-Recall Curve (PR-AUC). - This function uses the `evaluate` library to compute the following metrics: - - Accuracy: The proportion of correct predictions. - - ROC-AUC: The area under the receiver operating characteristic curve. + Notes: + - The `sigmoid` function is used to convert the logits into probabilities. + - ROC-AUC measures the model's ability to distinguish between classes, while PR-AUC focuses on performance when dealing with imbalanced data. """ # Convert logits to probabilities using sigmoid probabilities = sigmoid(logits)