diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1963f83e8..becaaa8f9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - 'integration_testing/array' - 'integration_testing/turing' - 'integration_testing/temporalgps' + - 'interface' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index fdff8a624..97adf6111 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.2" +version = "0.2.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interface.jl b/src/interface.jl index 3e35018fb..30f7bbb0c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -8,10 +8,10 @@ ensure that you zero-out the tangent fields of `x` each time. """ function value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, N, T} out, pb!! = rule(fx...) - @assert _typeof(tangent(out)) == T - ty = increment!!(tangent(out), ȳ) + @assert _typeof(tangent(out)) == fdata_type(T) + increment!!(tangent(out), fdata(ȳ)) v = copy(primal(out)) - return v, pb!!(ty, map(tangent, fx)...) + return v, pb!!(rdata(ȳ)) end """ @@ -48,7 +48,7 @@ use-case, consider pre-allocating the `CoDual`s and calling the other method of function. """ function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any, N}) where {R, N} - return value_and_pullback!!(rule, ȳ, map(zero_codual, fx)...) + return value_and_pullback!!(rule, ȳ, map(zero_fcodual, fx)...) end """ @@ -57,5 +57,5 @@ end Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`. """ function value_and_gradient!!(rule::R, fx::Vararg{Any, N}) where {R, N} - return value_and_gradient!!(rule, map(zero_codual, fx)...) + return value_and_gradient!!(rule, map(zero_fcodual, fx)...) end diff --git a/test/runtests.jl b/test/runtests.jl index 42541ad66..359a699ec 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,6 +58,8 @@ include("front_matter.jl") include(joinpath("integration_testing", "turing.jl")) elseif test_group == "integration_testing/temporalgps" include(joinpath("integration_testing", "temporalgps.jl")) + elseif test_group == "interface" + include("interface.jl") else throw(error("test_group=$(test_group) is not recognised")) end