Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Aug 9, 2023
1 parent b3cf5b6 commit 28f5c25
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 77 deletions.
3 changes: 1 addition & 2 deletions examples/imaging_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,12 @@ end

function instrument(θ, metadata)
(; lgamp, gphase) = θ
(; gcache, gcachep,) = metadata
(; gcache, gcachep) = metadata
## Now form our instrument model
gvis = exp.(lgamp)
gphase = exp.(1im.*gphase)
jgamp = jonesStokes(gvis, gcache)
jgphase = jonesStokes(gphase, gcachep)
# jgphase0= jonesStokes(gphase0, gcachep0)
return JonesModel(jgamp*jgphase)
end

Expand Down
1 change: 0 additions & 1 deletion ext/ComradePyehtimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ function getcoherency(obs)
"Do not use\n obs.switch_polrep(\"circ\")\nsince missing hands will not be handled correctly."
)

obs = obs.switch_polrep("circ")

# get (u,v) coordinates
u = pyconvert(Vector, obs.data["u"])
Expand Down
154 changes: 80 additions & 74 deletions playground/imaging_vis_standardize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
# instrument effects, such as time variable gains.

# To get started we load Comrade.
using Comrade


using Pkg #hide
Pkg.activate(joinpath(@__DIR__, "..", "examples")) #hide
using Comrade
Pkg.activate(joinpath(dirname(pathof(Comrade)), "..", "examples")) #hide
#-

using Pyehtim
using LinearAlgebra

# For reproducibility we use a stable random number genreator
using StableRNGs
rng = StableRNG(124)
rng = StableRNG(42)



