diff --git a/NEWS.md b/NEWS.md index ba95817..4229e69 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# 0.7.1 + +time_exp has been removed in favour of assuming that whichever AD library is being used can +successfully AD through the matrix exponential. Guard rails to prevent mis-use of previous +rule have been remved. + # 0.7 Mooncake.jl (and probably Enzyme.jl) is now able to differentiate everything in diff --git a/Project.toml b/Project.toml index cc4fd34..5ceb77e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["Will Tebbutt and contributors"] -version = "0.7.1" +version = "0.7.2" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" @@ -14,12 +14,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -[weakdeps] -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" - -[extensions] -TemporalGPsMooncakeExt = "Mooncake" - [compat] AbstractGPs = "0.5.17" BenchmarkTools = "1" @@ -28,7 +22,7 @@ BlockDiagonals = "0.1.7" FillArrays = "0.13.0 - 0.13.7, 1" JET = "0.9" KernelFunctions = "0.9, 0.10.1" -Mooncake = "0.4.3" +Mooncake = "0.4.41" StaticArrays = "1" StructArrays = "0.5, 0.6" julia = "1.6" diff --git a/ext/TemporalGPsMooncakeExt.jl b/ext/TemporalGPsMooncakeExt.jl deleted file mode 100644 index e87c56e..0000000 --- a/ext/TemporalGPsMooncakeExt.jl +++ /dev/null @@ -1,22 +0,0 @@ -module TemporalGPsMooncakeExt - -using Mooncake, TemporalGPs -import Mooncake: - rrule!!, - CoDual, - primal, - @is_primitive, - zero_fcodual, - MinimalCtx - -@is_primitive MinimalCtx Tuple{typeof(TemporalGPs.time_exp), Matrix{<:Real}, Real} -function rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) - _A = primal(A) - B_dB = zero_fcodual(TemporalGPs.time_exp(_A, primal(t))) - B = primal(B_dB) - dB = tangent(B_dB) - time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (_A * B)) - return B_dB, time_exp_pb -end - -end diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index fba5d96..b37b3f2 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -130,9 +130,6 @@ function add_proj_mean(hs::AbstractVector, m) return map((h, m) -> h + vcat(m, Zeros(length(h) - 1)), hs, m) end -# Really just a hook for AD. -time_exp(A, t) = exp(A * t) - # Generic constructors for base kernels. function broadcast_components( @@ -140,7 +137,7 @@ function broadcast_components( ) where {T} P = Symmetric(x0.P) t = vcat([first(t) - 1], t) - As = map(Δt -> time_exp(F, T(Δt)), diff(t)) + As = map(Δt -> exp(F * T(Δt)), diff(t)) as = Fill(Zeros{T}(size(first(As), 1)), length(As)) Qs = map(A -> P - A * P * A', As) Hs = Fill(H, length(As)) @@ -152,7 +149,7 @@ function broadcast_components( (F, q, H)::Tuple, x0::Gaussian, t::Union{StepRangeLen, RegularSpacing}, ::StorageType{T} ) where {T} P = Symmetric(x0.P) - A = time_exp(F, T(step(t))) + A = exp(F * T(step(t))) As = Fill(A, length(t)) as = Fill(Zeros{T}(size(F, 1)), length(t)) Q = Symmetric(P) - A * Symmetric(P) * A' @@ -187,8 +184,6 @@ function stationary_distribution(k::SimpleKernel, ::ArrayStorage{T}) where {T<:R return Gaussian(collect(x.m), collect(x.P)) end -safe_to_product(::Kernel) = false - # Matern-1/2 function to_sde(::Matern12Kernel, ::SArrayStorage{T}) where {T<:Real} @@ -205,8 +200,6 @@ function stationary_distribution(::Matern12Kernel, ::SArrayStorage{T}) where {T< ) end -safe_to_product(::Matern12Kernel) = true - # Matern - 3/2 function to_sde(::Matern32Kernel, ::SArrayStorage{T}) where {T<:Real} @@ -224,8 +217,6 @@ function stationary_distribution(::Matern32Kernel, ::SArrayStorage{T}) where {T< ) end -safe_to_product(::Matern32Kernel) = true - # Matern - 5/2 function to_sde(::Matern52Kernel, ::SArrayStorage{T}) where {T<:Real} @@ -243,8 +234,6 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T< return Gaussian(m, P) end -safe_to_product(::Matern52Kernel) = true - # Cosine function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T} @@ -260,8 +249,6 @@ function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:R return Gaussian(m, P) end -safe_to_product(::CosineKernel) = true - # ApproxPeriodicKernel # The periodic kernel is approximated by a sum of cosine kernels with different frequencies. @@ -319,8 +306,6 @@ function stationary_distribution(kernel::ApproxPeriodicKernel{N}, storage::Array return Gaussian(m, P) end -safe_to_product(::ApproxPeriodicKernel) = true - # Constant function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} @@ -334,9 +319,6 @@ function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{ return TemporalGPs.Gaussian(SVector{1, T}(0), SMatrix{1, 1, T}(T(only(k.c)))) end -safe_to_product(::ConstantKernel) = true - - # Scaled function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real} @@ -349,8 +331,6 @@ function stationary_distribution(k::ScaledKernel, storage::StorageType) return stationary_distribution(k.kernel, storage) end -safe_to_product(k::ScaledKernel) = safe_to_product(k.kernel) - function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType) As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type) σ = sqrt(convert(eltype(storage_type), only(k.σ²))) @@ -378,8 +358,6 @@ function stationary_distribution( return stationary_distribution(k.kernel, storage) end -safe_to_product(::TransformedKernel{<:Kernel, <:ScaleTransform}) = false - function lgssm_components( k::TransformedKernel{<:Kernel, <:ScaleTransform}, ts::AbstractVector, @@ -396,12 +374,8 @@ apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts. # Product -safe_to_product(k::KernelProduct) = all(safe_to_product, k.kernels) - function lgssm_components(k::KernelProduct, ts::AbstractVector, storage::StorageType) - safe_to_product(k) || throw(ArgumentError("Not all kernels in k are safe to product.")) - sde_kernels = to_sde.(k.kernels, Ref(storage)) F_kernels = getindex.(sde_kernels, 1) F = foldl(_kron_add, F_kernels) diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index ad9cdfb..6fdf7f9 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -13,7 +13,6 @@ end println("lti_sde:") @testset "lti_sde" begin - @testset "ApproxPeriodicKernel" begin k = ApproxPeriodicKernel() @test k isa ApproxPeriodicKernel{7} @@ -207,10 +206,4 @@ println("lti_sde:") ) end end - @testset "time_exp AD" begin - test_rule( - Xoshiro(123), t -> TemporalGPs.time_exp([1.0 2.0; 3.0 4.0], t), rand(); - is_primitive=false, - ) - end end diff --git a/test/models/linear_gaussian_conditionals.jl b/test/models/linear_gaussian_conditionals.jl index d6fb055..306bf06 100644 --- a/test/models/linear_gaussian_conditionals.jl +++ b/test/models/linear_gaussian_conditionals.jl @@ -178,7 +178,7 @@ println("linear_gaussian_conditionals:") x_post_large, lml_large = posterior_and_lml(x, model, y_missing) # Check that they give roughly the same answer. - @test x_post_vanilla ≈ x_post_large rtol=1e-8 atol=1e-8 + @test x_post_vanilla ≈ x_post_large rtol=1e-5 atol=1e-5 @test lml_vanilla ≈ lml_large rtol=1e-8 atol=1e-8 # Check that everything infers and AD gives the right answer. diff --git a/test/space_time/to_gauss_markov.jl b/test/space_time/to_gauss_markov.jl index 18de0ee..b8c6bfc 100644 --- a/test/space_time/to_gauss_markov.jl +++ b/test/space_time/to_gauss_markov.jl @@ -81,7 +81,10 @@ @test std.(marginals(fx_post_naive)) ≈ std.(marginals(fx_post_sde)) y_post = rand(rng, fx_post_naive) - @test logpdf(fx_post_naive, y_post) ≈ logpdf(fx_post_sde, y_post) + @test isapprox( + logpdf(fx_post_naive, y_post), logpdf(fx_post_sde, y_post); + atol=1e-6, rtol=1e-6, + ) # No statistical tests run on `rand`, which seems somewhat dangerous, but there's # not a lot to be done about it unfortunately.