Skip to content

Commit

Permalink
we only add a new value to the running stat if the value is between t…
Browse files Browse the repository at this point in the history
…he lower and upper bounds
  • Loading branch information
ChaoPang committed Sep 13, 2024
1 parent af2db54 commit 0a5a0ca
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def batched_generator():
current = fixed_stat
else:
current = agg_statistics(current, fixed_stat)

lab_stats = [
{
"concept_id": concept_id,
Expand Down
29 changes: 22 additions & 7 deletions src/cehrbert/utils/stat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,48 @@
class RunningStatistics(OnlineStatistics):
def __init__(self, capacity=100, value_outlier_std=2.0):
super().__init__()
self.value_outlier_std = value_outlier_std
self.excluding_outlier_online_statistics = ExcludingOutlierOnlineStatistics(
capacity=capacity, value_outlier_std=value_outlier_std
)

def _update_stats(self):
if self.count == 0:
self.current_mean = self.excluding_outlier_online_statistics.get_current_mean()
self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared()
self.count = self.excluding_outlier_online_statistics.get_count()

def add(self, weight: float, value: float) -> None:
if self.excluding_outlier_online_statistics.is_full():
super().add(weight, value)
std = self.standard_deviation()
if (
self.current_mean - self.value_outlier_std * std
<= value
<= self.current_mean + self.value_outlier_std * std
):
super().add(weight, value)
else:
self.excluding_outlier_online_statistics.add(value)
if self.excluding_outlier_online_statistics.is_full():
self.current_mean = self.excluding_outlier_online_statistics.get_current_mean()
self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared()
self.count = self.excluding_outlier_online_statistics.get_count()
self._update_stats()

def mean(self) -> float:
"""Return the current mean."""
if self.excluding_outlier_online_statistics.is_full():
return super().mean()
else:
self.excluding_outlier_online_statistics.get_current_mean()
return self.excluding_outlier_online_statistics.get_current_mean()

def standard_deviation(self) -> float:
"""Return the current standard devation."""
if self.excluding_outlier_online_statistics.is_full():
return super().standard_deviation()
else:
return self.excluding_outlier_online_statistics.standard_deviation()
return self.excluding_outlier_online_statistics.get_std()

def combine(self, other) -> None:
self._update_stats()
super().combine(other)


class ExcludingOutlierOnlineStatistics:
Expand Down Expand Up @@ -80,7 +95,7 @@ def get_sum_of_squared(self) -> float:
else:
raise ValueError(f"There is no value")

def standard_deviation(self) -> float:
def get_std(self) -> float:
self.update_remove_outliers()
if self.filtered_data:
return np.std(self.filtered_data)
Expand Down

0 comments on commit 0a5a0ca

Please sign in to comment.