Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalization of Weights within SMCFilter class #3382

Open
MarkusKramer1 opened this issue Jul 12, 2024 · 2 comments
Open

Normalization of Weights within SMCFilter class #3382

MarkusKramer1 opened this issue Jul 12, 2024 · 2 comments
Labels
enhancement help wanted Issues suitable for, and inviting external contributions

Comments

@MarkusKramer1
Copy link

MarkusKramer1 commented Jul 12, 2024

The weights reported by the get_emperical method in the SMCFilter class will usually not add up to one. I think it would be nice to have the option to directly get the normalized weights (maybe that should also be the default) to comply with the literature (see for example An Introduction to Sequential Monte Carlo , page 130).
The weights are actually normalized in this line. However, the result is saved in a local variable instead of updating the state weights variable. It seems like the values of the log_weights variable are only updated to be in a range between 0 and 1 using this command.
What is the reason behind not directly normalizing the weights variable?

The easiest fix for the issue would be to add an optional argument to the get_emperical function that if true normalizes the weights, e.g.:

def get_empirical(self, normalize_weights=True):
        """
        :param bool normalize_weights: If True, normalize the log weights before creating the empirical distribution.
        :returns: a marginal distribution over all state tensors.
        :rtype: a dictionary with keys which are latent variables and values
            which are :class:`~pyro.distributions.Empirical` objects.
        """
        if normalize_weights:
            # Normalize the log weights
            log_weights = self.state._log_weights - self.state._log_weights.logsumexp(-1)
        else:
            log_weights = self.state._log_weights

        return {
            key: dist.Empirical(value, log_weights)
            for key, value in self.state.items()
        }
@fritzo
Copy link
Member

fritzo commented Jul 14, 2024

Our general stance is against unnecessary normalization, in the statistics field in general. The logsumexp of the log_weights is a meaningful quantity, and can be used for a variety of tasks:

  • as a loss to backprop through
  • as an expert log weight in a mixture of experts model, e.g. an SMCFilter and a non-normalized Gaussian (as in our pyro.ops.Gaussian library or Funsor)
  • as a measure of goodness of fit

@fritzo fritzo closed this as completed Jul 14, 2024
@fritzo fritzo reopened this Jul 15, 2024
@fritzo
Copy link
Member

fritzo commented Jul 15, 2024

Sorry, I misread. I do think it's fine to implement a .normalized_weights() method or property, as long as we preserve the original unnormalized weights.

@fritzo fritzo added enhancement help wanted Issues suitable for, and inviting external contributions labels Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement help wanted Issues suitable for, and inviting external contributions
Projects
None yet
Development

No branches or pull requests

2 participants