diff --git a/Project.toml b/Project.toml index 97adf6111..691caead5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.3" +version = "0.2.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interface.jl b/src/interface.jl index 30f7bbb0c..b40637c5c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,32 +1,33 @@ """ - value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...) + __value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...) In-place version of `value_and_pullback!!` in which the arguments have been wrapped in `CoDual`s. Note that any mutable data in `f` and `x` will be incremented in-place. As such, if calling this function multiple times with different values of `x`, should be careful to ensure that you zero-out the tangent fields of `x` each time. """ -function value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, N, T} - out, pb!! = rule(fx...) +function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, N, T} + out, pb!! = rule(tuple_map(to_fwds, fx)...) @assert _typeof(tangent(out)) == fdata_type(T) increment!!(tangent(out), fdata(ȳ)) v = copy(primal(out)) - return v, pb!!(rdata(ȳ)) + return v, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(rdata(ȳ))) end """ - value_and_gradient!!(rule, f::CoDual, x::CoDual...) + __value_and_gradient!!(rule, f::CoDual, x::CoDual...) Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`. """ -function value_and_gradient!!(rule::R, fx::Vararg{CoDual, N}) where {R, N} - return value_and_pullback!!(rule, 1.0, fx...) +function __value_and_gradient!!(rule::R, fx::Vararg{CoDual, N}) where {R, N} + return __value_and_pullback!!(rule, 1.0, fx...) end """ value_and_pullback!!(rule, ȳ, f, x...) -Compute the value and pullback of `f(x...)`. +Compute the value and pullback of `f(x...)`. `ȳ` must be a valid tangent for the primal +return by `f(x...)`. `rule` should be constructed using `build_rrule`. @@ -45,10 +46,10 @@ will yield the wrong result. *Note:* This method of `value_and_pullback!!` has to first call `zero_codual` on all of its arguments. This may cause some additional allocations. If this is a problem in your use-case, consider pre-allocating the `CoDual`s and calling the other method of this -function. +function. The `CoDual`s should be primal-tangent pairs (as opposed to primal-fdata pairs). """ function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any, N}) where {R, N} - return value_and_pullback!!(rule, ȳ, map(zero_fcodual, fx)...) + return __value_and_pullback!!(rule, ȳ, tuple_map(zero_codual, fx)...) end """ @@ -57,5 +58,5 @@ end Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`. """ function value_and_gradient!!(rule::R, fx::Vararg{Any, N}) where {R, N} - return value_and_gradient!!(rule, map(zero_fcodual, fx)...) + return __value_and_gradient!!(rule, tuple_map(zero_codual, fx)...) end diff --git a/test/interface.jl b/test/interface.jl index 94c35d659..745194507 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -2,9 +2,14 @@ @testset "$(typeof((f, x...)))" for (ȳ, f, x...) in Any[ (1.0, (x, y) -> x * y + sin(x) * cos(y), 5.0, 4.0), ([1.0, 1.0], x -> [sin(x), sin(2x)], 3.0), + (1.0, x -> sum(5x), [5.0, 2.0]), ] rule = build_rrule(f, x...) - v, grad2 = value_and_pullback!!(rule, ȳ, f, x...) + v, (df, dx...) = value_and_pullback!!(rule, ȳ, f, x...) @test v ≈ f(x...) + @test df isa tangent_type(typeof(f)) + for (_dx, _x) in zip(dx, x) + @test _dx isa tangent_type(typeof(_x)) + end end end