Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove time_exp rule #134

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.7.1

time_exp has been removed in favour of assuming that whichever AD library is being used can
successfully AD through the matrix exponential. Guard rails to prevent mis-use of previous
rule have been remved.

# 0.7

Mooncake.jl (and probably Enzyme.jl) is now able to differentiate everything in
Expand Down
10 changes: 2 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["Will Tebbutt and contributors"]
version = "0.7.1"
version = "0.7.2"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -14,12 +14,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[weakdeps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
TemporalGPsMooncakeExt = "Mooncake"

[compat]
AbstractGPs = "0.5.17"
BenchmarkTools = "1"
Expand All @@ -28,7 +22,7 @@ BlockDiagonals = "0.1.7"
FillArrays = "0.13.0 - 0.13.7, 1"
JET = "0.9"
KernelFunctions = "0.9, 0.10.1"
Mooncake = "0.4.3"
Mooncake = "0.4.41"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
julia = "1.6"
Expand Down
22 changes: 0 additions & 22 deletions ext/TemporalGPsMooncakeExt.jl

This file was deleted.

30 changes: 2 additions & 28 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,14 @@ function add_proj_mean(hs::AbstractVector, m)
return map((h, m) -> h + vcat(m, Zeros(length(h) - 1)), hs, m)
end

# Really just a hook for AD.
time_exp(A, t) = exp(A * t)

# Generic constructors for base kernels.

function broadcast_components(
(F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}
) where {T}
P = Symmetric(x0.P)
t = vcat([first(t) - 1], t)
As = map(Δt -> time_exp(F, T(Δt)), diff(t))
As = map(Δt -> exp(F * T(Δt)), diff(t))
as = Fill(Zeros{T}(size(first(As), 1)), length(As))
Qs = map(A -> P - A * P * A', As)
Hs = Fill(H, length(As))
Expand All @@ -152,7 +149,7 @@ function broadcast_components(
(F, q, H)::Tuple, x0::Gaussian, t::Union{StepRangeLen, RegularSpacing}, ::StorageType{T}
) where {T}
P = Symmetric(x0.P)
A = time_exp(F, T(step(t)))
A = exp(F * T(step(t)))
As = Fill(A, length(t))
as = Fill(Zeros{T}(size(F, 1)), length(t))
Q = Symmetric(P) - A * Symmetric(P) * A'
Expand Down Expand Up @@ -187,8 +184,6 @@ function stationary_distribution(k::SimpleKernel, ::ArrayStorage{T}) where {T<:R
return Gaussian(collect(x.m), collect(x.P))
end

safe_to_product(::Kernel) = false

# Matern-1/2

function to_sde(::Matern12Kernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -205,8 +200,6 @@ function stationary_distribution(::Matern12Kernel, ::SArrayStorage{T}) where {T<
)
end

safe_to_product(::Matern12Kernel) = true

# Matern - 3/2

function to_sde(::Matern32Kernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -224,8 +217,6 @@ function stationary_distribution(::Matern32Kernel, ::SArrayStorage{T}) where {T<
)
end

safe_to_product(::Matern32Kernel) = true

# Matern - 5/2

function to_sde(::Matern52Kernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -243,8 +234,6 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T<
return Gaussian(m, P)
end

safe_to_product(::Matern52Kernel) = true

# Cosine

function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T}
Expand All @@ -260,8 +249,6 @@ function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:R
return Gaussian(m, P)
end

safe_to_product(::CosineKernel) = true

# ApproxPeriodicKernel

# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
Expand Down Expand Up @@ -319,8 +306,6 @@ function stationary_distribution(kernel::ApproxPeriodicKernel{N}, storage::Array
return Gaussian(m, P)
end

safe_to_product(::ApproxPeriodicKernel) = true

# Constant

function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -334,9 +319,6 @@ function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{
return TemporalGPs.Gaussian(SVector{1, T}(0), SMatrix{1, 1, T}(T(only(k.c))))
end

safe_to_product(::ConstantKernel) = true


# Scaled

function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real}
Expand All @@ -349,8 +331,6 @@ function stationary_distribution(k::ScaledKernel, storage::StorageType)
return stationary_distribution(k.kernel, storage)
end

safe_to_product(k::ScaledKernel) = safe_to_product(k.kernel)

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(k.σ²)))
Expand Down Expand Up @@ -378,8 +358,6 @@ function stationary_distribution(
return stationary_distribution(k.kernel, storage)
end

safe_to_product(::TransformedKernel{<:Kernel, <:ScaleTransform}) = false

function lgssm_components(
k::TransformedKernel{<:Kernel, <:ScaleTransform},
ts::AbstractVector,
Expand All @@ -396,12 +374,8 @@ apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts.

# Product

safe_to_product(k::KernelProduct) = all(safe_to_product, k.kernels)

function lgssm_components(k::KernelProduct, ts::AbstractVector, storage::StorageType)

safe_to_product(k) || throw(ArgumentError("Not all kernels in k are safe to product."))

sde_kernels = to_sde.(k.kernels, Ref(storage))
F_kernels = getindex.(sde_kernels, 1)
F = foldl(_kron_add, F_kernels)
Expand Down
7 changes: 0 additions & 7 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ end

println("lti_sde:")
@testset "lti_sde" begin

@testset "ApproxPeriodicKernel" begin
k = ApproxPeriodicKernel()
@test k isa ApproxPeriodicKernel{7}
Expand Down Expand Up @@ -207,10 +206,4 @@ println("lti_sde:")
)
end
end
@testset "time_exp AD" begin
test_rule(
Xoshiro(123), t -> TemporalGPs.time_exp([1.0 2.0; 3.0 4.0], t), rand();
is_primitive=false,
)
end
end
2 changes: 1 addition & 1 deletion test/models/linear_gaussian_conditionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ println("linear_gaussian_conditionals:")
x_post_large, lml_large = posterior_and_lml(x, model, y_missing)

# Check that they give roughly the same answer.
@test x_post_vanilla ≈ x_post_large rtol=1e-8 atol=1e-8
@test x_post_vanilla ≈ x_post_large rtol=1e-5 atol=1e-5
@test lml_vanilla ≈ lml_large rtol=1e-8 atol=1e-8

# Check that everything infers and AD gives the right answer.
Expand Down
5 changes: 4 additions & 1 deletion test/space_time/to_gauss_markov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@
@test std.(marginals(fx_post_naive)) ≈ std.(marginals(fx_post_sde))

y_post = rand(rng, fx_post_naive)
@test logpdf(fx_post_naive, y_post) ≈ logpdf(fx_post_sde, y_post)
@test isapprox(
logpdf(fx_post_naive, y_post), logpdf(fx_post_sde, y_post);
atol=1e-6, rtol=1e-6,
)

# No statistical tests run on `rand`, which seems somewhat dangerous, but there's
# not a lot to be done about it unfortunately.
Expand Down
Loading