Skip to content

Commit

Permalink
add some additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Aug 19, 2024
1 parent 7ca9376 commit e50cfe9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/instrument/priors/array_priors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct ArrayPrior{D, A, R, C<:Union{NTuple{2, Symbol}, Nothing}}
struct ArrayPrior{D, A, R, C}
default_dist::D
override_dist::A
refant::R
Expand Down Expand Up @@ -30,9 +30,13 @@ means that every site has a normal prior with mean 0 and 0.1 std. dev. except LM
zero and unit std. dev. Finally the refant is using the [`SEFDReference`](@ref) scheme.
"""
function ArrayPrior(dist; refant=NoReference(), phase=false, centroid_station=nothing, kwargs...)
if centroid_station isa Tuple{<:Symbol, <:Symbol}
centroid_station = NamedTuple{centroid_station}((0.0, 0.0))
end
return ArrayPrior(dist, kwargs, refant, phase, centroid_station)
end


function site_priors(d::ArrayPrior, array)
return site_tuple(array, d.default_dist; d.override_dist...)
end
Expand Down Expand Up @@ -183,13 +187,14 @@ function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroi
# fs = smap.frequencies
fixedinds, vals = reference_indices(array, smap, refants)


if !(centroid_station isa Nothing)
centroid1 = findfirst(==(centroid_station[1]), ss)
centroid2 = findfirst(==(centroid_station[2]), ss)
centstat = keys(centroid_station)
vals = values(centroid_station)
centroid1 = findfirst(==(centstat[1]), ss)
centroid2 = findfirst(==(centstat[2]), ss)
centroid === nothing && throw(ArgumentError("Centroid station not found in site list"))
append!(fixedinds, [centroid1, centroid2])
vals = append!(collect(vals), fill(0.0, 2))
vals = append!(collect(vals), [vals[1], vals[2]])
end

variateinds = setdiff(eachindex(ts), fixedinds)
Expand Down
5 changes: 4 additions & 1 deletion src/mrf_image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ function _apply_fluctuations(f, mimg::AbstractArray, δ::AbstractArray)
return mimg.*f.(δ)
end

_checknorm(m::AbstractArray) = isapprox(sum(m), 1, atol=1e-6)
Enzyme.EnzymeRules.inactive(::typeof(_checknorm), args...) = nothing

function _apply_fluctuations(t::VLBIImagePriors.LogRatioTransform, mimg::AbstractArray, δ::AbstractArray)
@argcheck isapprox(sum(parent(mimg)), 1, atol=1e-6) "Mean image must have unit flux when using log-ratio transformations in apply_fluctuations"
@argcheck _checknorm(mimg) "Mean image must have unit flux when using log-ratio transformations in apply_fluctuations"
r = to_simplex(t, δ)
r .= r.*parent(mimg)
r .= r./sum(r)
Expand Down
5 changes: 5 additions & 0 deletions src/posterior/vlbiposterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ function combine_prior(skymodel, ::Tuple{})
return NamedDist((sky=skymodel,))
end

function combine_prior(skymodel::NamedDist{()}, intmodel::Tuple{})
return NamedDist()
end


function combine_prior(skymodel, ::NamedDist{()})
return NamedDist((sky=skymodel,))
end
Expand Down

0 comments on commit e50cfe9

Please sign in to comment.