Skip to content

Commit

Permalink
implemented a custom loop to generate the test predictions in batches…
Browse files Browse the repository at this point in the history
… because the transformers train will crash due to a CPU OOM error since it tries to keep all the predictions around on the CPU
  • Loading branch information
ChaoPang committed Sep 10, 2024
1 parent e144717 commit ab26c50
Showing 1 changed file with 37 additions and 20 deletions.
57 changes: 37 additions & 20 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import json
import os
from typing import Any, Dict, List, Tuple
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import pandas as pd
from datasets import DatasetDict, load_from_disk
from pandas import Series
from peft import LoraConfig, get_peft_model
from scipy.special import expit as sigmoid
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
from torch.utils.data import DataLoader
from transformers import EarlyStoppingCallback, Trainer, set_seed
from transformers.utils import logging

Expand All @@ -33,7 +36,9 @@
LOG = logging.get_logger("transformers")


def compute_metrics(references: List[float], logits: List[float]) -> Dict[str, Any]:
def compute_metrics(
references: Union[List[float], Series[float]], logits: Union[List[float], Series[float]]
) -> Dict[str, Any]:
"""
Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits.
Expand Down Expand Up @@ -318,8 +323,6 @@ def assign_split(example):

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_state()

if training_args.do_predict:
Expand All @@ -331,24 +334,38 @@ def assign_split(example):
)
trainer._load_from_checkpoint(training_args.output_dir)

test_results = trainer.predict(processed_dataset["test"])

person_ids = [row["person_id"] for row in processed_dataset["test"]]

if isinstance(test_results.predictions, np.ndarray):
predictions = np.squeeze(test_results.predictions).tolist()
else:
predictions = np.squeeze(test_results.predictions[0]).tolist()
if isinstance(test_results.label_ids, np.ndarray):
labels = np.squeeze(test_results.label_ids).tolist()
else:
labels = np.squeeze(test_results.label_ids[0]).tolist()

prediction_pd = pd.DataFrame({"person_id ": person_ids, "prediction": predictions, "label": labels})
prediction_pd.to_csv(os.path.join(training_args.output_dir, "test_predictions.csv"), index=False)
# Create the prediction folder if not exists
test_prediction_folder = os.path.join(training_args.output_dir, "test_predictions")
Path(test_prediction_folder).mkdir(parents=True, exist_ok=True)
test_dataloader = DataLoader(
dataset=processed_dataset["test"],
batch_size=training_args.per_device_eval_batch_size,
num_workers=training_args.dataloader_num_workers,
collate_fn=collator,
pin_memory=training_args.dataloader_pin_memory,
)
LOG.info(
"Started generating predictions for test set at %s",
test_prediction_folder,
)
for index, batch in enumerate(test_dataloader):
batched_person_ids = batch["person_id"]
batched_labels = batch["classifier_label"]
cehrbert_output = model(**batch)
cehrbert_output.logits
cehrbert_output.loss
prediction_pd = pd.DataFrame(
{"person_id ": batched_person_ids, "prediction": cehrbert_output.logits, "label": batched_labels}
)
prediction_pd.to_parquet(os.path.join(test_prediction_folder, f"{index}.parquet"))

LOG.info(
"Started computing metrics using the test set predictions at %s",
test_prediction_folder,
)
test_prediction_pd = pd.read_parquet(test_prediction_folder)
# Save results to JSON
metrics = compute_metrics(references=labels, logits=predictions)
metrics = compute_metrics(references=test_prediction_pd.label, logits=test_prediction_pd.prediction)
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
with open(test_results_path, "w") as f:
json.dump(metrics, f, indent=4)
Expand Down

0 comments on commit ab26c50

Please sign in to comment.