-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multilabel metrics based on threshold should accept a threshold per label #2652
Comments
Hi @Yann-CV, thanks for creating this issue and sorry for not getting back to you earlier. import torch
from torchmetrics.classification import MultilabelConfusionMatrix
from torchmetrics import MetricCollection
target = torch.randint(0, 2, (10, 3))
preds = torch.rand(10, 3)
metric = MetricCollection({
f"mcm_{str(t).replace('.', '_')}": MultilabelConfusionMatrix(num_labels=3, threshold=t) for t in [0.1, 0.5, 0.9]
})
output = metric(preds, target)
print(output) it is a bit slower than if this was implemented directly into the metric but it should still work. Closing issue for now. |
Thanks @SkafteNicki |
🚀 Feature
Multilabel metrics converting the preds to 0 and 1 during computation (eg MultilabelStatScores, MultilabelConfusionMatrix) can use a threshold per label.
Motivation
In a multilabel classification task, one might want to compute the stats with different threshold per label allowing to be more or less sensitive label wise.
Pitch
conf_mat = MultilabelConfusionMatrix(num_labels=3, thresholds=torch.tensor([0.6, 0.5, 0.8])
Alternatives
Additional context
None
The text was updated successfully, but these errors were encountered: