Skip to content

Commit

Permalink
add reparameterization trick to loss
Browse files Browse the repository at this point in the history
  • Loading branch information
gyoge0 committed Jul 26, 2024
1 parent c1d01eb commit 3459e1f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sparkle_stats/training/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch


def likelihood_loss(predictions, labels, epsilon=1e-5):
def likelihood_loss(predictions, labels, backing_loss_fn, epsilon=1e-5):
# labels: B, 2*classes
# predictions: B, classes
assert (
Expand Down Expand Up @@ -32,5 +32,6 @@ def likelihood_loss(predictions, labels, epsilon=1e-5):
means = predictions[:, :, 0]
variances = torch.abs(predictions[:, :, 1]) + epsilon
dist = torch.distributions.Normal(means, variances)
probs = dist.log_prob(labels)
return probs.sum(axis=1).mean()
# reparameterization trick
sample = dist.rsample()
return backing_loss_fn(sample, labels)

0 comments on commit 3459e1f

Please sign in to comment.