diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 1373a4bbd..a350d2908 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -40,11 +40,17 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo -getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f)) +function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) + return getvarinfo(LogDensityProblemsAD.parent(f)) +end setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo -function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) - return Accessors.@set f.ℓ = setvarinfo(f.ℓ, varinfo) +function setvarinfo( + f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType +) + return LogDensityProblemsAD.ADgradient( + adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo) + ) end """ @@ -120,7 +126,7 @@ function AbstractMCMC.step( varinfo = DynamicPPL.link(varinfo, model) end end - f = setvarinfo(f, varinfo) + f = setvarinfo(f, varinfo, alg.adtype) # Then just call `AdvancedHMC.step` with the right arguments. if initial_state === nothing diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 64a5e95df..15ec6149c 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -2,6 +2,7 @@ module InferenceTests using ..Models: gdemo_d, gdemo_default using ..NumericalTests: check_gdemo, check_numerical +import ..ADUtils using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL @@ -14,7 +15,9 @@ import ReverseDiff using Test: @test, @test_throws, @testset using Turing -@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) +ADUtils.install_tapir && import Tapir + +@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 10d202da3..43e3966a9 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -1,5 +1,6 @@ module AbstractMCMCTests +import ..ADUtils using AdvancedMH: AdvancedMH using Distributions: sample using Distributions.FillArrays: Zeros @@ -15,6 +16,8 @@ using Test: @test, @test_throws, @testset using Turing using Turing.Inference: AdvancedHMC +ADUtils.install_tapir && import Tapir + function initialize_nuts(model::Turing.Model) # Create a log-density function with an implementation of the # gradient so we ensure that we're using the same AD backend as in Turing. @@ -22,7 +25,9 @@ function initialize_nuts(model::Turing.Model) # Link the varinfo. f = Turing.Inference.setvarinfo( - f, DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model) + f, + DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model), + Turing.Inference.getADType(DynamicPPL.getcontext(LogDensityProblemsAD.parent(f))), ) # Choose parameter dimensionality and initial parameter value @@ -112,7 +117,9 @@ end @testset "External samplers" begin @testset "AdvancedHMC.jl" begin - # Try a few different AD backends. + # TODO(mhauru) The below tests fail with Tapir, see + # https://github.com/TuringLang/Turing.jl/pull/2289. + # Once that is fixed, this should say `for adtype in ADUtils.adbackends`. @testset "adtype=$adtype" for adtype in [AutoForwardDiff(), AutoReverseDiff()] @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Need some functionality to initialize the sampler. diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f30dc0f77..6868cb5e8 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -2,6 +2,7 @@ module GibbsTests using ..Models: MoGtest_default, gdemo, gdemo_default using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical +import ..ADUtils using Distributions: InverseGamma, Normal using Distributions: sample using ForwardDiff: ForwardDiff @@ -12,9 +13,9 @@ using Turing using Turing: Inference using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess -@testset "Testing gibbs.jl with $adbackend" for adbackend in ( - AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false) -) +ADUtils.install_tapir && import Tapir + +@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "gibbs constructor" begin N = 500 s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index 6110bbdb1..3f02c7594 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -2,6 +2,7 @@ module GibbsConditionalTests using ..Models: gdemo, gdemo_default using ..NumericalTests: check_gdemo, check_numerical +import ..ADUtils using Clustering: Clustering using Distributions: Categorical, InverseGamma, Normal, sample using ForwardDiff: ForwardDiff @@ -14,9 +15,9 @@ using StatsFuns: StatsFuns using Test: @test, @testset using Turing -@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ( - AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false) -) +ADUtils.install_tapir && import Tapir + +@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends Random.seed!(1000) rng = StableRNG(123) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index b589d4687..dde977a6f 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -4,6 +4,7 @@ using ..Models: gdemo_default using ..ADUtils: ADTypeCheckContext #using ..Models: gdemo using ..NumericalTests: check_gdemo, check_numerical +import ..ADUtils using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample import DynamicPPL using DynamicPPL: Sampler @@ -17,7 +18,9 @@ using StatsFuns: logistic using Test: @test, @test_logs, @testset using Turing -@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) +ADUtils.install_tapir && import Tapir + +@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index a5829eb18..95b3bc543 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -2,6 +2,7 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo +import ..ADUtils using Distributions: sample import ForwardDiff using LinearAlgebra: dot @@ -10,7 +11,9 @@ using StableRNGs: StableRNG using Test: @test, @testset using Turing -@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) +ADUtils.install_tapir && import Tapir + +@testset "Testing sghmc.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) @test alg isa SGHMC @@ -36,7 +39,7 @@ using Turing end end -@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) +@testset "Testing sgld.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) @test alg isa SGLD diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 27b469ed7..e900a8f69 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -1,6 +1,7 @@ module ADUtils using ForwardDiff: ForwardDiff +using Pkg: Pkg using Random: Random using ReverseDiff: ReverseDiff using Test: Test @@ -9,7 +10,10 @@ using Turing: Turing using Turing: DynamicPPL using Zygote: Zygote -export ADTypeCheckContext +export ADTypeCheckContext, adbackends + +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# Stuff for checking that the right AD backend is being used. """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) @@ -270,4 +274,24 @@ Test.@testset "ADTypeCheckContext" begin end end +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# List of AD backends to test. + +""" +All the ADTypes on which we want to run the tests. +""" +adbackends = [ + Turing.AutoForwardDiff(; chunksize=0), Turing.AutoReverseDiff(; compile=false) +] + +# Tapir isn't supported for older Julia versions, hence the check. +install_tapir = isdefined(Turing, :AutoTapir) +if install_tapir + # TODO(mhauru) Is there a better way to install optional dependencies like this? + Pkg.add("Tapir") + using Tapir + push!(adbackends, Turing.AutoTapir(false)) + push!(eltypes_by_adtype, Turing.AutoTapir => (Tapir.CoDual,)) +end + end