Skip to content

Commit

Permalink
Detectron model health (#664)
Browse files Browse the repository at this point in the history
* add model health calc and example notebook

* move notebook

* move notebook

* fix transforms

* add perf metrics to notebook

---------

Co-authored-by: Amrit Krishnan <[email protected]>
  • Loading branch information
a-kore and amrit110 authored Jul 17, 2024
1 parent f391ebf commit 5d768cb
Show file tree
Hide file tree
Showing 2 changed files with 935 additions and 91 deletions.
269 changes: 178 additions & 91 deletions cyclops/monitor/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
import pandas as pd
import sklearn
from datasets import Dataset, DatasetDict, concatenate_datasets
from datasets.utils.logging import disable_progress_bar
from scipy.special import expit as sigmoid
from scipy.special import softmax
from sklearn.base import BaseEstimator

from cyclops.data.transforms import Lambdad
from cyclops.data.utils import apply_transforms
from cyclops.models.catalog import wrap_model
from cyclops.models.utils import is_pytorch_model, is_sklearn_model
from cyclops.models.wrappers import PTModel, SKModel
from cyclops.monitor.utils import DetectronModule, DummyCriterion, get_args
from cyclops.utils.optional import import_optional_module


disable_progress_bar()


if TYPE_CHECKING:
import torch
from alibi_detect.cd import (
Expand Down Expand Up @@ -705,33 +711,55 @@ def __init__(
self.model = base_model
else:
self.model = model
if isinstance(base_model, nn.Module):
if is_pytorch_model(base_model):
self.base_model = wrap_model(
base_model,
batch_size=batch_size,
)
self.base_model.initialize()
else:
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
if transforms:
self.transforms = partial(apply_transforms, transforms=transforms)
model_transforms = transforms
model_transforms.transforms = model_transforms.transforms + (
Lambdad(
keys=("mask", "labels"),
func=lambda x: np.array(x),
allow_missing_keys=True,
),
)
self.model_transforms = partial(
apply_transforms,
transforms=model_transforms,
)
else:
self.transforms = None
self.model_transforms = None
elif is_sklearn_model(base_model):
self.base_model = wrap_model(base_model)
self.base_model.initialize()
self.feature_column = feature_column
if transforms:
self.transforms = partial(apply_transforms, transforms=transforms)
model_transforms = transforms
model_transforms.transforms = model_transforms.transforms + (
Lambdad(
keys=("mask", "labels"),
func=lambda x: np.array(x),
allow_missing_keys=True,
),
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
self.model_transforms = partial(
apply_transforms,
transforms=model_transforms,
self.transforms = transforms
self.model_transforms = transforms
elif isinstance(base_model, SKModel):
self.base_model = base_model
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
else:
self.transforms = transforms
self.model_transforms = transforms
elif isinstance(base_model, PTModel):
self.base_model = base_model
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
else:
raise ValueError("base_model must be a PyTorch or sklearn model.")

self.feature_column = feature_column
self.splits_mapping = splits_mapping
self.num_runs = num_runs
self.sample_size = sample_size
Expand All @@ -741,8 +769,7 @@ def __init__(
self.lr = lr
self.num_workers = num_workers
self.task = task
if save_dir is None:
self.save_dir = "detectron"
self.save_dir = "detectron" if save_dir is None else save_dir

self.fit(X_s)

Expand All @@ -759,24 +786,35 @@ def fit(self, X_s: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
for seed in range(self.num_runs):
# train ensemble of for split 'p*'
for e in range(1, self.ensemble_size + 1):
if is_pytorch_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
elif is_sklearn_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
alpha = 1 / (len(X_s) * self.sample_size + 1)
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
if is_pytorch_model(self.base_model.model):
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
model.initialize()
elif is_sklearn_model(self.base_model.model):
model = self.base_model
if isinstance(X_s, (Dataset, DatasetDict)):
# create p/p* splits

p = (
X_s[self.splits_mapping["train"]]
.shuffle()
Expand Down Expand Up @@ -808,26 +846,39 @@ def fit(self, X_s: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
np.array(pstar_pseudolabels),
)
pstar = pstar.add_column("labels", pstar_pseudolabels.tolist())
if is_sklearn_model(self.base_model.model):
pstar = pstar.map(
lambda x: x.update({"labels": int(1 - x["labels"])})
)

p_pstar = concatenate_datasets([p, pstar], axis=0)
p_pstar = p_pstar.train_test_split(test_size=0.5, shuffle=True)

train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_pstar,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)
if is_pytorch_model(self.base_model.model):
train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_pstar,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)
model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
log=False,
)
elif is_sklearn_model(self.base_model.model):
model.fit(
X=p_pstar,
feature_columns=self.feature_column,
target_columns="labels",
transforms=self.model_transforms,
)

model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
)
pstar_logits = model.predict(
X=pstar,
feature_columns=self.feature_column,
Expand Down Expand Up @@ -862,22 +913,33 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
for seed in range(self.num_runs):
# train ensemble of for split 'p*'
for e in range(1, self.ensemble_size + 1):
if is_pytorch_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
elif is_sklearn_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
alpha = 1 / (len(X_t) * self.sample_size + 1)
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
model.initialize()
if is_pytorch_model(self.base_model.model):
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
model.initialize()
elif is_sklearn_model(self.base_model.model):
model = self.base_model
if isinstance(X_t, (Dataset, DatasetDict)):
# create p/q splits
p = (
Expand Down Expand Up @@ -908,24 +970,36 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
)
q_pseudolabels = self.format_pseudolabels(np.array(q_pseudolabels))
q = q.add_column("labels", q_pseudolabels.tolist())
if is_sklearn_model(self.base_model.model):
q = q.map(lambda x: x.update({"labels": int(1 - x["labels"])}))
p_q = concatenate_datasets([p, q], axis=0)
p_q = p_q.train_test_split(test_size=0.5, shuffle=True)
train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_q,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)

model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
)
if is_pytorch_model(self.base_model.model):
train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_q,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)

model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
log=False,
)
elif is_sklearn_model(self.base_model.model):
model.fit(
X=p_q,
feature_columns=self.feature_column,
target_columns="labels",
transforms=self.model_transforms,
)
q_logits = model.predict(
X=q,
feature_columns=self.feature_column,
Expand All @@ -950,18 +1024,21 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):

def format_pseudolabels(self, labels):
"""Format pseudolabels."""
if self.task in ("binary", "multilabel"):
labels = (
(labels > 0.5).astype("float32")
if ((labels <= 1).all() and (labels >= 0).all())
else (sigmoid(labels) > 0.5).astype("float32")
)
elif self.task == "multiclass":
labels = (
labels.argmax(dim=-1)
if np.isclose(labels.sum(axis=-1), 1).all()
else softmax(labels, axis=-1).argmax(axis=-1)
)
if is_pytorch_model(self.base_model.model):
if self.task in ("binary", "multilabel"):
labels = (
(labels > 0.5).astype("float32")
if ((labels <= 1).all() and (labels >= 0).all())
else (sigmoid(labels) > 0.5).astype("float32")
)
elif self.task == "multiclass":
labels = (
labels.argmax(dim=-1)
if np.isclose(labels.sum(axis=-1), 1).all()
else softmax(labels, axis=-1).argmax(axis=-1)
)
elif is_sklearn_model(self.base_model.model):
return labels
else:
raise ValueError(
f"Task must be either 'binary', 'multiclass' or 'multilabel', got {self.task} instead.",
Expand Down Expand Up @@ -1015,15 +1092,25 @@ def get_results(self, max_ensemble_size=None) -> float:
test_count = self.counts("test", max_ensemble_size)[0]
cdf = self.ecdf(cal_counts)
p_value = cdf(test_count)
self.model_health = self.get_model_health(max_ensemble_size)
return {
"data": {
"model_health": self.model_health,
"p_val": p_value,
"distance": test_count,
"cal_record": self.cal_record,
"test_record": self.test_record,
},
}

def get_model_health(self, max_ensemble_size=None) -> float:
"""Get model health."""
self.cal_counts = self.counts("calibration", max_ensemble_size)
self.test_count = self.counts("test", max_ensemble_size)[0]
self.baseline = self.cal_counts.mean()
self.model_health = self.test_count / self.baseline
return min(1, self.model_health)

@staticmethod
def split_dataset(X: Union[Dataset, DatasetDict]) -> DatasetDict:
"""Split dataset into train and test splits."""
Expand Down
Loading

0 comments on commit 5d768cb

Please sign in to comment.