Skip to content

Commit

Permalink
Improve precision for mean, std, var.
Browse files Browse the repository at this point in the history
np.bincount always accumulates to float64.
So only cast after the division.
  • Loading branch information
dcherian committed Jul 26, 2024
1 parent 12405c2 commit a4c0c94
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,21 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
else:
sums = np.bincount(group_idx, weights=a, minlength=size).astype(
dtype, copy=False
)
sums = np.bincount(group_idx, weights=a, minlength=size)

with np.errstate(divide="ignore", invalid="ignore"):
ret = sums.astype(dtype, copy=False) / counts
ret = sums / counts
if not np.isnan(fill_value):
ret[counts == 0] = fill_value
return ret
return ret.astype(dtype, copy=False)


def _sum_of_squres(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
ret = np.bincount(group_idx, weights=a * a, minlength=size)
if fill_value != 0:
counts = np.bincount(group_idx, minlength=size)
ret[counts == 0] = fill_value
return ret
return ret.astype(dtype, copy=False)


def _var(
Expand All @@ -176,7 +174,7 @@ def _var(
counts = np.bincount(group_idx, minlength=size)
sums = np.bincount(group_idx, weights=a, minlength=size)
with np.errstate(divide="ignore", invalid="ignore"):
means = sums.astype(dtype, copy=False) / counts
means = sums / counts
counts = np.where(counts > ddof, counts - ddof, 0)
ret = (
np.bincount(group_idx, (a - means[group_idx]) ** 2, minlength=size) / counts
Expand All @@ -185,7 +183,7 @@ def _var(
ret = np.sqrt(ret) # this is now std not var
if not np.isnan(fill_value):
ret[counts == 0] = fill_value
return ret
return ret.astype(dtype, copy=False)


def _std(group_idx, a, size, fill_value, dtype=np.dtype(np.float64), ddof=0):
Expand Down

0 comments on commit a4c0c94

Please sign in to comment.