From 2f29f95eeda5782746657441e453e194c5b3f00c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 24 Oct 2024 01:30:36 +0100 Subject: [PATCH] Tapir -> Mooncake (???) --- .github/workflows/AD.yml | 4 +--- Project.toml | 11 +++++------ ...jectorsTapirExt.jl => BijectorsMooncakeExt.jl} | 15 ++++++++------- test/ad/chainrules.jl | 14 +++++++------- test/ad/utils.jl | 10 +++++----- test/runtests.jl | 8 ++++---- 6 files changed, 30 insertions(+), 32 deletions(-) rename ext/{BijectorsTapirExt.jl => BijectorsMooncakeExt.jl} (77%) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 47ef8549..b0ea13a5 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -23,13 +23,11 @@ jobs: AD: - Enzyme - ForwardDiff - - Tapir + - Mooncake - Tracker - ReverseDiff - Zygote exclude: - - version: 1.6 - AD: Tapir # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - version: 1.6 diff --git a/Project.toml b/Project.toml index e7498ac7..c976ab59 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.19" +version = "0.14.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -25,12 +25,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -39,8 +38,8 @@ BijectorsEnzymeExt = "Enzyme" BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" +BijectorsMooncakeExt = "Mooncake" BijectorsTrackerExt = "Tracker" -BijectorsTapirExt = "Tapir" BijectorsZygoteExt = "Zygote" [compat] @@ -65,7 +64,7 @@ Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.4, 2" Statistics = "1" -Tapir = "0.2.23" +Mooncake = "0.4.19" Tracker = "0.2" Zygote = "0.6.63" julia = "1.6" @@ -76,6 +75,6 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/BijectorsTapirExt.jl b/ext/BijectorsMooncakeExt.jl similarity index 77% rename from ext/BijectorsTapirExt.jl rename to ext/BijectorsMooncakeExt.jl index 70805a82..d7285bf6 100644 --- a/ext/BijectorsTapirExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,10 +1,11 @@ -module BijectorsTapirExt +module BijectorsMooncakeExt if isdefined(Base, :get_extension) - using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule + using Mooncake: + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore else - using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule + using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule using ..Bijectors: find_alpha, ChainRulesCore end @@ -19,20 +20,20 @@ end # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) -function Tapir.rrule!!( +function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} # Require that the integer is non-differentiable. - if tangent_type(I) != Tapir.NoTangent + if tangent_type(I) != Mooncake.NoTangent msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." throw(ArgumentError(msg)) end out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) function find_alpha_pb(dout::P) _, dx, dy, _ = pb(dout) - return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() + return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData() end - return Tapir.zero_fcodual(out), find_alpha_pb + return Mooncake.zero_fcodual(out), find_alpha_pb end end diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index bcdb9523..a2c13df1 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -27,9 +27,9 @@ end test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) - if @isdefined Tapir + if @isdefined Mooncake rng = Xoshiro(123456) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -37,9 +37,9 @@ end z; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -47,9 +47,9 @@ end 3; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -57,7 +57,7 @@ end UInt32(3); is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 3e21e693..c6c21144 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) b in ( :ForwardDiff, :Zygote, - :Tapir, + :Mooncake, :ReverseDiff, :Enzyme, :EnzymeForward, @@ -78,12 +78,12 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" - rule = Tapir.build_rrule(f, x; safety_on=false) + if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10" + rule = Mooncake.build_rrule(f, x; safety_on=false) if :tapir in broken @test_broken( isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], + Mooncake.value_and_gradient!!(rule, f, x)[2][2], finitediff; rtol=rtol, atol=atol, @@ -92,7 +92,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) else @test( isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], + Mooncake.value_and_gradient!!(rule, f, x)[2][2], finitediff; rtol=rtol, atol=atol, diff --git a/test/runtests.jl b/test/runtests.jl index 914c0e32..7ee539ed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,12 +34,12 @@ if VERSION < v"1.9" using Compat: stack end -# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing -# on at least version 1.10. +# Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're +# testing on at least version 1.10. if VERSION >= v"1.10" using Pkg - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + using Mooncake end const GROUP = get(ENV, "GROUP", "All")