You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
All I am trying to do is to calculate the hessian of some function.
I've identified three problems:
1)
Current scope:
; Function Attrs: alwaysinline mustprogress willreturn
error when trying to run F1, which goes away by replacing the loop with a nested one (F2)
F1:
for i = 1:400
:(
...
F2 (nested loop):
for i = 1:200
for j=1:2
:)
...
The problem also doesn't happen when doing fewer iterations (~ < 40) (F1_less).
I've also tested the outer loop with more than 10000 iterations and it also worked; however, if I increase the inner loop above ~20 I get the same error.
The F3 definition gives the wrong result. This one uses the function cos.(x) in a for loop. Note that this doesn't happen if I manually nest the function call (F4) or manually (F5)*.
Additionally, I've found out that if I run F2 (correct result, not the same as the previous F2 of problem (1) because I messed up naming, I think it is clear in the code tho) after F3 (wrong) WITHOUT zeroing dx, dy and hess, F2 also gets the wrong result (numerically the same as F3). This doesn't happen if I zero the derivative arrays after F3 or if I run F2 two consecutive times (without zeroing the arrays).
*F5 has its own problem.
F5 (loop unrolled) doesn't compile if I include z[1] and z[2] in the output. (note that for F52, doesn't compile, y[1] = z[1]*z[1] + z[2]*z[2] and for F5, which does compile, y[1] = z[1]*z[1].
Also F5 takes much longer to compile than anything else.
I'll also just throw the question here: Note the amount of allocations when using the cos.(x) function (F3 and F5) - and to a lesser extent F4. Is this expected, a bug or user error? Would the correct way to do this be like F1 SVector{N,T}(cos(zz) for zz in z) or something?
To verify the values I've used Casadi in Python:
x = ca.SX.sym("x",2,1)
def f(x):
y = ca.cos(x)
for i in range(20):
y = ca.cos(y)
return y[0]**2 + y[1]**2
f = ca.Function("f",[x],ca.cse(ca.hessian(f(x),x)),["x"],["y","y2"],{"jit":True,"jit_options":{"flags":["-O3", "-march=native"]}})
#time_function(f,s=3)
f(1)
Just for fun, the above Casadi function (but with 400 iterations loop) benchmarked: Mean: 21.562 us; standard deviation: 1.006 us; min: 21.219 us; max: 54.193 us
vs ~9.3us with Enzyme. Pretty cool.
For completeness here's the code again:
using Enzyme
using StaticArrays
using BenchmarkTools
function cos_vec(z::SVector{N,T}) where {N,T}
return SVector{N,T}(cos(zz) for zz in z)
end
function cos_vec(z::NTuple{N,T}) where {N,T}
return NTuple{N,T}(cos(zz) for zz in z)
end
function cos_vec(z::AbstractArray{T}) where {T}
return T[cos(zz) for zz in z]
end
function f1!(x, y)
z1,z2 = cos_vec((x[1],x[2]))
for i = 1:400
z1,z2 = cos_vec((z1,z2))
end
y[1] = z1*z1 + z2*z2
return nothing
end
function f1_less!(x, y)
z1,z2 = cos_vec((x[1],x[2]))
for i = 1:20
z1,z2 = cos_vec((z1,z2))
end
y[1] = z1*z1 + z2*z2
return nothing
end
function f2!(x, y)
z1,z2 = cos_vec((x[1],x[2]))
for i = 1:200
for j=1:2
z1,z2 = cos_vec((z1,z2))
end
end
y[1] = z1*z1 + z2*z2
return nothing
end
function f2_less!(x, y)
z1,z2 = cos_vec((x[1],x[2]))
for i = 1:10
for j=1:2
z1,z2 = cos_vec((z1,z2))
end
end
y[1] = z1*z1 + z2*z2
return nothing
end
function grad!(x, dx, y, dy,f)
Enzyme.autodiff(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy),)
nothing
end
x = [1.0, 1.0]
y = Vector{Float64}(undef, 1)
dx = [0.0, 0.0]
dy = [1.0]
hess = ([0.0, 0.0], [0.0, 0.0])
x = ca.SX.sym("x",2,1)
def f(x):
y = ca.cos(x)
for i in range(20):
y = ca.cos(y)
return y[0]**2 + y[1]**2
f = ca.Function("f",[x],ca.cse(ca.hessian(f(x),x)),["x"],["y","y2"],{"jit":True,"jit_options":{"flags":["-O3", "-march=native"]}})
f(1)
F5 (note I changed the output to be only z[1]^2, works):
function f5!(x, y)
z = cos.(x)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
y[1] = z[1]*z[1]# + z[2]*z[2]
return nothing
end
dx = [0.0, 0.0]
dy = [1.0]
vx = ([1.0, 0.0], [0.0, 1.0])
hess = ([0.0, 0.0], [0.0, 0.0])
Enzyme.autodiff(Enzyme.Forward, grad!, Enzyme.BatchDuplicated(x, vx), Enzyme.BatchDuplicated(dx, hess), Const(y), Const(dy),Const(f5!));
@show hess;
@show dx;
hess = ([0.000270064862399231, 0.0], [0.0, 0.0])
dx = [-0.00036669251092078, 0.0]
BenchmarkTools.Trial: 10000 samples with 8 evaluations.
Range (min … max): 3.996 μs … 2.585 ms ┊ GC (min … max): 0.00% … 99.46%
Time (median): 4.877 μs ┊ GC (median): 0.00%
Time (mean ± σ): 6.813 μs ± 38.121 μs ┊ GC (mean ± σ): 10.71% ± 1.98%
▆ █▃
▄█▄▃▄███▅▂▂▂▂▁▁▁▁▁▁▂▂▁▂▂▂▂▃▃▃▃▃▃▃▂▂▃▄▅▅▅▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
4 μs Histogram: frequency by time 10.3 μs <
Memory estimate: 9.75 KiB, allocs estimate: 250.
F5_2 (I changed back the output to z[1]^2 + z[2]^2, compilation fails):
function f52!(x, y)
z = cos.(x)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
z = cos.(z)
y[1] = z[1]*z[1] + z[2]*z[2]
return nothing
end
dx = [0.0, 0.0]
dy = [1.0]
vx = ([1.0, 0.0], [0.0, 1.0])
hess = ([0.0, 0.0], [0.0, 0.0])
Enzyme.autodiff(Enzyme.Forward, grad!, Enzyme.BatchDuplicated(x, vx), Enzyme.BatchDuplicated(dx, hess), Const(y), Const(dy),Const(f52!));
@show hess;
@show dx;
EDIT:
Julia 1.11.1
Enzyme v0.13.13
StaticArrays v1.9.8
6.11.6-2-cachyos
AMD Ryzen 7 8845hs
Newbie in Julia and Enzyme here so forgive me if it's a skill issue.
First here's the gist with the problem documented: https://gist.github.com/ernestds/8dd881253ad7b49617da57aa22e61548
All I am trying to do is to calculate the hessian of some function.
I've identified three problems:
1)
error when trying to run F1, which goes away by replacing the loop with a nested one (F2)
F1:
F2 (nested loop):
The problem also doesn't happen when doing fewer iterations (~ < 40) (F1_less).
I've also tested the outer loop with more than 10000 iterations and it also worked; however, if I increase the inner loop above ~20 I get the same error.
The F3 definition gives the wrong result. This one uses the function cos.(x) in a for loop. Note that this doesn't happen if I manually nest the function call (F4) or manually (F5)*.
Additionally, I've found out that if I run F2 (correct result, not the same as the previous F2 of problem (1) because I messed up naming, I think it is clear in the code tho) after F3 (wrong) WITHOUT zeroing
dx
,dy
andhess
, F2 also gets the wrong result (numerically the same as F3). This doesn't happen if I zero the derivative arrays after F3 or if I run F2 two consecutive times (without zeroing the arrays).*F5 has its own problem.
F5 (loop unrolled) doesn't compile if I include z[1] and z[2] in the output. (note that for F52, doesn't compile,
y[1] = z[1]*z[1] + z[2]*z[2]
and for F5, which does compile,y[1] = z[1]*z[1]
.Also F5 takes much longer to compile than anything else.
I'll also just throw the question here: Note the amount of allocations when using the
cos.(x)
function (F3 and F5) - and to a lesser extent F4. Is this expected, a bug or user error? Would the correct way to do this be like F1SVector{N,T}(cos(zz) for zz in z)
or something?To verify the values I've used Casadi in Python:
Just for fun, the above Casadi function (but with 400 iterations loop) benchmarked:
Mean: 21.562 us; standard deviation: 1.006 us; min: 21.219 us; max: 54.193 us
vs ~9.3us with Enzyme. Pretty cool.
For completeness here's the code again:
F1:
F1 with less iterations: works with correct result
ground truth from Casadi:
Problem 2:
F2a (different than the last F2 because I messed up the naming along the way, result is correct)
F3 (result is wrong)
F2 after F3
interestingly the result of F2 also goes wrong (coinciding with F3) if I run it directly after F3 without zeroing dx, dy and hess:
Just some other tries:
F4 (works fine):
F5 (note I changed the output to be only z[1]^2, works):
F5_2 (I changed back the output to z[1]^2 + z[2]^2, compilation fails):
The text was updated successfully, but these errors were encountered: