Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amortized point estimators #121

Open
marvinschmitt opened this issue Feb 7, 2024 · 32 comments
Open

Amortized point estimators #121

marvinschmitt opened this issue Feb 7, 2024 · 32 comments
Assignees
Labels
feature New feature or request

Comments

@marvinschmitt
Copy link
Member

The paper "Likelihood-free parameter estimation with neural Bayes estimators" (Sainsbury-Dale, Zammit-Mangion, & Huser, 2023) enables neural amortized point estimation, which is generally faster than fully Bayesian neural inference with conditional generative models (current standard in BayesFlow).

Implementing amortized point estimators in BayesFlow would help with fast prototyping and sanity checks (and surely other tasks, too!).

Some BayesFlow pointers/proposals:

  • Implement the point estimation features in a class AmortizedPointEstimator in bayesflow.amortizers
  • Implement appropriate loss functions in bayesflow.losses (depending on the target quantity)
  • Implement a suitable inference network in bayesflow.inference_networks which takes the summary network output (e.g., DeepSet outputs) and regresses to the point estimate
  • Summary networks are already implemented in bayesflow.summary_networks and probably work out-of-the-box

Some links:

@marvinschmitt marvinschmitt added the feature New feature or request label Feb 7, 2024
@Jice-Zeng
Copy link

Great proposals! I read the paper. But I am not sure if the method can generate posterior samples, which may be helpful to detect multi-modal distributions.

@stefanradev93
Copy link
Contributor

Yes, the method will not deal with multimodality out of the box. I think its main use case will be as a fast and frugal heuristic to obtain estimates very quickly.

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Feb 13, 2024 via email

@andrewzm
Copy link

andrewzm commented Feb 13, 2024 via email

@stefanradev93
Copy link
Contributor

stefanradev93 commented Feb 13, 2024

I like the idea too! We could use the point summary network or the point estimates themselves as summaries, with an auxiliary, smaller summary stream for capturing additional information relevant for the joint.

@andrewzm
Copy link

andrewzm commented Feb 14, 2024 via email

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Feb 14, 2024 via email

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Feb 14, 2024 via email

@ukoethe
Copy link

ukoethe commented Feb 14, 2024

I also think we need to do it in two steps:

  • train summary statistics jointly with a normalizing flow as usual
  • train point estimates (with an additional small network) from the frozen summary network output

Thus, only inference will potentially be faster, not training. I suspect that attemtps to train summary statistics directly via point estimates might result in poor summaries, because it will be harder for the summary network to capture the uncertainty.

@ukoethe
Copy link

ukoethe commented Feb 14, 2024

A luxury method to train point estimates would enhance the point estimator network with additional outputs for the (co)variance of the point estimate and the fraction of the full posterior explained by the MAP mode.

These outputs could be trained by destillation from the SBI posterior: (1) Find the MAP mode in the NF posterior output => target value of the MAP point estimate. (2) Fit a Gaussian to the MAP mode (=Laplace approximation) to get the (co)variance and fraction of probability mass explained by the MAP mode.

The destillation approach has the additional benefit that the point estimator would return the true MAP estimates (subject to the accuracy of the normalizing flow) instead of the ground truth parameterizations of the training runs.

@stefanradev93
Copy link
Contributor

stefanradev93 commented Feb 14, 2024

