diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 171e7196..5a1c27fa 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -5,6 +5,7 @@ from .accuracy import Accuracy from .ava_map import AVAMeanAP from .bleu import BLEU +from .char_recall_precision import CharRecallPrecision from .coco_detection import COCODetection from .connectivity_error import ConnectivityError from .dota_map import DOTAMeanAP @@ -36,7 +37,7 @@ 'StructuralSimilarity', 'SignalNoiseRatio', 'MultiLabelMetric', 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', - 'ConnectivityError', 'ROUGE' + 'ConnectivityError', 'ROUGE', 'CharRecallPrecision' ] _deprecated_msg = ( diff --git a/mmeval/metrics/char_recall_precision.py b/mmeval/metrics/char_recall_precision.py new file mode 100644 index 00000000..85dfa6e5 --- /dev/null +++ b/mmeval/metrics/char_recall_precision.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from difflib import SequenceMatcher +from typing import Dict, Sequence, Tuple + +from mmeval.core import BaseMetric + + +class CharRecallPrecision(BaseMetric): + """Calculate the char level recall & precision. + + Args: + letter_case (str): There are three options to alter the letter cases + - unchanged: Do not change prediction texts and labels. + - upper: Convert prediction texts and labels into uppercase + characters. + - lower: Convert prediction texts and labels into lowercase + characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. + valid_symbol (str): Valid characters. Defaults to + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + + Examples: + >>> from mmeval import CharRecallPrecision + >>> metric = CharRecallPrecision() + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'char_recall': 0.6, 'char_precision': 0.8571428571428571} + >>> metric = CharRecallPrecision(letter_case='upper') + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'char_recall': 0.7, 'char_precision': 1.0} + """ + + def __init__(self, + letter_case: str = 'unchanged', + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + **kwargs): + super().__init__(**kwargs) + assert letter_case in ['unchanged', 'upper', 'lower'] + self.letter_case = letter_case + self.valid_symbol = re.compile(valid_symbol) + + def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data and predictions. + + Args: + predictions (list[str]): The prediction texts. + labels (list[str]): The ground truth texts. + """ + for pred, label in zip(predictions, labels): + if self.letter_case in ['upper', 'lower']: + pred = getattr(pred, self.letter_case)() + label = getattr(label, self.letter_case)() + label_ignore = self.valid_symbol.sub('', label) + pred_ignore = self.valid_symbol.sub('', pred) + # number to calculate char level recall & precision + true_positive_char_num = self._cal_true_positive_char( + pred_ignore, label_ignore) + self._results.append( + (len(label_ignore), len(pred_ignore), true_positive_char_num)) + + def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[tuple]): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the + metrics, and the values are corresponding results. + """ + gt_sum, pred_sum, true_positive_sum = 0.0, 0.0, 0.0 + for gt, pred, true_positive in results: + gt_sum += gt + pred_sum += pred + true_positive_sum += true_positive + char_recall = true_positive_sum / max(gt_sum, 1.0) + char_precision = true_positive_sum / max(pred_sum, 1.0) + eval_res = {} + eval_res['recall'] = char_recall + eval_res['precision'] = char_precision + return eval_res + + def _cal_true_positive_char(self, pred: str, gt: str) -> int: + """Calculate correct character number in prediction. + + Args: + pred (str): Prediction text. + gt (str): Ground truth text. + + Returns: + true_positive_char_num (int): The true positive number. + """ + + all_opt = SequenceMatcher(None, pred, gt) + true_positive_char_num = 0 + for opt, _, _, s2, e2 in all_opt.get_opcodes(): + if opt == 'equal': + true_positive_char_num += (e2 - s2) + else: + pass + return true_positive_char_num diff --git a/requirements/optional.txt b/requirements/optional.txt index 1bd0bbd5..0375d5f2 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ +difflib opencv-python!=4.5.5.62,!=4.5.5.64 pycocotools scipy diff --git a/tests/test_metrics/test_char_recall_precision.py b/tests/test_metrics/test_char_recall_precision.py new file mode 100644 index 00000000..d5f34a45 --- /dev/null +++ b/tests/test_metrics/test_char_recall_precision.py @@ -0,0 +1,25 @@ +import pytest + +from mmeval import CharRecallPrecision + + +def test_init(): + with pytest.raises(AssertionError): + CharRecallPrecision(letter_case='fake') + + +def test_char_recall_precision_metric(): + metric = CharRecallPrecision(letter_case='lower') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['recall'] - 0.7) < 1e-7 + assert abs(res['precision'] - 1) < 1e-7 + + metric = CharRecallPrecision(letter_case='upper') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['recall'] - 0.7) < 1e-7 + assert abs(res['precision'] - 1) < 1e-7 + + metric = CharRecallPrecision(letter_case='unchanged') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['recall'] - 0.6) < 1e-7 + assert abs(res['precision'] - 6.0 / 7) < 1e-7