-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Benchmarking some very simple Flux models #361
Comments
Thanks for this -- not premature at all. My guess is that Zygote (I assume that's what Flux is using by default still?) has rules for everything here, but Mooncake is deriving its own rule in some performance-critical case. I'll have to take a look at a profile and see what rule(s) I'm missing that are causing poor performance. |
Yes to Zygote. The first should not need very exotic rules, matrix multiplication and Maybe worth including an even simpler Flux-free example: julia> @btime Flux.gradient(x -> sum(abs2, x), $img)[1][1:3] # really Zygote
min 20.875 μs, mean 48.610 μs (4 allocations, 392.19 KiB)
3-element Vector{Float32}:
1.7552974
1.0280211
0.57361186
julia> let f = x -> sum(abs2, x)
backend = DifferentiationInterface.AutoMooncake(; config=nothing)
prep = @btime DifferentiationInterface.prepare_gradient($f, $backend, $img)
grad = @btime DifferentiationInterface.gradient($f, $prep, $backend, $img)
grad[1:3]
end
min 2.475 ms, mean 3.286 ms (6687 allocations, 12.59 MiB)
min 812.959 μs, mean 866.172 μs (5 allocations, 392.44 KiB)
3-element Vector{Float32}:
1.7552974
1.0280211
0.57361186 and Enzyme, without/with pre-allocating space for the gradient: julia> @btime Enzyme.gradient(Reverse, x -> sum(abs2, x), $img)[1][1:3]
min 77.083 μs, mean 103.718 μs (4 allocations, 392.19 KiB)
3-element Vector{Float32}:
1.7552974
1.0280211
0.57361186
julia> @btime Enzyme.autodiff(Reverse, x -> sum(abs2, x), Active, $(Duplicated(img, zero.(img))));
min 75.625 μs, mean 82.092 μs (0 allocations) |
Yeah, I agree that it ought not to require anything particularly exotic. Mooncake has rules for In terms of the performance you're seeing from Mooncake here, there are two things happening in this example:
To get a sense of the amount of time taken allocating memory: julia> using BenchmarkTools, Mooncake
julia> img = rand(Float32, 28, 28, 1, 128);
julia> f(x) = sum(abs2, x)
f (generic function with 1 method)
julia> rule = build_rrule(f, img);
julia> @benchmark Mooncake.value_and_gradient!!($rule, f, $img)
BenchmarkTools.Trial: 6059 samples with 1 evaluation.
Range (min … max): 694.000 μs … 20.780 ms ┊ GC (min … max): 0.00% … 95.46%
Time (median): 750.042 μs ┊ GC (median): 0.00%
Time (mean ± σ): 823.888 μs ± 549.142 μs ┊ GC (mean ± σ): 2.50% ± 3.75%
▆▆█▂ ▂▁
▆▆████▇▅▄▃▃▃▂▂▂▂▂▃▅▅███▇▅▄▄▃▃▂▂▂▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▂▂ ▃
694 μs Histogram: frequency by time 1.17 ms <
Memory estimate: 392.41 KiB, allocs estimate: 5.
julia> @benchmark Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)))
BenchmarkTools.Trial: 6520 samples with 1 evaluation.
Range (min … max): 689.917 μs … 22.167 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 709.292 μs ┊ GC (median): 0.00%
Time (mean ± σ): 765.979 μs ± 278.923 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
█▆▆▅▅▄▂▁ ▁▄▅▅▅▄▄▃▃▃▂▁ ▂
████████▇▆▅▃▄▄▃▅▄▄▄▂▃▃▅▂▅▅▅▆██████████████▇▇▇▆▅▆▅▆▆▆▆▇▆▇▆▇▆▇▆ █
690 μs Histogram: log(frequency) by time 972 μs <
Memory estimate: 0 bytes, allocs estimate: 0. If I add the following rule to Mooncake: julia> function Mooncake.rrule!!(::CoDual{typeof(sum)}, ::CoDual{typeof(abs2)}, x::CoDual{<:Array{P}}) where {P<:IEEEFloat}
function sum_abs2_pb(dy::P)
x.dx .+= (2 * dy) .* x.x
return NoRData(), NoRData(), NoRData()
end
return zero_fcodual(sum(abs2, x.x)), sum_abs2_pb
end
julia> Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(sum), typeof(abs2), Array{<:IEEEFloat}} and recompute julia> @benchmark Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 44.166 μs … 370.417 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 45.750 μs ┊ GC (median): 0.00%
Time (mean ± σ): 49.663 μs ± 9.479 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▅▅█▆▃ ▁▂▂▃▄▄▄▄▃▂▁ ▁
██████▇▇▇▇▇▆██████████████▇▇█▇▇▆▆▆▆▇▆▆▆▆▆▅▆▅▆▆▅▅▆▆▆▅▄▅▅▅▅▅▄▅ █
44.2 μs Histogram: log(frequency) by time 76.4 μs <
Memory estimate: 0 bytes, allocs estimate: 0. which shows that Mooncake is leaving quite a lot on the table in this case, and highlights just how fast Enzyme is in this situation -- getting to within a factor of two of hand-written performance without writing a rule is quite impressive. That being said, the primal only takes a few microseconds: julia> @benchmark f($img)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
Range (min … max): 6.892 μs … 31.825 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 6.917 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.236 μs ± 855.013 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█ ▄▂ ▂▁▁▁▁▂▁▁▁ ▁
█▆██▅▅▅▆▆▅▅▅▃▄▅▅▇▇▇▇▇▆▇▆▇▇▇████████████▅▆▅▆▆▅▆▄▄▅▄▃▄▃▅▃▅▅▅▄ █
6.89 μs Histogram: log(frequency) by time 9.49 μs <
Memory estimate: 0 bytes, allocs estimate: 0. so I suspect that even my hand-written rule isn't optimal. I wonder if |
What's the vision here? Do you think that, long-term, operations like For julia> using BenchmarkTools, Mooncake, Zygote
julia> img = rand(Float32, 28, 28, 1, 128);
julia> f(x) = sum(abs2, x)
f (generic function with 1 method)
julia> @btime f($img);
5.667 μs (0 allocations: 0 bytes)
julia> @btime Zygote.gradient(f, $img);
21.292 μs (3 allocations: 392.09 KiB)
julia> rule = build_rrule(f, img);
julia> @btime Mooncake.value_and_gradient!!($rule, f, $img);
630.542 μs (5 allocations: 392.41 KiB)
julia> @btime Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)));
621.708 μs (0 allocations: 0 bytes)
julia> function Mooncake.rrule!!(::Mooncake.CoDual{typeof(sum)}, ::Mooncake.CoDual{typeof(abs2)}, x::Mooncake.CoDual{<:Array{P}}) where {P<:Mooncake.IEEEFloat}
function sum_abs2_pb(dy::P)
x.dx .+= (2 * dy) .* x.x
return NoRData(), NoRData(), NoRData()
end
return Mooncake.zero_fcodual(sum(abs2, x.x)), sum_abs2_pb
end
julia> Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(sum), typeof(abs2), Array{<:Mooncake.IEEEFloat}}
julia> rule2 = build_rrule(f, img);
julia> @btime Mooncake.__value_and_gradient!!($rule2, Mooncake.zero_codual(f), $(Mooncake.zero_codual(img)));
37.708 μs (0 allocations: 0 bytes) # on Julia 1.10, I got 22.375 μs
(jl_MaAIiw) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_MaAIiw/Project.toml`
[da2b9cff] Mooncake v0.4.36
julia> VERSION
v"1.11.0"
Ok, I tried it. With a slightly harder function than Here I think Zygote is using ForwardDiff, because that's what it does. (I don't know whether Enzyme & Mooncake switch to forward mode for broadcasting.) using Zygote, Enzyme, Mooncake
img = rand(Float32, 28, 28, 1, 128);
red(x) = sum(x -> atan(x, 2f0), x); # sum(f,x) which is is certain not to have an existing rule
red2(x) = sum(atan.(x, 2f0)); # implementation using broadcast instead
let
g1 = @btime Zygote.gradient(red, $img) # red(x)=sum(sin,x) allocates less memory here, 1.15 MiB
g2 = @btime Zygote.gradient(red2, $img) # also here, 1.53 MiB
g2[1][1:3]
end
# min 510.625 μs, mean 647.758 μs (39 allocations, 1.92 MiB)
# min 645.250 μs, mean 855.522 μs (52 allocations, 2.30 MiB)
let
g1 = @btime Enzyme.gradient(Reverse, red, $img) # this is the only large change with red(x)=sum(sin,x) instead, -> 290.916 μs
g2 = @btime Enzyme.gradient(Reverse, red2, $img)
g2[1][1:3]
end
# min 77.625 μs, mean 104.632 μs (3 allocations, 392.11 KiB)
# min 939.375 μs, mean 1.051 ms (8 allocations, 1.15 MiB)
let
rule = build_rrule(red, img); # this step min 551.914 ns, mean 603.759 ns (10 allocations, 488 bytes)
g1 = @btime Mooncake.value_and_gradient!!($rule, $red, $img);
rule2 = build_rrule(red2, img);
g2 = @btime Mooncake.value_and_gradient!!($rule2, $red2, $img);
g2[2][2][1:3]
end
# min 926.291 μs, mean 1.023 ms (5 allocations, 392.44 KiB)
# min 2.521 ms, mean 2.672 ms (13 allocations, 1.15 MiB) |
Agree with the points you make here. Regarding rules -- yes, I definitely anticipate requiring many fewer rules in the short to medium term. Also agree re tracking performance -- if you take a look a Mooncake PR (eg. #362) you'll see a table of benchmark numbers that get automatically printed into the PR -- they track the AD time / primal time for Mooncake and some others for a range of functions. In particular you'll see results for various variants of
Long-term I would hope not, but in the short-term they're going to continue to need them (sadly). Regarding the numbers you're seeing: I see similar numbers for Mooncake locally. Enzyme is doing something truly impressive to get For context, Mooncake doesn't have any rules for higher order functions, and doesn't have any special tricks to make use of forwards-mode AD in situations where it might be advantageous to do so, so what's you're seeing here is just Mooncake doing reverse-mode AD on Julia IR. That is, it's ADing this just with a rule for Moving forwards, is there any chance of the Flux maintainers publishing a set of functions / models that they say "this is the core of the library, and is the bit that really needs to be fast"? In the short term I am quite interested in ensuring people can get good performance using Mooncake for fairly standard DL models, even if the performance drops off a bit for less standard stuff. |
Indeed, quicker than the primal for me too. Very clever but probably not the number we want.
We're not very systematic, but these two are the simplest cases in the model zoo. I have not tried any GPU models. FluxML/Flux.jl#2471 wants to make it super-easy to try Enzyme instead of Zygote. It would be nice to do something similar for Mooncake. Since both (I believe) are happiest when the gradient is allocated up front, making some comparably easy way to do that might be nice. |
Correct -- preallocation of gradients storage will tend to give the best performance. |
So the ideal Flux would be for something like Enzyme's Duplicated to be available:
Thinking about extensions etc, it would actually be simplest for Flux to own this. But perhaps that's weird. Maybe this is the wrong thread for such interface discussion, though. Maybe FluxML/Flux.jl#2471 is the right one? |
I don't know whether it's premature to do so, but since we're thinking about how Flux interacts with AD, I tried out a few very simple cases. Exactly the same cases as in EnzymeAD/Enzyme.jl#2069 : CPU only, Float32, one model without calling
NNlib.conv
and one with.Versions:
The text was updated successfully, but these errors were encountered: