Skip to content

Commit

Permalink
fix test_bertscore_sorting bug + validate idf arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Guilherme Paulino-Passos @ DoC-cluster committed Sep 13, 2024
1 parent fa351e8 commit cc97f16
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def bert_score(
preds = list(preds)
if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call
target = list(target)
if not isinstance(idf, bool):
raise ValueError(f"The value of idf must be a boolean. Value passed:{idf=}")

if verbose and (not _TQDM_AVAILABLE):
raise ModuleNotFoundError(
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_bertscore_differentiability(
@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
"idf",
["idf"],
[(False,), (True,)],
)
def test_bertscore_sorting(idf: bool):
Expand Down

0 comments on commit cc97f16

Please sign in to comment.