diff --git a/metrics/aggregation.py b/metrics/aggregation.py index 57f75d5..a3c4e79 100644 --- a/metrics/aggregation.py +++ b/metrics/aggregation.py @@ -50,7 +50,8 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: value=state[:, 0], weight=state[:, 1], ) - return torch.stack([mean, weight_sum]) + merged_accumulated_mean = torch.stack([mean, weight_sum]) + return merged_accumulated_mean class StableMean(torchmetrics.Metric): @@ -94,4 +95,5 @@ def compute(self) -> torch.Tensor: """ Compute and return the accumulated mean. """ - return self.mean_and_weight_sum[0] + accumulated_mean = self.mean_and_weight_sum[0] + return accumulated_mean diff --git a/metrics/auroc.py b/metrics/auroc.py index 6979c20..d04e27d 100644 --- a/metrics/auroc.py +++ b/metrics/auroc.py @@ -159,4 +159,6 @@ def compute(self) -> torch.Tensor: ) # Compute auroc with the weight set to 1/2 when positive & negative have identical scores. - return auroc_le - (auroc_le - auroc_lt) / 2.0 + auroc = auroc_le - (auroc_le - auroc_lt) / 2.0 + return auroc + \ No newline at end of file diff --git a/metrics/rce.py b/metrics/rce.py index b6ada1d..be5dfa4 100644 --- a/metrics/rce.py +++ b/metrics/rce.py @@ -21,7 +21,8 @@ def _smooth( label_smoothing: smoothing constant. Returns: Smoothed values. """ - return value * (1.0 - label_smoothing) + 0.5 * label_smoothing + smoothed_values = value * (1.0 - label_smoothing) + 0.5 * label_smoothing + return smoothed_values def _binary_cross_entropy_with_clipping( @@ -178,7 +179,8 @@ def compute(self) -> torch.Tensor: pred_ce = self.binary_cross_entropy.compute() - return (1.0 - (pred_ce / baseline_ce)) * 100 + rce = (1.0 - (pred_ce / baseline_ce)) * 100 + return rce def reset(self): """