Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

top_k for multiclassf1score is not working correctly #1653

Open
eneserdo opened this issue Mar 25, 2023 · 4 comments
Open

top_k for multiclassf1score is not working correctly #1653

eneserdo opened this issue Mar 25, 2023 · 4 comments
Assignees
Labels
bug / fix Something isn't working good first issue Good for newcomers v0.11.x
Milestone

Comments

@eneserdo
Copy link

eneserdo commented Mar 25, 2023

πŸ› Bug

top_k argument of MulticlassF1Score is not working as expected. It suppose to give higher results as top_k increases, but that is not happening sometimes.

According to docs:

top_k (int) – Number of highest probability or logit score predictions considered to find the correct label.

So, it must increase strictly always.

Also normally, when top_k=num_classes, it is expected to give 1 (100%), but that's not happening either.

To Reproduce

Steps to reproduce the behavior...

Code sample
import torch
from torchmetrics.classification import MulticlassF1Score

preds = torch.randn(200, 5).softmax(dim=-1)
target = torch.randint(5, (200,))

f1_val_top1=MulticlassF1Score(num_classes=5, top_k=1, average="macro")
f1_val_top3=MulticlassF1Score(num_classes=5, top_k=3, average="macro")
f1_val_top5=MulticlassF1Score(num_classes=5, top_k=5, average="macro")

print(f1_val_top1(preds, target), f1_val_top3(preds, target), f1_val_top5(preds, target))  

It returns (tensor(0.1774), tensor(0.2740), tensor(0.3318))

AFAI understood from documentation, when I set top_k=5, it must give 1 because there only 5 classes anyway.

More explicitly, I was expected the following two to have the same output:

import torch, functorch
from torchmetrics.classification import MulticlassF1Score
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

f1_val_top3=MulticlassF1Score(num_classes=5, top_k=3, average="macro")
f1_val_top1=MulticlassF1Score(num_classes=5, top_k=1, average="macro")

pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
pred_top_1 = pred_top_3[:, 0]

# This simply changes the incorrect labels with the correct ones, only if correct guesses is in top 3 predictions 
pred_corrected_top3 = torch.where(functorch.vmap(lambda t1, t2: torch.isin(t1, t2))(target, pred_top_3), target, pred_top_1)

print(f1_val_top3(preds, target), f1_val_top1(pred_corrected_top3, target))

But result is different

Environment

  • TorchMetrics 0.11.3 (installed via pip):
  • Python 3.8.16
  • PyTorch 1.12.0
@eneserdo eneserdo added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 25, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@arijitde92
Copy link

Hi I am interested in solving this issue. Can I work on it?

@Lightning-AI Lightning-AI deleted a comment from stale bot Aug 25, 2023
@Borda
Copy link
Member

Borda commented Aug 25, 2023

@arijitde92 sorry for late reply, sure, you are welcome to take it πŸ’œ

@Borda Borda added this to the v1.1.x milestone Aug 25, 2023
@Borda Borda added the v0.11.x label Aug 25, 2023
@Borda Borda modified the milestones: v1.1.x, v1.2.x Sep 24, 2023
@Borda Borda modified the milestones: v1.2.x, v1.3.x Jan 11, 2024
@Borda Borda added the good first issue Good for newcomers label Aug 29, 2024
@rittik9
Copy link
Contributor

rittik9 commented Sep 10, 2024

@Borda pls assign it to me

@Borda Borda removed the help wanted Extra attention is needed label Sep 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working good first issue Good for newcomers v0.11.x
Projects
None yet
Development

No branches or pull requests

5 participants
@Borda @arijitde92 @eneserdo @rittik9 and others