From 0adf963616fd926a5d839065f515d5c91e564bc9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 30 Oct 2024 19:41:27 +0000 Subject: [PATCH] More SpecialFunctions Tests (#328) * zero_adjoint for irrational conversions to floats * Add more rules from SpecialFunctions * Bump patch version * Fix typo in docstring * Add more rules and tests for SpecialFunctions --- Project.toml | 2 +- ext/MooncakeSpecialFunctionsExt.jl | 44 ++++++++++++-- src/interpreter/s2s_reverse_mode_ad.jl | 2 +- .../avoiding_non_differentiable_code.jl | 9 +++ .../special_functions/special_functions.jl | 58 ++++++++++++++++--- 5 files changed, 100 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index c495f3020..ed8a12000 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.26" +version = "0.4.27" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/MooncakeSpecialFunctionsExt.jl b/ext/MooncakeSpecialFunctionsExt.jl index eca33b5e0..dc6fd1b0f 100644 --- a/ext/MooncakeSpecialFunctionsExt.jl +++ b/ext/MooncakeSpecialFunctionsExt.jl @@ -1,12 +1,46 @@ module MooncakeSpecialFunctionsExt using SpecialFunctions, Mooncake +using Base: IEEEFloat -import Mooncake: @from_rrule, DefaultCtx +import Mooncake: @from_rrule, DefaultCtx, @zero_adjoint -@from_rrule DefaultCtx Tuple{typeof(airyai), Float64} -@from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} -@from_rrule DefaultCtx Tuple{typeof(erfc), Float64} -@from_rrule DefaultCtx Tuple{typeof(erfcx), Float64} +@from_rrule DefaultCtx Tuple{typeof(airyai), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airyaix), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airyaiprime), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airybi), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airybiprime), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(besselj0), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(besselj1), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(bessely0), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(bessely1), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(dawson), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(digamma), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erf), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erf), IEEEFloat, IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfc), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logerfc), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfcinv), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfcx), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logerfcx), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfi), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfinv), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(gamma), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(invdigamma), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(trigamma), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(polygamma), Integer, IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(beta), IEEEFloat, IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logbeta), IEEEFloat, IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logabsgamma), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(loggamma), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(expint), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(expintx), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(expinti), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(sinint), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(cosint), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(ellipk), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(ellipe), IEEEFloat} + +@zero_adjoint DefaultCtx Tuple{typeof(logfactorial), Integer} end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index ff815cf41..56bce43cf 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -297,7 +297,7 @@ function comms_channel(info::ADStmtInfo) end #= - make_ad_stmts(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo + make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index d3a1800ae..df4fe7b8f 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -9,6 +9,9 @@ end @zero_adjoint MinimalCtx Tuple{typeof(randn), AbstractRNG, Vararg} @zero_adjoint MinimalCtx Tuple{typeof(string), Vararg} @zero_adjoint MinimalCtx Tuple{Type{Symbol}, Vararg} +@zero_adjoint MinimalCtx Tuple{Type{Float64}, Any, RoundingMode} +@zero_adjoint MinimalCtx Tuple{Type{Float32}, Any, RoundingMode} +@zero_adjoint MinimalCtx Tuple{Type{Float16}, Any, RoundingMode} function generate_hand_written_rrule!!_test_cases( rng_ctor, ::Val{:avoiding_non_differentiable_code} @@ -47,6 +50,12 @@ function generate_hand_written_rrule!!_test_cases( # Rules to make Symbol-related functionality work properly. (false, :stability_and_allocs, nothing, Symbol, "hello"), (false, :stability_and_allocs, nothing, Symbol, UInt8[1, 2]), + (false, :stability_and_allocs, nothing, Float64, π, RoundDown), + (false, :stability_and_allocs, nothing, Float64, π, RoundUp), + (true, :stability_and_allocs, nothing, Float32, π, RoundDown), + (true, :stability_and_allocs, nothing, Float32, π, RoundUp), + (true, :stability_and_allocs, nothing, Float16, π, RoundDown), + (true, :stability_and_allocs, nothing, Float16, π, RoundUp), ) memory = Any[_x, _dx] return test_cases, memory diff --git a/test/ext/special_functions/special_functions.jl b/test/ext/special_functions/special_functions.jl index 3621672f5..2d4e4a582 100644 --- a/test/ext/special_functions/special_functions.jl +++ b/test/ext/special_functions/special_functions.jl @@ -4,21 +4,63 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using Mooncake, SpecialFunctions, Test +# Rules in this file are only lightly tester, because they are all just @from_rrule rules. @testset "special_functions" begin @testset for (perf_flag, f, x...) in [ (:stability, airyai, 0.1), - (:stability, airyai, 0.0), - (:stability, airyai, -0.5), (:stability, airyaix, 0.1), - (:stability, airyaix, 0.05), - (:stability, airyaix, 0.9), + (:stability, airyaiprime, 0.1), + (:stability, airybi, 0.1), + (:stability, airybiprime, 0.1), + (:stability_and_allocs, besselj0, 0.1), + (:stability_and_allocs, besselj1, 0.1), + (:stability_and_allocs, bessely0, 0.1), + (:stability_and_allocs, bessely1, 0.1), + (:stability_and_allocs, dawson, 0.1), + (:stability_and_allocs, digamma, 0.1), + (:stability_and_allocs, erf, 0.1), + (:stability_and_allocs, erf, 0.1, 0.5), (:stability_and_allocs, erfc, 0.1), - (:stability_and_allocs, erfc, 0.0), - (:stability_and_allocs, erfc, -0.5), + (:stability_and_allocs, logerfc, 0.1), + (:stability_and_allocs, erfcinv, 0.1), (:stability_and_allocs, erfcx, 0.1), - (:stability_and_allocs, erfcx, 0.0), - (:stability_and_allocs, erfcx, -0.5), + (:stability_and_allocs, logerfcx, 0.1), + (:stability_and_allocs, erfi, 0.1), + (:stability_and_allocs, erfinv, 0.1), + (:stability_and_allocs, gamma, 0.1), + (:stability_and_allocs, invdigamma, 0.1), + (:stability_and_allocs, trigamma, 0.1), + (:stability_and_allocs, polygamma, 3, 0.1), + (:stability_and_allocs, beta, 0.3, 0.1), + (:stability_and_allocs, logbeta, 0.3, 0.1), + # (:stability_and_allocs, logabsgamma, 0.3), + (:stability_and_allocs, loggamma, 0.3), + (:stability_and_allocs, expint, 0.3), + (:stability_and_allocs, expintx, 0.3), + (:stability_and_allocs, expinti, 0.3), + (:stability_and_allocs, sinint, 0.3), + (:stability_and_allocs, cosint, 0.3), + (:stability_and_allocs, ellipk, 0.3), + (:stability_and_allocs, ellipe, 0.3), + (:stability_and_allocs, logfactorial, 3), ] test_rule(Xoshiro(123456), f, x...; perf_flag) end + @testset for (perf_flag, f, x...) in [ + (:allocs, logerf, 0.3, 0.5), # first branch + (:allocs, logerf, 1.1, 1.2), # second branch + (:allocs, logerf, -1.2, -1.1), # third branch + (:allocs, logerf, 0.3, 1.1), # fourth branch + (:allocs, SpecialFunctions.loggammadiv, 1.0, 9.0), + (:allocs, SpecialFunctions.gammax, 1.0), + (:allocs, SpecialFunctions.rgammax, 3.0, 6.0), + (:allocs, SpecialFunctions.rgamma1pm1, 0.1), + # (:allocs, SpecialFunctions.auxgam, 0.1), # allocations + # (:allocs, logabsbeta, 0.3, 0.1), # logabsgamma needs to work for this to work + # (:allocs, SpecialFunctions.loggamma1p, 0.3), # allocations + # (:allocs, SpecialFunctions.loggamma1p, -0.3), # allocations + # (:allocs, SpecialFunctions.lambdaeta, 5.0), # a genuine bug! + ] + test_rule(Xoshiro(123456), f, x...; perf_flag, is_primitive=false) + end end