Comparing different nested samplers with jaxns #58
Replies: 4 comments 3 replies
-
Hi @Joshuaalbert, Connected with the question above I have another technical one. As I mentioned I would like to run jaxns with some cosmology data. The likelihood evaluation requires to run a external simulator which has to be interfaced. The thing is that it is a C++ code (open-mp parallelized since it is expensive computationally) but can be run through python. The question is whether it would be possible to implement it in jaxns. Thanks !! |
Beta Was this translation helpful? Give feedback.
-
@rruizdeaustri on the first question: yes, those are stored in the results: # Create the nested sampler class. In this case without any tuning.
ns = NestedSampler(log_likelihood, prior_chain)
results = ns(random.PRNGKey(42))
# weights
weights = results.log_dp_mean
# samples (weighted by weights)
samples = results.samples
# to generate effective sample size equally weighted samples with replacement
from jaxns import resample
uniformly_weighted_samples = resample(random.PRNGKey(43), samples, weights, S=int(results.ESS), replace=True) To the second question. This is a really great and important question. Jaxns is built on JAX and from that we get good performance and compatibility with a number of machine learning projects which are also based on JAX. However, there will always be situations where likelihoods are computed through expensive means that require special software/hardware. There are two options here. Option 1) is to use from jax import disable_jit
with disable_jit():
ns = NestedSampler(log_likelihood, prior_chain)
results = ns(random.PRNGKey(42)) Option 2) is to take a data-driven approach and learn the likelihood from a data set of simulations using a universal function approximator, e.g. a neural network or Gaussian process, and then use this learned likelihood as a surrogate likelihood using JAX operands. This is not always satisfactory, but it can be really interesting for some applications. There is an exo-planet team who is also potentially doing this with Jaxns. If you are interested in this approach I'd love to be involved, because it's quite novel, so let me know. A potential solution I have been thinking about is to abstract out the backend so that JAX can be replaced with numpy and then everything would just work at Cython speed. This problem of non-JAX likelihoods is definitely coming up a lot in physics, biology, and othre sciences. |
Beta Was this translation helpful? Give feedback.
-
You might be interested in #59 |
Beta Was this translation helpful? Give feedback.
-
@rruizdeaustri release 1.2 is being worked on now which should make it possible for non-JAX likelihoods. |
Beta Was this translation helpful? Give feedback.
-
I'm trying to compare the results obtained from different Bayesian samplers based in nested sampling in the framework of physics (ie. cosmology) and some, say, nasty analytical likelihoods.
Because that I would like to somehow store the output chains (ie weights and coordinates of the model points) so I can plot them together with the output of the other samplers. Is that possible ?
Beta Was this translation helpful? Give feedback.
All reactions