Summarizing the ideas so far:

  1. Train point estimators instead of generative networks - simply attaches a configurable MLP to the summary network and uses some norm (L1, L2, L-Infinity, https://www.tensorflow.org/api_docs/python/tf/norm) as a loss.
  • Advantage: fast and pleasant to iterate over for larger problems.
  • Disadvantage: no full Bayesian inference (by design).
  1. Train point estimators and use pre-trained summary network in tandem with a full generative network - uses a frozen summary network, possibly in combination with another, smaller and trainable summary network, to train a full workflow.
  • Advantage: the first summary network is trained fast and already captures some important information. May be equivalent to good hand-crafted summary statistics.
  • Disadvantage: summary statistics may not be good, speed gains diminished due to two-step approach.
  1. Train the default approach and distil it using a point estimator - trains the full workflow (i.e., summary and generative net) and attaches a configurable MLP to the pretrained summary network, possibly learning a summary of the approximate posterior (e.g., MAP).
  • Advantage: the summary statistics will already be good for learning the full posterior (assuming everything has converged just fine), so they will also be good for the point estimator. Point estimator can then be trained hyperefficiently.
  • Disadvantage: speed gains diminished due to two-step approach, the utility of the point estimates is unclear if full workflow is already good.

Crucially, all three are enabled, once the functionality is there. Does anyone want to implement the blueprints Marvin suggested in the first post? Otherwise, I would self-assign in a week or two.

@andrewzm
Copy link

andrewzm commented Feb 17, 2024 via email

@stefanradev93 stefanradev93 self-assigned this Feb 26, 2024
@stefanradev93
Copy link
Contributor

Hi all, I have implemented and tested a first working version of an amortized point estimator, which can be found in the Development branch.

An example use case is:

import bayesflow as bf
from bayesflow.benchmarks import Benchmark

# Trivial Gaussian benchmark or a real model
benchmark = Benchmark('gaussian_linear', mode='posterior')

# An easy-to-use MLP with residual connections
inference_net = bf.helper_networks.ConfigurableMLP(input_dim=10, dropout_rate=0.05)

# Can be any norm in [1, 2, np.inf] and can also use a summary network
amortizer = bf.amortizers.AmortizedPointEstimator(inference_net, norm_ord=2)

# Training can happen as usual
trainer = bf.trainers.Trainer(amortizer=amortizer, configurator=benchmark.configurator)
data = benchmark.generative_model(5000)
h = trainer.train_offline(data, epochs=50, batch_size=32)

# Quick point estimates can be obtained by simply calling the .estimate() method
test_data = benchmark.configurator(benchmark.generative_model(100))
estimates = amortizer.estimate(test_data)

Let me know what you think.

@marvinschmitt
Copy link
Member Author

Great, thanks! Just adding two pointers to Stefan's post:

@paul-buerkner
Copy link
Contributor

The workflow looks good to me. How would you generalize the interface to different loss function for different point estimates? The norm_ord is perhaps a bit restrictive an interface for that(?)

@stefanradev93
Copy link
Contributor

The workflow looks good to me. How would you generalize the interface to different loss function for different point estimates? The norm_ord is perhaps a bit restrictive an interface for that(?)

The current AmortizedPointEstimator supports a loss_fun argument which takes precedence over norm_ord when provided and allows for generic loss functions. We may want to allow some string arguments which default to useful estimators that are not captured by the simple norm_ord or provide a more semantic interface (e.g., loss_fun="mean").

@Jice-Zeng
Copy link

Great job! Does the current workflow support the uncertainty quantification for estimated parameters? Do we need apply bootstrap for UQ just like what the reference paper did?

@stefanradev93
Copy link
Contributor

You would currently need bootstrap or Monte Carlo dropout for UQ, unless a custom loss_fun already entails UQ.

@vpratz
Copy link
Collaborator

vpratz commented Mar 9, 2024

The interface looks good to me as well. @han-ol and I implemented the illustrative example from the paper (see #145) and the results look sensible. Feel free to take a look and to improve the notebook, maybe it can serve as the foundation of a more detailed tutorial notebook on the method.

@Jice-Zeng
Copy link

The interface looks good to me as well. @han-ol and I implemented the illustrative example from the paper (see #145) and the results look sensible. Feel free to take a look and to improve the notebook, maybe it can serve as the foundation of a more detailed tutorial notebook on the method.

Can I ask the link of this notebook? I can look into and learn it.
Thanks!

@marvinschmitt
Copy link
Member Author

The interface looks good to me as well. @han-ol and I implemented the illustrative example from the paper (see #145) and the results look sensible. Feel free to take a look and to improve the notebook, maybe it can serve as the foundation of a more detailed tutorial notebook on the method.

Can I ask the link of this notebook? I can look into and learn it.

Thanks!

https://github.com/stefanradev93/BayesFlow/blob/Development/examples/Amortized_Point_Estimation.ipynb

@han-ol
Copy link
Contributor

han-ol commented Mar 11, 2024

Hi all, the current interface supports loss functions of the form $\mathcal L=f(\hat \theta - \theta)$. This covers many of the interesting loss functions. However, my understanding is

cannot abide this form.

To support losses that are not simply functions of the difference $\hat \theta - \theta$ I suggest we change the current signature of the custom loss_fun from loss_fun(net_output - paramters) to loss_fun(net_output, paramters).

The implementation replaces the tf.norm with a wrapper that takes two arguments, so the convenient functionality of just choosing norm_ord stays untouched and the illustrative notebook runs just as before. Not sure if we want this wrapper and if it is well placed in losses.py.

Let me know what you think!

@andrewzm
Copy link

Sorry for the delay -- this is really nice, thanks @stefanradev93 for the implementation and @vpratz @han-ol for the easy-to-follow notebook, very cool!

I think this is already great as is, from a user point of view I agree with @han-ol that it would be helpful to have a generic loss interface and possibly also some "template" losses one could use. One possibility is to provide some template losses (e.g., normed_difference, quantile_loss, etc.) that take in net_output and parameters and then have the additional arguments specific to that loss (e.g., power of the normed difference, quantile to target, etc.). Users can of course specify their own losses but the templates would cover 95% of use cases.

The last thing which as a user will be really handy is a function like amortizer.bootstrap() that can do the bootstrap for you. The easiest would be parametric bootstrap where the amortizer.estimate() is run with the obs data to give you the estimate, the simulator is run at that estimate to give you an N-sized bootstrap sample (assuming the simulator is available and that the training data are not fixed beforehand), and then amortizer.estimate() is run again on that sample. This would be useful for all loss functions... e.g., one might want to do get a bootstrap sample of 95% credible intervals.

Anyway, really nice stuff, very happy to see it implemented in BayesFlow!

@marvinschmitt
Copy link
Member Author

Agreed, great job everyone!

@andrewzm

and possibly also some "template" losses one could use. [...] the templates would cover 95% of use cases.

Based on your expertise and experience with amortized point estimators, would you mind compiling a list of template losses that might cover the vast majority of users' needs?

@andrewzm
Copy link

I would say some "experience" not "expert", but I'm fairly confident that the vast majority of cases will be covered by one of the following losses:

  1. Normed difference loss (with a user-chosen value for the exponent, default to 2).
  2. Quantile loss (where the user has to specify the quantile $q$) -- we tested this out in a neural-network setting in Sainsbury-Dale et al. (2024) and it works very well.

Other useful losses include:

  1. Weighted squared-error loss: $L(\theta, \hat\theta(Z); w(\cdot)) = w(\theta)(\hat\theta(Z) - \theta)^2$ where $w(\theta)$ is a user-defined weighting function of $\theta$, e.g., $w(\theta) = \frac{1}{\theta^2}$ (Robert, 2007, Corollary 2.5.2).
  2. The loss function $L(\theta, \hat\theta(Z); \tau) = ( \hat\theta(Z) - \theta^\tau)^2$. Such a loss function would be useful for deriving higher order moments of the posterior distribution. For example, setting $\tau = 2$, combining this with an estimator for the posterior mean, one can then compute the posterior variance
    $$\textrm{var}(\theta \mid Z) = \textrm{E}(\theta^2 \mid Z) - \textrm{E}(\theta \mid Z)^2.$$
    This requires training two networks though, one for the posterior expectation of the parameter and one for the posterior expectation of the square of the parameter... this can lead to negative values when estimating the variance in this way. If the loss can be made to ingest outputs from another neural network, then one could use the more "robust" formulation of Fan & Yao (1998) who propose the loss:
    $$L(\theta, \hat\theta(Z)) = ((\theta - \textrm{E}(\theta \mid Z))^2 - \hat\theta(Z))^2.$$
    for which the Bayes estimator $\hat\theta(Z) = \textrm{var}(\theta \mid Z)$ (so positivity can be easily forced through the neural network architecture).

There are other losses one might consider, e.g., Stein's loss for covariance matrices which might be useful (one would need to go through the Cholesky factor to ensure positive definiteness of $\hat\Sigma$ in the NN architecture) and Huber losses, although I don't know without further reading what Bayes estimators these would lead to. For starters I would put in the normed difference, quantile loss and the weighted squared-error loss.

Fan, J., & Yao, Q. (1998). Efficient estimation of conditional variance functions in stochastic regression. Biometrika, 85(3), 645-660.
Robert, C. P. (2007). The Bayesian Choice. New York: Springer.
Sainsbury-Dale, M., Richards, J., Zammit-Mangion, A., & Huser, R. (2023). Neural Bayes estimators for irregular spatial data using graph neural networks. arXiv preprint arXiv:2310.02600.

@marvinschmitt
Copy link
Member Author

Since this issue contains past discussions and involved people are watching it:

How do we go about implementing amortized point estimators in the new BayesFlow 2.0 release?

@paul-buerkner
Copy link
Contributor

Naming wise, I would suggest, we will implement a PointApproximator class in the approximator module. I don't know how easily portable the existing code is, but perhaps not too difficult? Deferring to @stefanradev93 and @LarsKue to judge this.

@han-ol
Copy link
Contributor

han-ol commented Oct 1, 2024

While the most common usage of this would be to estimate points, it could also be worthwhile to (A) choose a more general name or (B) split it out in different classes. From a technical point of view, this issue relaxes the definition of Approximator for all situations where we don't represent a whole distribution with a generative network, but not all of these cases are strictly point estimators.

I couldn't think of a nice general name that is not already taken (Approximator).

So what do you think about having a

  • Approximator(could also be called DistributionApproximator for specificity, or we keep it short as is),
  • PointApproximator (feed forward net with 1 output node),
  • IntervalApproximator (feed forward net with 2 output nodes, ), and
  • QuantileApproximator (feed forward net with any number of output nodes, quantile levels are selected by loss fun or with convenient constructor)

?

I would expect that when we eventually get to the diagnostics API, it will help if the different approximator outputs correspond to well specified classes. (I have a implementation sketch of a calibration diagnostics for the QuantileApproximator for example.)
Also parametric bootstrap* could be a method applicable to PointApproximator but not to IntervalApproximator, etc.

(*Point estimate each observed dataset and resimulate a few samples corresponding to its point estimate, then apply the point estimator to get some UQ for the points. This captures stochasticity in the likelihood that can lead to point estimation variance.)

@andrewzm
Copy link

andrewzm commented Oct 2, 2024 via email

@paul-buerkner
Copy link
Contributor

In the dev version, we call the main approximator ContinuousApproximator because we approximate a full continuous distribution. For a user interface, it would be better, I think, to have one PointEstimator class to handle all the point estimates, i.e. measures of central tendency such as means or medians, measures of variation such as SD or MAD, and quantiles. Often enough I want to estimate several of them at the same time, say a mean and a bunch of quantiles. If I had different approximators for these different kinds of point estimators I wouldn't be able to learn them efficiently together.

@han-ol
Copy link
Contributor

han-ol commented Oct 2, 2024

We could, in principle, also have a ApproximatorCollection class that bundles multiple of the feed forward approximators together and owns shared weights for them. You could pass a list of the point estimates you want, train them together, and the code can internally distinguish the different kinds of point estimates for bootstrap and plotting.

@andrewzm right, interesting point regarding bootstrap on intervals. I guess it still needs custom code to deal with the differing shapes and select which parameters to plug in the simulator, so for that to be handled without the user writing some configuration code for a hook in a parametric_bootstrap function, we would need well specified IntervalApproximator, etc... classes.
For interval approximation specifically, would you plug both of the interval boundaries in the simulator or sample from the interval? Is there an established way of doing this?

Stepping back: I don't know whether this ApproximatorCollection thing would be feature creep and we thus should focus on a little less tidy, but highly flexible PointApproximator class that the user customizes themselves. I just want to raise that, I think, if we have a most general PointApproximator class the implementation of things like automatic diagnostic plotting and bootstrap gets a little messy and will probably require "expert" knowledge by the user.

@andrewzm
Copy link

andrewzm commented Oct 3, 2024

@han-ol Assuming we're talking about a parametric bootstrap setting, I think what would be useful to the user is getting a feel for the (frequentist) validity of these quantile intervals around a notional parameter value related to the available data, which could be the posterior mean or median. The bootstrapped intervals could be used to quantify the expected frequentist coverage probability, by seeing the proportion of times the "notional" value falls into the bootstrapped intervals. This is similar to the procedure Hermans et al. describes in "A Trust Crisis In Simulation-Based Inference? Your Posterior Approximations Can Be Unfaithful" (https://arxiv.org/pdf/2110.06581) in Section 2.2 but for a notional value of theta (e.g., the posterior mean) instead of averaging over the whole parameter space. From a software point you're right in that things are bit different, but not much -- with point estimators the notional value will be the estimate itself, while in the interval case it would need to be specified (but it could also be parameter value in the middle of the interval as a default.. crude but probably OK as a default).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

8 participants