From 0a5a0cab4af0b0305750de685cd3d1f3846d66a2 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 12 Sep 2024 17:13:09 -0400 Subject: [PATCH] we only add a new value to the running stat if the value is between the lower and upper bounds --- .../hf_models/tokenization_hf_cehrbert.py | 1 + src/cehrbert/utils/stat_utils.py | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py index 0a0b86b3..8ff9f76d 100644 --- a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py @@ -327,6 +327,7 @@ def batched_generator(): current = fixed_stat else: current = agg_statistics(current, fixed_stat) + lab_stats = [ { "concept_id": concept_id, diff --git a/src/cehrbert/utils/stat_utils.py b/src/cehrbert/utils/stat_utils.py index ba3d5eba..a7f8fad4 100644 --- a/src/cehrbert/utils/stat_utils.py +++ b/src/cehrbert/utils/stat_utils.py @@ -6,33 +6,48 @@ class RunningStatistics(OnlineStatistics): def __init__(self, capacity=100, value_outlier_std=2.0): super().__init__() + self.value_outlier_std = value_outlier_std self.excluding_outlier_online_statistics = ExcludingOutlierOnlineStatistics( capacity=capacity, value_outlier_std=value_outlier_std ) + def _update_stats(self): + if self.count == 0: + self.current_mean = self.excluding_outlier_online_statistics.get_current_mean() + self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared() + self.count = self.excluding_outlier_online_statistics.get_count() + def add(self, weight: float, value: float) -> None: if self.excluding_outlier_online_statistics.is_full(): - super().add(weight, value) + std = self.standard_deviation() + if ( + self.current_mean - self.value_outlier_std * std + <= value + <= self.current_mean + self.value_outlier_std * std + ): + super().add(weight, value) else: self.excluding_outlier_online_statistics.add(value) if self.excluding_outlier_online_statistics.is_full(): - self.current_mean = self.excluding_outlier_online_statistics.get_current_mean() - self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared() - self.count = self.excluding_outlier_online_statistics.get_count() + self._update_stats() def mean(self) -> float: """Return the current mean.""" if self.excluding_outlier_online_statistics.is_full(): return super().mean() else: - self.excluding_outlier_online_statistics.get_current_mean() + return self.excluding_outlier_online_statistics.get_current_mean() def standard_deviation(self) -> float: """Return the current standard devation.""" if self.excluding_outlier_online_statistics.is_full(): return super().standard_deviation() else: - return self.excluding_outlier_online_statistics.standard_deviation() + return self.excluding_outlier_online_statistics.get_std() + + def combine(self, other) -> None: + self._update_stats() + super().combine(other) class ExcludingOutlierOnlineStatistics: @@ -80,7 +95,7 @@ def get_sum_of_squared(self) -> float: else: raise ValueError(f"There is no value") - def standard_deviation(self) -> float: + def get_std(self) -> float: self.update_remove_outliers() if self.filtered_data: return np.std(self.filtered_data)