Expand All @@ -34,7 +35,7 @@ obs = ehtim.obsdata.load_uvfits(joinpath(dirname(pathof(Comrade)), "..", "exampl
# - Scan average the data since the data have been preprocessed so that the gain phases
# coherent.
# - Add 1% systematic noise to deal with calibration issues that cause 1% non-closing errors.
obs = scan_average(obs.add_fractional_noise(0.02))
obs = scan_average(obs.add_fractional_noise(0.01))

# Now we extract our complex visibilities.
dvis = extract_table(obs, ComplexVisibilities())
Expand All @@ -48,11 +49,11 @@ dvis = extract_table(obs, ComplexVisibilities())


function sky(θ, metadata)
(;fg, c, λ, σ, ν) = θ
(;trf, meanpr, ftot, grid, cache) = metadata
(;fg, c, σimg, λ, ν) = θ
(;ftot, trf, K, meanpr, grid, cache) = metadata
## Construct the image model we fix the flux to 0.6 Jy in this case
cp = trf(c, meanpr, σ, λ, ν)
rast = (ftot*(1-fg))*(to_simplex(CenteredLR(), cp))
cp = trf(c, meanpr, σimg, λ, ν)
rast = (ftot*(1-fg))*K(to_simplex(CenteredLR(), cp))
img = IntensityMap(rast, grid)
m = ContinuousImage(img, cache)
g = modify(Gaussian(), Stretch(μas2rad(250.0), μas2rad(250.0)), Renormalize(ftot*fg))
Expand All @@ -71,9 +72,9 @@ function instrument(θ, metadata)
(; gcache, gcachep) = metadata
## Now form our instrument model
gvis = exp.(lgamp)
gphase = exp.(1im.*gphase)
gphase = 1im.*gphase
jgamp = jonesStokes(gvis, gcache)
jgphase = jonesStokes(gphase, gcachep)
jgphase = jonesStokes(exp, gphase, gcachep)
return JonesModel(jgamp*jgphase)
end

Expand All @@ -96,7 +97,7 @@ end
# the EHT is not very sensitive to a larger field of view. Typically 60-80 μas is enough to
# describe the compact flux of M87. Given this, we only need to use a small number of pixels
# to describe our image.
npix = 48
npix = 32
fovx = μas2rad(150.0)
fovy = μas2rad(150.0)

Expand All @@ -117,12 +118,25 @@ cache = create_cache(NFFTAlg(dvis), buffer, DeltaPulse())
# the timescale we expect them to vary. For the phases we use a station specific scheme where
# we set AA to be fixed to unit gain because it will function as a reference station.
gcache = jonescache(dvis, ScanSeg())
gcachep = jonescache(dvis, ScanSeg(); autoref=RandomReference(FixedSeg(complex(1.0))))
gcachep = jonescache(dvis, ScanSeg{true}(); autoref=SEFDReference((complex(0.0))))

using VLBIImagePriors
# Now we can form our metadata we need to fully define our model. First we
# also construct a `K` matrix or kernel that automatically centers the image.
K = CenterImage(grid)
# Now we need to specify our image prior. For this work we will use a Gaussian Markov
# Random field prior
# Since we are using a Gaussian Markov random field prior we need to first specify our `mean`
# image. This behaves somewhat similary to a entropy regularizer in that it will
# start with an initial guess for the image structure. For this tutorial we will use a
# a symmetric Gaussian with a FWHM of 60 μas
fwhmfac = 2*sqrt(2*log(2))
mpr = modify(Gaussian(), Stretch(μas2rad(50.0)./fwhmfac))
imgpr = intensitymap(mpr, grid)

# Now since we are actually modeling our image on the simplex we need to ensure that
# our mean image has unit flux
imgpr ./= flux(imgpr)
# and since our prior is not on the simplex we need to convert it to `unconstrained or real space`.
meanpr = to_real(CenteredLR(), Comrade.baseimage(imgpr))

# We will also fix the total flux to be the observed value 1.1. This is because
# total flux is degenerate with a global shift in the gain amplitudes making the problem
# degenerate. To fix this we use the observed total flux as our value.
Expand All @@ -143,7 +157,7 @@ distamp = station_tuple(dvis, Normal(0.0, 0.1); LM = Normal(1.0))
# This means that rather than the parameters
# being directly the gains, we fit the first gain for each site, and then
# the other parameters are the segmented gains compared to the previous time. To model this
#, we break the gain phase prior into two parts. The first is the prior
# we break the gain phase prior into two parts. The first is the prior
# for the first observing timestamp of each site, `distphase0`, and the second is the
# prior for segmented gain ϵₜ from time i to i+1, given by `distphase`. For the EHT, we are
# dealing with pre-2*rand(rng, ndim) .- 1.5calibrated data, so often, the gain phase jumps from scan to scan are
Expand All @@ -154,25 +168,11 @@ distamp = station_tuple(dvis, Normal(0.0, 0.1); LM = Normal(1.0))
distphase = station_tuple(dvis, DiagonalVonMises(0.0, inv^2)))


# Now we need to specify our image prior. For this work we will use a Gaussian Markov
# Random field prior
# Since we are using a Gaussian Markov random field prior we need to first specify our `mean`
# image. For this work we will use a symmetric Gaussian with a FWHM of 40 μas
fwhmfac = 2*sqrt(2*log(2))
mpr = modify(Gaussian(), Stretch(μas2rad(80.0)./fwhmfac))
imgpr = intensitymap(mpr, grid)

# Now since we are actually modeling our image on the simplex we need to ensure that
# our mean image has unit flux
imgpr ./= flux(imgpr)

meanpr = to_real(CenteredLR(), Comrade.baseimage(imgpr))

# In addition we want a reasonable guess for what the resolution of our image should be.
# For radio astronomy this is given by roughly the longest baseline in the image. To put this
# into pixel space we then divide by the pixel size.
hh(x) = hypot(x...)
beam = inv(maximum(hh.(uvpositions.(extract_table(obs, ComplexVisibilities()).data))))
beam = beamsize(dvis)
rat = (beam/(step(grid.X)))

# To make the Gaussian Markov random field efficient we first precompute a bunch of quantities
Expand All @@ -188,20 +188,23 @@ crcache = MarkovRandomFieldCache(meanpr)
# - the default prior unit multivariate or standard normal distribution.
# We include this transformation in the metadata since it is needed when forming the actual image or log-ratio of the image.
trf, cprior = standardize(crcache, Normal)
metadata = (;K, trf, meanpr, ftot=1.1, grid, cache, gcache, gcachep)

# Now we can form our metadata we need to fully define our model.
metadata = (;ftot=1.1, trf=trf, K = CenterImage(imgpr), meanpr, grid, cache, gcache, gcachep)



# We can now form our model parameter priors. Like our other imaging examples, we use a
# Dirichlet prior for our image pixels. For the log gain amplitudes, we use the `CalPrior`

# We can now form our model parameter priors. For the log gain amplitudes, we use the `CalPrior`
# which automatically constructs the prior for the given jones cache `gcache`.
prior = (
prior = NamedDist(
fg = Uniform(0.0, 1.0),
c = cprior,
λ = truncated(Normal(0.0, 0.1/rat); lower=2/npix),
σ = truncated(Normal(0.0, 0.1); lower=0.0),
ν = InverseGamma(10.0, 50.0),
σimg = truncated(Normal(0.0, 0.5); lower=0.01),
λ = truncated(Normal(0.0, 0.25*inv(rat)); lower=2/npix),
ν = InverseGamma(5.0, 10.0),
lgamp = CalPrior(distamp, gcache),
gphase = CalPrior(distphase, gcachep),
gphase = CalPrior(distphase, station_tuple(dvis, DiagonalVonMises(0.0, inv(0.1^2))),gcachep)
)


Expand All @@ -218,13 +221,13 @@ post = Posterior(lklhd, prior)
tpost = asflat(post)
ndim = dimension(tpost)

# Our Posterior and TransformedPosterior objects satisfy the `LogDensityProblems` interface.
# Our `Posterior` and `TransformedPosterior` objects satisfy the `LogDensityProblems` interface.
# This allows us to easily switch between different AD backends and many of Julia's statistical
# inference packages use this interface as well.
using LogDensityProblemsAD
using Zygote
gtpost = ADgradient(Val(:Zygote), tpost)
x0 = randn(ndim)
x0 = randn(rng, ndim)
LogDensityProblemsAD.logdensity_and_gradient(gtpost, x0)

# We can now also find the dimension of our posterior or the number of parameters we are going to sample.
Expand All @@ -238,81 +241,85 @@ LogDensityProblemsAD.logdensity_and_gradient(gtpost, x0)
using ComradeOptimization
using OptimizationOptimJL
f = OptimizationFunction(tpost, Optimization.AutoZygote())

sols = map(1:10) do i
prob = Optimization.OptimizationProblem(f, prior_sample(rng, tpost), nothing)
= logdensityof(tpost)
sol = solve(prob, LBFGS(), maxiters=5_000, g_tol=1e-1)
@info sol.minimum
return sol
end

mins = getproperty.(sols, :minimum)
sols = sols[sortperm(mins)]


prob = Optimization.OptimizationProblem(f, rand(rng, ndim) .- 0.5, nothing)
= logdensityof(tpost)
sol = solve(prob, LBFGS(), maxiters=1_000, g_tol=1e-1);

# Now transform back to parameter space
xopts = transform.(Ref(tpost), sols)
xopt = transform(tpost, sol.u)

# !!! warning
# Fitting gains tends to be very difficult, meaning that optimization can take a lot longer.
# The upside is that we usually get nicer images.
#-
# First we will evaluate our fit by plotting the residuals
using Plots
residual(vlbimodel(post, xopts[1]), dvis)
residual(vlbimodel(post, xopt), dvis)

# These look reasonable, although there may be some minor overfitting. This could be
# improved in a few ways, but that is beyond the goal of this quick tutorial.
# Plotting the image, we see that we have a much cleaner version of the closure-only image from
# [Imaging a Black Hole using only Closure Quantities](@ref).
using CairoMakie
img = intensitymap(skymodel(post, xopts[1]), fovx, fovy, 128, 128)
image(img, title="MAP Image", axis=(xreversed=true, aspect=1), colormap=:afmhot)
img = intensitymap(skymodel(post, xopt), fovx, fovy, 128, 128)
plot(img, title="MAP Image")


# Because we also fit the instrument model, we can inspect their parameters.
# To do this, `Comrade` provides a `caltable` function that converts the flattened gain parameters
# to a tabular format based on the time and its segmentation.
gt = Comrade.caltable(gcachep, xopts[1].gphase)
Plots.plot(gt, layout=(3,3), size=(600,500))
gt = Comrade.caltable(gcachep, xopt.gphase)
plot(gt, layout=(3,3), size=(600,500))

# The gain phases are pretty random, although much of this is due to us picking a random
# reference station for each scan.

# Moving onto the gain amplitudes, we see that most of the gain variation is within 10% as expected
# except LMT, which has massive variations.
gt = Comrade.caltable(gcache, exp.(xopts[2].lgamp))
Plots.plot(gt, layout=(3,3), size=(600,500))
gt = Comrade.caltable(gcache, exp.(xopt.lgamp))
plot(gt, layout=(3,3), size=(600,500))


# To sample from the posterior, we will use HMC, specifically the NUTS algorithm. For information about NUTS,
# To sample from the posterior, we will use HMC, specifically the NUTS algorithm. For
# information about NUTS,
# see Michael Betancourt's [notes](https://arxiv.org/abs/1701.02434).
# !!! note
# For our `metric,` we use a diagonal matrix due to easier tuning
#-
# However, due to the need to sample a large number of gain parameters, constructing the posterior
# is rather time-consuming. Therefore, for this tutorial, we will only do a quick preliminary run, and any posterior
# is rather time-consuming. Therefore, for this tutorial, we will only do a quick preliminary
# run, and any posterior
# inferences should be appropriately skeptical.
#-
using ComradeAHMC
metric = DiagEuclideanMetric(ndim)
trace1 = sample(rng, post, AHMC(;metric, autodiff=Val(:Zygote)), 2_000; saveto=DiskStore("ResolveMapT1"), nadapts = 1_000, init_params=xopts[1])
chain, stats = sample(rng, post, AHMC(;metric, autodiff=Val(:Zygote), init_buffer=200, ), 3000; nadapts=2000, init_params=xopt)

#-
# !!! note
# The above sampler will store the samples in memory, i.e. RAM. For large models this
# can lead to out-of-memory issues. To fix that you can include the keyword argument
# `saveto = DiskStore()` which periodically saves the samples to disk limiting memory
# useage. You can load the chain using `load_table(diskout)` where `diskout` is
# the object returned from sample. For more information please see [ComradeAHMC](@ref).
#-

# Now we prune the adaptation phase
chainsub = chain[2001:end]

#-
# !!! warning
# This should be run for likely an order of magnitude more steps to properly estimate expectations of the posterior
#-
chain = load_table(trace1)


# Now that we have our posterior, we can put error bars on all of our plots above.
# Let's start by finding the mean and standard deviation of the gain phases
gphase = hcat(chain.gphase...)
gphase = hcat(chainsub.gphase...)
mgphase = mean(gphase, dims=2)
sgphase = std(gphase, dims=2)

# and now the gain amplitudes
gamp = exp.(hcat(chain.lgamp...))
gamp = exp.(hcat(chainsub.lgamp...))
mgamp = mean(gamp, dims=2)
sgamp = std(gamp, dims=2)

Expand All @@ -332,10 +339,8 @@ plot(ctable_ph, layout=(3,3), size=(600,500))
plot(ctable_am, layout=(3,3), size=(600,500))

# Finally let's construct some representative image reconstructions.
samples = skymodel.(Ref(post), chain[1001:10:end])
imgs = intensitymap.(samples, fovx, fovy, 128, 128);

image(imgs[27], axis=(xreversed=true, aspect=1), colormap=:afmhot)
samples = skymodel.(Ref(post), chainsub[begin:5:end])
imgs = (intensitymap.(samples, fovx, fovy, 128, 128))

mimg = mean(imgs)
simg = std(imgs)
Expand All @@ -346,8 +351,9 @@ p4 = plot(imgs[end], title="Draw 2", clims = (0.0, maximum(mimg)));
plot(p1,p2,p3,p4, layout=(2,2), size=(800,800))

# Now let's check the residuals

p = plot();
for s in sample(chain, 10)
for s in sample(chainsub, 10)
residual!(p, vlbimodel(post, s), dvis)
end
p
Expand Down

0 comments on commit 28f5c25

Please sign in to comment.