Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilabel metrics based on threshold should accept a threshold per label #2652

Closed
Yann-CV opened this issue Jul 23, 2024 · 2 comments
Closed
Labels
enhancement New feature or request

Comments

@Yann-CV
Copy link

Yann-CV commented Jul 23, 2024

🚀 Feature

Multilabel metrics converting the preds to 0 and 1 during computation (eg MultilabelStatScores, MultilabelConfusionMatrix) can use a threshold per label.

Motivation

In a multilabel classification task, one might want to compute the stats with different threshold per label allowing to be more or less sensitive label wise.

Pitch

conf_mat = MultilabelConfusionMatrix(num_labels=3, thresholds=torch.tensor([0.6, 0.5, 0.8])

Alternatives

thresholds=torch.tensor([0.6, 0.5, 0.8])
conf_mat = MultilabelConfusionMatrix(num_labels=3)
for idx, t in enumerate(thresholds):
     preds[:, idx] = preds >= t

conf_mat(target, preds)

Additional context

None

@Yann-CV Yann-CV added the enhancement New feature or request label Jul 23, 2024
@SkafteNicki
Copy link
Member

Hi @Yann-CV, thanks for creating this issue and sorry for not getting back to you earlier.
There was a similar issue to this opened around the same time #2612 that requested that multilabel metrics accepted a list of thresholds and all labels would be evaluated against all thresholds. Not the same you are asking for but somewhat related.
By the same argument as in that PR the proposed change would be a significant change to the codebase and we are therefore not going to support it in the near future. If this get requested in the future we may reconsider. Also this can be achieved using a MetricCollection to essentially create a single "metric" that can do this:

import torch
from torchmetrics.classification import MultilabelConfusionMatrix
from torchmetrics import MetricCollection

target = torch.randint(0, 2, (10, 3))
preds = torch.rand(10, 3)

metric = MetricCollection({
    f"mcm_{str(t).replace('.', '_')}": MultilabelConfusionMatrix(num_labels=3, threshold=t) for t in [0.1, 0.5, 0.9]
})
output = metric(preds, target)
print(output)

it is a bit slower than if this was implemented directly into the metric but it should still work. Closing issue for now.

@Yann-CV
Copy link
Author

Yann-CV commented Sep 18, 2024

Thanks @SkafteNicki

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants