-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test against Enzyme #318
Test against Enzyme #318
Changes from all commits
175f382
89a3267
961aeb6
f33da48
28beac5
1239040
94023a1
7860701
3cf5e79
2163ec4
8654475
e43d71e
deec0a6
cd3000a
5ac4583
44bd739
976874b
f5fd835
41d643c
5124bd7
73711d6
04d98e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
module BijectorsEnzymeExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using Enzyme: @import_frule, @import_rrule | ||
using Bijectors: find_alpha | ||
else | ||
using ..Enzyme: @import_frule, @import_rrule | ||
using ..Bijectors: find_alpha | ||
end | ||
|
||
@import_rrule typeof(find_alpha) Real Real Real | ||
@import_frule typeof(find_alpha) Real Real Real | ||
|
||
end |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -2,7 +2,24 @@ | |||||||
const AD = get(ENV, "AD", "All") | ||||||||
|
||||||||
function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) | ||||||||
for b in broken | ||||||||
if !( | ||||||||
b in ( | ||||||||
:ForwardDiff, | ||||||||
:Zygote, | ||||||||
:Tapir, | ||||||||
:ReverseDiff, | ||||||||
:Enzyme, | ||||||||
:EnzymeForward, | ||||||||
:EnzymeReverse, | ||||||||
) | ||||||||
) | ||||||||
error("Unknown broken AD backend: $b") | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] | ||||||||
et = eltype(finitediff) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Shouldn't Enzyme return the correct types automatically? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In forward mode it returns tuples, and if the gradient is empty, the result is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the gradient should never be empty? Such a test would be quite useless, so I assume we don't run into this special case here? So maybe a simple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a corner case, but we ran into it here, when Bijectors.jl/test/bijectors/corr.jl Line 5 in 8654475
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest not testing AD in the case Bijectors.jl/test/bijectors/corr.jl Lines 32 to 33 in 8654475
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It passes now though, and I don't really see a downside to testing it? Good to know for instance that nothing crashes even if you hit this corner case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I also would prefer to not remove the test completely (even though I think it's of very limited use)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what the current test is effectively doing, because when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant something differently - removing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's how I understood you, but that would require something like adding another argument to |
||||||||
|
||||||||
if AD == "All" || AD == "ForwardDiff" | ||||||||
if :ForwardDiff in broken | ||||||||
|
@@ -30,6 +47,37 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) | |||||||
end | ||||||||
end | ||||||||
|
||||||||
# TODO(mhauru) The version bound should be relaxed once some Enzyme issues get | ||||||||
# sorted out. I think forward mode will remain broken for versions <= 1.6 due to | ||||||||
# some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and | ||||||||
# discussion in https://github.com/TuringLang/Bijectors.jl/pull/318. | ||||||||
if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10" | ||||||||
forward_broken = :EnzymeForward in broken || :Enzyme in broken | ||||||||
reverse_broken = :EnzymeReverse in broken || :Enzyme in broken | ||||||||
if forward_broken | ||||||||
@test_broken( | ||||||||
collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, | ||||||||
rtol = rtol, | ||||||||
atol = atol | ||||||||
) | ||||||||
else | ||||||||
@test( | ||||||||
collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, | ||||||||
rtol = rtol, | ||||||||
atol = atol | ||||||||
) | ||||||||
end | ||||||||
if reverse_broken | ||||||||
@test_broken( | ||||||||
Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol | ||||||||
) | ||||||||
else | ||||||||
@test( | ||||||||
Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol | ||||||||
) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" | ||||||||
rule = Tapir.build_rrule(f, x; safety_on=false) | ||||||||
if :tapir in broken | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change has nothing to do with this PR, I just spotted the typo while working on this PR and didn't feel like making a separate one character PR.