From e50cfe9e32b12562167d664eeea6419ca033cb73 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 18 Aug 2024 23:16:57 -0400 Subject: [PATCH] add some additional tests --- src/instrument/priors/array_priors.jl | 15 ++++++++++----- src/mrf_image.jl | 5 ++++- src/posterior/vlbiposterior.jl | 5 +++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index b7dab6c5..2625ae8e 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -1,4 +1,4 @@ -struct ArrayPrior{D, A, R, C<:Union{NTuple{2, Symbol}, Nothing}} +struct ArrayPrior{D, A, R, C} default_dist::D override_dist::A refant::R @@ -30,9 +30,13 @@ means that every site has a normal prior with mean 0 and 0.1 std. dev. except LM zero and unit std. dev. Finally the refant is using the [`SEFDReference`](@ref) scheme. """ function ArrayPrior(dist; refant=NoReference(), phase=false, centroid_station=nothing, kwargs...) + if centroid_station isa Tuple{<:Symbol, <:Symbol} + centroid_station = NamedTuple{centroid_station}((0.0, 0.0)) + end return ArrayPrior(dist, kwargs, refant, phase, centroid_station) end + function site_priors(d::ArrayPrior, array) return site_tuple(array, d.default_dist; d.override_dist...) end @@ -183,13 +187,14 @@ function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroi # fs = smap.frequencies fixedinds, vals = reference_indices(array, smap, refants) - if !(centroid_station isa Nothing) - centroid1 = findfirst(==(centroid_station[1]), ss) - centroid2 = findfirst(==(centroid_station[2]), ss) + centstat = keys(centroid_station) + vals = values(centroid_station) + centroid1 = findfirst(==(centstat[1]), ss) + centroid2 = findfirst(==(centstat[2]), ss) centroid === nothing && throw(ArgumentError("Centroid station not found in site list")) append!(fixedinds, [centroid1, centroid2]) - vals = append!(collect(vals), fill(0.0, 2)) + vals = append!(collect(vals), [vals[1], vals[2]]) end variateinds = setdiff(eachindex(ts), fixedinds) diff --git a/src/mrf_image.jl b/src/mrf_image.jl index 0d3af12c..44cf94bf 100644 --- a/src/mrf_image.jl +++ b/src/mrf_image.jl @@ -31,8 +31,11 @@ function _apply_fluctuations(f, mimg::AbstractArray, δ::AbstractArray) return mimg.*f.(δ) end +_checknorm(m::AbstractArray) = isapprox(sum(m), 1, atol=1e-6) +Enzyme.EnzymeRules.inactive(::typeof(_checknorm), args...) = nothing + function _apply_fluctuations(t::VLBIImagePriors.LogRatioTransform, mimg::AbstractArray, δ::AbstractArray) - @argcheck isapprox(sum(parent(mimg)), 1, atol=1e-6) "Mean image must have unit flux when using log-ratio transformations in apply_fluctuations" + @argcheck _checknorm(mimg) "Mean image must have unit flux when using log-ratio transformations in apply_fluctuations" r = to_simplex(t, δ) r .= r.*parent(mimg) r .= r./sum(r) diff --git a/src/posterior/vlbiposterior.jl b/src/posterior/vlbiposterior.jl index 5b6512f0..93acbca1 100644 --- a/src/posterior/vlbiposterior.jl +++ b/src/posterior/vlbiposterior.jl @@ -87,6 +87,11 @@ function combine_prior(skymodel, ::Tuple{}) return NamedDist((sky=skymodel,)) end +function combine_prior(skymodel::NamedDist{()}, intmodel::Tuple{}) + return NamedDist() +end + + function combine_prior(skymodel, ::NamedDist{()}) return NamedDist((sky=skymodel,)) end