diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 77c0b021..781a8b6f 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -35,7 +35,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if :EnzymeReverse in broken @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test_broken( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol @@ -43,7 +44,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) elseif :EnzymeForward in broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol @@ -51,7 +53,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) elseif :Enzyme in broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test_broken( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol @@ -59,7 +62,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) else @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol