Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stopped using the transformers (4.39.3) trainer compute_metrics to calculate custom metrics as large validation sets will crash #54

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,6 @@ class HFFineTuningMapping(DatasetMapping):

def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
return {
"age_at_index": record["age_at_index"],
"age_at_index": record["age"] if "age" in record else record["age_at_index"],
"classifier_label": record["label"],
}
87 changes: 56 additions & 31 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import os
from typing import Tuple
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from datasets import DatasetDict, load_from_disk
from peft import LoraConfig, get_peft_model
from scipy.special import expit as sigmoid
from sklearn.metrics import accuracy_score, auc, precision_recall_curve, roc_auc_score
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
from transformers import EarlyStoppingCallback, Trainer, set_seed
from transformers.utils import logging

Expand All @@ -33,38 +33,60 @@
LOG = logging.get_logger("transformers")


def compute_metrics(eval_pred):
outputs, labels = eval_pred
logits = outputs[0]
def compute_metrics(references: List[float], logits: List[float]) -> Dict[str, Any]:
"""
Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits.

# Convert logits to probabilities using sigmoid
probabilities = sigmoid(logits)

if probabilities.shape[1] == 2:
positive_probs = probabilities[:, 1]
else:
positive_probs = probabilities.squeeze() # Ensure it's a 1D array

# Calculate predictions based on probability threshold of 0.5
predictions = (positive_probs > 0.5).astype(np.int32)
Args:
references (List[float]): Ground truth binary labels (0 or 1).
logits (List[float]): Logits output from the model (raw prediction scores), which will be converted to probabilities using the sigmoid function.

# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
Returns:
Dict[str, Any]: A dictionary containing:
- 'roc_auc': The Area Under the Receiver Operating Characteristic Curve (ROC-AUC).
- 'pr_auc': The Area Under the Precision-Recall Curve (PR-AUC).

Notes:
- The `sigmoid` function is used to convert the logits into probabilities.
- ROC-AUC measures the model's ability to distinguish between classes, while PR-AUC focuses on performance when dealing with imbalanced data.
"""
# Convert logits to probabilities using sigmoid
probabilities = sigmoid(logits)
# # Calculate PR-AUC
# Calculate ROC-AUC
roc_auc = roc_auc_score(labels, positive_probs)

# Calculate PR-AUC
precision, recall, _ = precision_recall_curve(labels, positive_probs)
roc_auc = roc_auc_score(references, probabilities)
precision, recall, _ = precision_recall_curve(references, probabilities)
pr_auc = auc(recall, precision)

return {"accuracy": accuracy, "roc_auc": roc_auc, "pr_auc": pr_auc}
return {"roc_auc": roc_auc, "pr_auc": pr_auc}


def load_pretrained_model_and_tokenizer(
model_args,
) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
# Try to load the pretrained tokenizer
"""
Loads a pretrained model and tokenizer based on the given model arguments.

Args:
model_args (Namespace): An argument object containing the following fields:
- tokenizer_name_or_path (str): The path or name of the pretrained tokenizer to load.
- model_name_or_path (str): The path or name of the pretrained model to load.
- finetune_model_type (str): The type of fine-tuning model to use. Must be one of the values in `FineTuneModelType`.

Returns:
Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
- CehrBertPreTrainedModel: The loaded pretrained model (either a classification or LSTM model).
- CehrBertTokenizer: The loaded pretrained tokenizer.

Raises:
ValueError: If the tokenizer cannot be loaded from the specified path, or if the fine-tuning model type is invalid.

Notes:
- If loading the model fails, the function will attempt to create a new model using the provided model arguments
and the tokenizer's configuration.
- The function supports two types of models for fine-tuning:
- `CehrBertForClassification` for pooling-based models.
- `CehrBertLstmForClassification` for LSTM-based models.
"""
try:
tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
except Exception:
Expand Down Expand Up @@ -284,7 +306,6 @@ def assign_split(example):
data_collator=collator,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["validation"],
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)],
args=training_args,
)
Expand All @@ -297,6 +318,8 @@ def assign_split(example):

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_state()

if training_args.do_predict:
Expand All @@ -309,12 +332,6 @@ def assign_split(example):
trainer._load_from_checkpoint(training_args.output_dir)

test_results = trainer.predict(processed_dataset["test"])
# Save results to JSON
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
with open(test_results_path, "w") as f:
json.dump(test_results.metrics, f, indent=4)

LOG.info(f"Test results: {test_results.metrics}")

person_ids = [row["person_id"] for row in processed_dataset["test"]]

Expand All @@ -330,6 +347,14 @@ def assign_split(example):
prediction_pd = pd.DataFrame({"person_id ": person_ids, "prediction": predictions, "label": labels})
prediction_pd.to_csv(os.path.join(training_args.output_dir, "test_predictions.csv"), index=False)

# Save results to JSON
metrics = compute_metrics(references=labels, logits=predictions)
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
with open(test_results_path, "w") as f:
json.dump(metrics, f, indent=4)

LOG.info(f"Test results: {metrics}")


if __name__ == "__main__":
main()
Loading