-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate fixing AD issues #151
base: new_API
Are you sure you want to change the base?
Investigate fixing AD issues #151
Conversation
Definitely AutoZygote is required on the outside instead of AutoReverseDiff. Investigating what's going on with Enzyme in ternally ```julia using PyCall certifi = PyCall.pyimport("certifi") ENV["SSL_CERT_FILE"] = certifi.where() using ODINN working_dir = joinpath(homedir(), "OGGM/ODINN_tests") MB = true fast = true atol = 2.0 params = Parameters(OGGM = OGGMparameters(working_dir=working_dir, multiprocessing=true), simulation = SimulationParameters(working_dir=working_dir, use_MB=MB, velocities=true, tspan=(2010.0, 2015.0), multiprocessing=false, workers=5, test_mode=true), hyper = Hyperparameters(batch_size=4, epochs=4, optimizer=ODINN.ADAM(0.01)), UDE = UDEparameters(target = "A", sensealg = ODINN.GaussAdjoint(autojacvec=ODINN.EnzymeVJP())) ) rgi_ids = ["RGI60-11.03638"] model = Model(iceflow = SIA2Dmodel(params), mass_balance = mass_balance = TImodel1(params; DDF=6.0/1000.0, acc_factor=1.2/1000.0), machine_learning = NN(params)) glaciers = initialize_glaciers(rgi_ids, params) functional_inversion = FunctionalInversion(model, glaciers, params) @time run!(functional_inversion) ```
SciML/SciMLSensitivity.jl#1060 is required to use GaussAdjoint for this. |
Requires ODINN-SciML/Sleipnir.jl#55, ODINN-SciML/Huginn.jl#53, and Enzyme#main (which in turn requires SciML/OptimizationBase.jl#53). With that the callback differentiates, but it hits: ERROR: AssertionError: Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Float64}} has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information which is an Enzyme bug in broadcast against a scalar. So we're close now, |
@@ -83,12 +82,12 @@ function build_D_features(H::Matrix, temp, ∇S) | |||
∇S_flat = ∇S[inn1(H) .!= 0.0] # flatten | |||
H_flat = H[H .!= 0.0] # flatten | |||
T_flat = repeat(temp,length(H_flat)) | |||
X = Flux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix | |||
X = Lux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lux
doesn't have a normalise
function
- `θ`: Neural network parameters | ||
""" | ||
function NN(params::Sleipnir.Parameters; | ||
architecture::Union{Flux.Chain, Nothing} = nothing, | ||
θ::Union{Vector{F}, Nothing} = nothing) where {F <: AbstractFloat} | ||
architecture::Union{Lux.Chain, Nothing} = nothing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a Chain
, won't AbstractExplicitLayer
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will, I'm just making small changes to isolate the real AD issues though.
Upstreaming issues found here:
The last issue is the real showstopper for now, it's not clear to me that there is a good workaround to that and it seems to me to just be a straight Enzyme bug to fix. The two look a bit bespoke to PyCall and TimerOutputs, the latter is easy to workaround, the former is weird because the workaround currently doesn't work seemingly due to some internals of PyCall. |
OK great, thanks for pushing this through @ChrisRackauckas! 👍🏻 Any idea how far is the solution to the last Enzyme issue? What is the issue with PyCall? As far as I understood, I thought we were bypassing the differentiation of all Python code for now which is present in the callback. |
As discussed in the thread, that issue really isn't an Enzyme issue in its own right, but rather that a type unstability exists around the broadcast which makes it illegal for certain code to be represented properly in Julia (if the type instability weren't there Enzyme can write the derivative update totally fine). recopying the backtrace here, since I think the better solution would just be fixing the type instability.
|
@ChrisRackauckas @JordiBolibar where is the inactive function you have for pycall. Indeed I would've expected a We could also teach pycall object construction how to properly deal with allocations in an enzyme rule, but the fact you say there's an inactive marking that's not triggering seems like the bigger issue. |
Another error I ran into was Tullio support. @wsmoses is that known? |
All tullio examples that folks have sent us are functional last we checked.
Open an issue?
…On Tue, Jun 4, 2024 at 7:26 PM Christopher Rackauckas < ***@***.***> wrote:
Another error I ran into was Tullio support. @wsmoses
<https://github.com/wsmoses> is that known?
—
Reply to this email directly, view it on GitHub
<#151 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXBVOCUQJRSFVLPUBLTZFX2F7AVCNFSM6AAAAABINIVDLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBYGA2TCOJVHA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
As I mentioned in the Huginn PR, Tullio should be easily avoidable here, since we only use it in the out-of-place version of the function that was needed for Zygote and ReverseDiff. Migrating to the in-place version of the function should avoid that. |
Training iceflow UDE...
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 1
└ @ MLUtils ~/.julia/packages/MLUtils/LmmaQ/src/batchview.jl:95
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 1
└ @ MLUtils ~/.julia/packages/MLUtils/LmmaQ/src/batchview.jl:95
Before solving ODE problem
ODE problem solved for 1
over here for 1
simulation finished for 1
Batch 1 finished!
All batches finished
Loss computed: 13.506396213311314
┌ Warning: Automatic dt set the starting dt as NaN, causing instability. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/F67Rp/src/solve.jl:591
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/integrator_interface.jl:593
ERROR: DimensionMismatch: array could not be broadcast to match destination
Stacktrace:
[1] check_broadcast_shape
@ ./broadcast.jl:579 [inlined]
[2] check_broadcast_axes
@ ./broadcast.jl:582 [inlined]
[3] instantiate
@ ./broadcast.jl:309 [inlined]
[4] materialize!
@ ./broadcast.jl:914 [inlined]
[5] materialize!
@ ./broadcast.jl:911 [inlined]
[6] vec_pjac!(out::ComponentArrays.ComponentVector{…}, λ::Vector{…}, y::Matrix{…}, t::Float64, S::SciMLSensitivity.AdjointSensitivityIntegrand{…})
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:288
[7] AdjointSensitivityIntegrand
@ ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:309 [inlined]
[8] (::SciMLSensitivity.AdjointSensitivityIntegrand{…})(t::Float64)
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:321
[9] evalrule(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, a::Float64, b::Float64, x::Vector{…}, w::Vector{…}, gw::Vector{…}, nrm::typeof(LinearAlgebra.norm))
@ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/evalrule.jl:0
[10] #6
@ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:15 [inlined]
[11] ntuple
@ ./ntuple.jl:48 [inlined]
[12] do_quadgk(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…}, n::Int64, atol::Float64, rtol::Float64, maxevals::Int64, nrm::typeof(LinearAlgebra.norm), segbuf::Nothing)
@ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:13
[13] #50
@ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:253 [inlined]
[14] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…})
@ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:145
[15] quadgk(::SciMLSensitivity.AdjointSensitivityIntegrand{…}, ::Float64, ::Vararg{…}; atol::Float64, rtol::Float64, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing)
@ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:252
[16] _adjoint_sensitivities(sol::SciMLBase.ODESolution{…}, sensealg::SciMLSensitivity.QuadratureAdjoint{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::SciMLBase.CallbackSet{…}, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:382
[17] _adjoint_sensitivities
@ ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:324 [inlined]
[18] #adjoint_sensitivities#63
@ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:383 [inlined]
[19] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::SciMLBase.ODESolution{…})
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:556
[20] ZBack
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
[21] (::Zygote.var"#kw_zpullback#53"{…})(dy::SciMLBase.ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
[22] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[23] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[24] #solve#51
@ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [inlined]
[25] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[26] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[27] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[28] solve
@ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:993 [inlined]
[29] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[30] #simulate_iceflow_UDE!#32
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:187 [inlined]
[31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[32] simulate_iceflow_UDE!
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:171 [inlined]
[33] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[34] batch_iceflow_UDE
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:146 [inlined]
[35] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[36] #28
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
[37] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[38] #680
@ ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201 [inlined]
[39] #235
@ ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157 [inlined]
[40] (::Base.var"#1023#1028"{Distributed.var"#235#236"{…}})(r::Base.RefValue{Any}, args::Tuple{Tuple{…}})
@ Base ./asyncmap.jl:94
[41] (::Base.var"#1039#1040"{Base.var"#1023#1028"{Distributed.var"#235#236"{Zygote.var"#680#685"}}, Channel{Any}, Nothing})()
@ Base ./asyncmap.jl:228
Stacktrace:
[1] (::Base.var"#1033#1035")(x::Task)
@ Base ./asyncmap.jl:171
[2] foreach(f::Base.var"#1033#1035", itr::Vector{Any})
@ Base ./abstractarray.jl:3097
[3] maptwice(wrapped_f::Function, chnl::Channel{Any}, worker_tasks::Vector{Any}, c::Base.Iterators.Zip{Tuple{…}})
@ Base ./asyncmap.jl:171
[4] wrap_n_exec_twice
@ ./asyncmap.jl:147 [inlined]
[5] #async_usemap#1018
@ ./asyncmap.jl:97 [inlined]
[6] async_usemap
@ ./asyncmap.jl:78 [inlined]
[7] #asyncmap#1017
@ ./asyncmap.jl:75 [inlined]
[8] asyncmap
@ ./asyncmap.jl:74 [inlined]
[9] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{…}; distributed::Bool, batch_size::Int64, on_error::Nothing, retry_delays::Vector{…}, retry_check::Nothing)
@ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:126
[10] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}})
@ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:99
[11] pmap(f::Function, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}}; kwargs::@Kwargs{})
@ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156
[12] pmap
@ ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156 [inlined]
[13] pmap(f::Function, c1::Vector{Tuple{…}}, c::Vector{@NamedTuple{…}})
@ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157
[14] (::Zygote.var"#map_back#682"{ODINN.var"#28#29"{…}, 1, Tuple{…}, Tuple{…}, Vector{…}})(Δ::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201
[15] (::Zygote.var"#2861#back#688"{Zygote.var"#map_back#682"{…}})(Δ::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[16] predict_iceflow!
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
[17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[18] loss_iceflow
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:58 [inlined]
[19] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[20] #25
@ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:31 [inlined]
[21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[22] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[23] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[24] OptimizationFunction
@ ~/.julia/packages/SciMLBase/sakPO/src/scimlfunctions.jl:3762 [inlined]
[25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[26] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[27] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[28] #37
@ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:90 [inlined]
[29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[30] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[31] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[32] #39
@ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93 [inlined]
[33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[34] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
[35] gradient(f::Function, args::ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
[36] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentArrays.ComponentVector{…}, ::ComponentArrays.ComponentVector{…}, ::Vector{…}, ::Vararg{…})
@ OptimizationZygoteExt ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93
[37] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[38] macro expansion
@ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
[39] (::OptimizationOptimisers.var"#12#13"{OptimizationBase.OptimizationCache{…}, ComponentArrays.ComponentVector{…}})()
@ OptimizationOptimisers ~/.julia/packages/Optimization/jWtfU/src/utils.jl:29
[40] maybe_with_logger(f::OptimizationOptimisers.var"#12#13"{…}, logger::Nothing)
@ Optimization ~/.julia/packages/Optimization/jWtfU/src/utils.jl:7
[41] macro expansion
@ ~/.julia/packages/Optimization/jWtfU/src/utils.jl:28 [inlined]
[42] __solve(cache::OptimizationBase.OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[43] solve!(cache::OptimizationBase.OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:188
[44] solve(prob::SciMLBase.OptimizationProblem{…}, alg::Optimisers.Adam, args::IterTools.NCycle{…}; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:96
[45] train_UDE!(simulation::FunctionalInversion)
@ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:43
[46] run!(simulation::FunctionalInversion)
@ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:11
[47] macro expansion
@ ./timing.jl:279 [inlined]
[48] top-level scope
@ ~/Desktop/test.jl:76
Some type information was truncated. Use `show(err)` to see complete types. Hokay the autodiff is working using the SIA2D!, but in the adjoint the dt goes to NaN which crashes it and I'll need to investigate that. |
oh if the issue is a nan, enzyme (on cpu only presently) has a nan checker which will throw a backtrace the first time a nan is generated for a derivative. set Clearly this needs more docs [PRs welcome ofc!]. |
Definitely AutoZygote is required on the outside instead of AutoReverseDiff. Investigating what's going on with Enzyme in ternally