Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate fixing AD issues #151

Open
wants to merge 5 commits into
base: new_API
Choose a base branch
from

Conversation

ChrisRackauckas
Copy link
Member

Definitely AutoZygote is required on the outside instead of AutoReverseDiff. Investigating what's going on with Enzyme in ternally

using PyCall
certifi = PyCall.pyimport("certifi")
ENV["SSL_CERT_FILE"] = certifi.where()

using ODINN

working_dir = joinpath(homedir(), "OGGM/ODINN_tests")
MB = true
fast = true
atol = 2.0

params = Parameters(OGGM = OGGMparameters(working_dir=working_dir,
                                              multiprocessing=true),
                        simulation = SimulationParameters(working_dir=working_dir,
                                                        use_MB=MB,
                                                        velocities=true,
                                                        tspan=(2010.0, 2015.0),
                                                        multiprocessing=false,
                                                        workers=5,
                                                        test_mode=true),
                        hyper = Hyperparameters(batch_size=4,
                                                epochs=4,
                                                optimizer=ODINN.ADAM(0.01)),
                        UDE = UDEparameters(target = "A",
                                            sensealg = ODINN.GaussAdjoint(autojacvec=ODINN.EnzymeVJP()))
    )

rgi_ids = ["RGI60-11.03638"]

model = Model(iceflow = SIA2Dmodel(params),
                mass_balance = mass_balance = TImodel1(params; DDF=6.0/1000.0, acc_factor=1.2/1000.0),
                machine_learning = NN(params))
glaciers = initialize_glaciers(rgi_ids, params)
functional_inversion = FunctionalInversion(model, glaciers, params)

@time run!(functional_inversion)

Definitely AutoZygote is required on the outside instead of AutoReverseDiff. Investigating what's going on with Enzyme in ternally

```julia
using PyCall
certifi = PyCall.pyimport("certifi")
ENV["SSL_CERT_FILE"] = certifi.where()

using ODINN

working_dir = joinpath(homedir(), "OGGM/ODINN_tests")
MB = true
fast = true
atol = 2.0

params = Parameters(OGGM = OGGMparameters(working_dir=working_dir,
                                              multiprocessing=true),
                        simulation = SimulationParameters(working_dir=working_dir,
                                                        use_MB=MB,
                                                        velocities=true,
                                                        tspan=(2010.0, 2015.0),
                                                        multiprocessing=false,
                                                        workers=5,
                                                        test_mode=true),
                        hyper = Hyperparameters(batch_size=4,
                                                epochs=4,
                                                optimizer=ODINN.ADAM(0.01)),
                        UDE = UDEparameters(target = "A",
                                            sensealg = ODINN.GaussAdjoint(autojacvec=ODINN.EnzymeVJP()))
    )

rgi_ids = ["RGI60-11.03638"]

model = Model(iceflow = SIA2Dmodel(params),
                mass_balance = mass_balance = TImodel1(params; DDF=6.0/1000.0, acc_factor=1.2/1000.0),
                machine_learning = NN(params))
glaciers = initialize_glaciers(rgi_ids, params)
functional_inversion = FunctionalInversion(model, glaciers, params)

@time run!(functional_inversion)
```
@ChrisRackauckas
Copy link
Member Author

SciML/SciMLSensitivity.jl#1060 is required to use GaussAdjoint for this.

@ChrisRackauckas
Copy link
Member Author

Requires ODINN-SciML/Sleipnir.jl#55, ODINN-SciML/Huginn.jl#53, and Enzyme#main (which in turn requires SciML/OptimizationBase.jl#53). With that the callback differentiates, but it hits:

ERROR: AssertionError: Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Float64}} has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information

which is an Enzyme bug in broadcast against a scalar. So we're close now,

@@ -83,12 +82,12 @@ function build_D_features(H::Matrix, temp, ∇S)
∇S_flat = ∇S[inn1(H) .!= 0.0] # flatten
H_flat = H[H .!= 0.0] # flatten
T_flat = repeat(temp,length(H_flat))
X = Flux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix
X = Lux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux doesn't have a normalise function

- `θ`: Neural network parameters
"""
function NN(params::Sleipnir.Parameters;
architecture::Union{Flux.Chain, Nothing} = nothing,
θ::Union{Vector{F}, Nothing} = nothing) where {F <: AbstractFloat}
architecture::Union{Lux.Chain, Nothing} = nothing,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a Chain, won't AbstractExplicitLayer work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will, I'm just making small changes to isolate the real AD issues though.

@ChrisRackauckas
Copy link
Member Author

Upstreaming issues found here:

The last issue is the real showstopper for now, it's not clear to me that there is a good workaround to that and it seems to me to just be a straight Enzyme bug to fix. The two look a bit bespoke to PyCall and TimerOutputs, the latter is easy to workaround, the former is weird because the workaround currently doesn't work seemingly due to some internals of PyCall.

@JordiBolibar
Copy link
Member

OK great, thanks for pushing this through @ChrisRackauckas! 👍🏻 Any idea how far is the solution to the last Enzyme issue?

What is the issue with PyCall? As far as I understood, I thought we were bypassing the differentiation of all Python code for now which is present in the callback.

@wsmoses
Copy link

wsmoses commented May 31, 2024

As discussed in the thread, that issue really isn't an Enzyme issue in its own right, but rather that a type unstability exists around the broadcast which makes it illegal for certain code to be represented properly in Julia (if the type instability weren't there Enzyme can write the derivative update totally fine).

recopying the backtrace here, since I think the better solution would just be fixing the type instability.

ERROR: AssertionError: Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Float64}} has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information
Stacktrace:
  [1] active_reg
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:525 [inlined]
  [2] active_reg
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:516 [inlined]
  [3] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Base.Broadcast.materialize), df::Nothing, primal_1::Base.Broadcast.Broadcasted{…}, shadow_1_1::Base.Broadcast.Broadcasted{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/F71IJ/src/rules/jitrules.jl:66
  [4] #SIA2D#9
    @ ~/.julia/dev/Huginn/src/models/iceflow/SIA2D/SIA2D_utils.jl:120
  [5] SIA2D
    @ ~/.julia/dev/Huginn/src/models/iceflow/SIA2D/SIA2D_utils.jl:89 [inlined]
  [6] SIA2D_UDE
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:231 [inlined]
  [7] SIA2D_UDE_closure
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:182 [inlined]
  [8] ODEFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:2296 [inlined]
  [9] #138
    @ ~/.julia/dev/SciMLSensitivity/src/adjoint_common.jl:450 [inlined]
 [10] diffejulia__138_21210_inner_1wrap
    @ ~/.julia/dev/SciMLSensitivity/src/adjoint_common.jl:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:5916 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:5566 [inlined]
 [13] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:5443 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/F71IJ/src/Enzyme.jl:291 [inlined]
 [15] _vecjacobian!(dλ::Vector{…}, y::Matrix{…}, λ::Vector{…}, p::ComponentArrays.ComponentVector{…}, t::Float64, S::SciMLSensitivity.ODEGaussAdjointSensitivityFunction{…}, isautojacvec::SciMLSensitivity.EnzymeVJP, dgrad::Nothing, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:714
 [16] #vecjacobian!#18
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:231 [inlined]
 [17] vecjacobian!
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:228 [inlined]
 [18] (::SciMLSensitivity.ODEGaussAdjointSensitivityFunction{…})(du::Vector{…}, u::Vector{…}, p::ComponentArrays.ComponentVector{…}, t::Float64)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/gauss_adjoint.jl:102
 [19] ODEFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:2296 [inlined]
 [20] ode_determine_initdt(u0::Vector{…}, t::Float64, tdir::Float64, dtmax::Float64, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::SciMLBase.ODEProblem{…}, integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/initdt.jl:53
 [21] auto_dt_reset!(integrator::OrdinaryDiffEq.ODEIntegrator)
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/integrators/integrator_interface.jl:474 [inlined]
 [22] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:580
 [23] __init(prob::SciMLBase.ODEProblem{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::SciMLBase.CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:533
 [24] __init(prob::Union{…}, alg::Union{…}, timeseries_init::Any, ts_init::Any, ks_init::Any, recompile::Type{…}) where recompile_flag (repeats 5 times)
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [25] #__solve#805
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:6 [inlined]
 [26] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:1 [inlined]
 [27] solve_call(_prob::SciMLBase.ODEProblem{…}, args::OrdinaryDiffEq.RDPK3Sp35{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:612
 [28] solve_call
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:569 [inlined]
 [29] #solve_up#53
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1080 [inlined]
 [30] solve_up
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1066 [inlined]
 [31] #solve#51
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1003 [inlined]
 [32] _adjoint_sensitivities(sol::SciMLBase.ODESolution{…}, sensealg::SciMLSensitivity.GaussAdjoint{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::SciMLBase.CallbackSet{…}, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/gauss_adjoint.jl:540
 [33] _adjoint_sensitivities
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/gauss_adjoint.jl:507 [inlined]
 [34] #adjoint_sensitivities#63
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:383 [inlined]
 [35] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::SciMLBase.ODESolution{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:556
 [36] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [37] (::Zygote.var"#kw_zpullback#53"{…})(dy::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [38] #291
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [39] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [40] #solve#51
    @ ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1003 [inlined]
 [41] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [42] #291
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [43] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [44] solve
    @ ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:993 [inlined]
 [45] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [46] #simulate_iceflow_UDE!#29
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:187 [inlined]
 [47] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [48] simulate_iceflow_UDE!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:171 [inlined]
 [49] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [50] batch_iceflow_UDE
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:146 [inlined]
 [51] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [52] #25
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [53] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [54] #680
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201 [inlined]
 [55] #235
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157 [inlined]
 [56] (::Base.var"#1023#1028"{Distributed.var"#235#236"{…}})(r::Base.RefValue{Any}, args::Tuple{Tuple{…}})
    @ Base ./asyncmap.jl:94
Stacktrace:
  [1] (::Base.var"#1033#1035")(x::Task)
    @ Base ./asyncmap.jl:171
  [2] foreach(f::Base.var"#1033#1035", itr::Vector{Any})
    @ Base ./abstractarray.jl:3094
  [3] maptwice(wrapped_f::Function, chnl::Channel{Any}, worker_tasks::Vector{Any}, c::Base.Iterators.Zip{Tuple{…}})
    @ Base ./asyncmap.jl:171
  [4] wrap_n_exec_twice
    @ ./asyncmap.jl:147 [inlined]
  [5] #async_usemap#1018
    @ ./asyncmap.jl:97 [inlined]
  [6] kwcall(::NamedTuple, ::typeof(Base.async_usemap), f::Any, c::Vararg{Any})
    @ Base ./asyncmap.jl:78 [inlined]
  [7] #asyncmap#1017
    @ ./asyncmap.jl:75 [inlined]
  [8] asyncmap
    @ ./asyncmap.jl:74 [inlined]
  [9] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{…}; distributed::Bool, batch_size::Int64, on_error::Nothing, retry_delays::Vector{…}, retry_check::Nothing)
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:126
 [10] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}})
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:99
 [11] pmap(f::Function, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}}; kwargs::@Kwargs{})
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156
 [12] pmap
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156 [inlined]
 [13] pmap(f::Function, c1::Vector{Tuple{…}}, c::Vector{@NamedTuple{…}})
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157
 [14] (::Zygote.var"#map_back#682"{ODINN.var"#25#26"{…}, 1, Tuple{…}, Tuple{…}, Vector{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201
 [15] (::Zygote.var"#2861#back#688"{Zygote.var"#map_back#682"{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [16] predict_iceflow!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] loss_iceflow
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:58 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] #22
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:31 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [24] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:3762 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] #37
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:90 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [31] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [32] #39
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [35] gradient(f::Function, args::ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [36] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentArrays.ComponentVector{…}, ::ComponentArrays.ComponentVector{…}, ::Vector{…}, ::Vararg{…})
    @ OptimizationZygoteExt ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93
 [37] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [38] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [39] (::OptimizationOptimisers.var"#12#13"{OptimizationBase.OptimizationCache{…}, ComponentArrays.ComponentVector{…}})()
    @ OptimizationOptimisers ~/.julia/packages/Optimization/jWtfU/src/utils.jl:29
 [40] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging ./logging.jl:515
 [41] with_logger
    @ Base.CoreLogging ./logging.jl:627 [inlined]
 [42] maybe_with_logger(f::OptimizationOptimisers.var"#12#13"{…}, logger::LoggingExtras.TeeLogger{…})
    @ Optimization ~/.julia/packages/Optimization/jWtfU/src/utils.jl:7
 [43] macro expansion
    @ ~/.julia/packages/Optimization/jWtfU/src/utils.jl:28 [inlined]
 [44] __solve(cache::OptimizationBase.OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [45] solve!(cache::OptimizationBase.OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:188
 [46] solve(prob::SciMLBase.OptimizationProblem{…}, alg::Optimisers.Adam, args::IterTools.NCycle{…}; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:96
 [47] train_UDE!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:43
 [48] run!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:11
 [49] top-level scope
    @ ./timing.jl:279
 [50] top-level scope
    @ none:1
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link

wsmoses commented May 31, 2024

@ChrisRackauckas @JordiBolibar where is the inactive function you have for pycall. Indeed I would've expected a inactive marking to force that to be fine, so if you hvae a MWE of that part I can fix.

We could also teach pycall object construction how to properly deal with allocations in an enzyme rule, but the fact you say there's an inactive marking that's not triggering seems like the bigger issue.

src/ODINN.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member Author

Another error I ran into was Tullio support. @wsmoses is that known?

@wsmoses
Copy link

wsmoses commented Jun 4, 2024 via email

@JordiBolibar
Copy link
Member

As I mentioned in the Huginn PR, Tullio should be easily avoidable here, since we only use it in the out-of-place version of the function that was needed for Zygote and ReverseDiff. Migrating to the in-place version of the function should avoid that.

src/ODINN.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member Author

Training iceflow UDE...
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 1
└ @ MLUtils ~/.julia/packages/MLUtils/LmmaQ/src/batchview.jl:95
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 1
└ @ MLUtils ~/.julia/packages/MLUtils/LmmaQ/src/batchview.jl:95
Before solving ODE problem
ODE problem solved for 1
over here for 1
simulation finished for 1
Batch 1 finished!
All batches finished
Loss computed: 13.506396213311314
┌ Warning: Automatic dt set the starting dt as NaN, causing instability. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/F67Rp/src/solve.jl:591
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/integrator_interface.jl:593
ERROR: DimensionMismatch: array could not be broadcast to match destination
Stacktrace:
  [1] check_broadcast_shape
    @ ./broadcast.jl:579 [inlined]
  [2] check_broadcast_axes
    @ ./broadcast.jl:582 [inlined]
  [3] instantiate
    @ ./broadcast.jl:309 [inlined]
  [4] materialize!
    @ ./broadcast.jl:914 [inlined]
  [5] materialize!
    @ ./broadcast.jl:911 [inlined]
  [6] vec_pjac!(out::ComponentArrays.ComponentVector{…}, λ::Vector{…}, y::Matrix{…}, t::Float64, S::SciMLSensitivity.AdjointSensitivityIntegrand{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:288
  [7] AdjointSensitivityIntegrand
    @ ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:309 [inlined]
  [8] (::SciMLSensitivity.AdjointSensitivityIntegrand{…})(t::Float64)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:321
  [9] evalrule(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, a::Float64, b::Float64, x::Vector{…}, w::Vector{…}, gw::Vector{…}, nrm::typeof(LinearAlgebra.norm))
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/evalrule.jl:0
 [10] #6
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:15 [inlined]
 [11] ntuple
    @ ./ntuple.jl:48 [inlined]
 [12] do_quadgk(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…}, n::Int64, atol::Float64, rtol::Float64, maxevals::Int64, nrm::typeof(LinearAlgebra.norm), segbuf::Nothing)
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:13
 [13] #50
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:253 [inlined]
 [14] handle_infinities(workfunc::QuadGK.var"#50#51"{}, f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…})
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:145
 [15] quadgk(::SciMLSensitivity.AdjointSensitivityIntegrand{…}, ::Float64, ::Vararg{…}; atol::Float64, rtol::Float64, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing)
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:252
 [16] _adjoint_sensitivities(sol::SciMLBase.ODESolution{…}, sensealg::SciMLSensitivity.QuadratureAdjoint{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::SciMLBase.CallbackSet{…}, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:382
 [17] _adjoint_sensitivities
    @ ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:324 [inlined]
 [18] #adjoint_sensitivities#63
    @ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:383 [inlined]
 [19] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{})(Δ::SciMLBase.ODESolution{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:556
 [20] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [21] (::Zygote.var"#kw_zpullback#53"{})(dy::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] (::Zygote.var"#2169#back#293"{})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [24] #solve#51
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [inlined]
 [25] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] solve
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:993 [inlined]
 [29] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] #simulate_iceflow_UDE!#32
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:187 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] simulate_iceflow_UDE!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:171 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] batch_iceflow_UDE
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:146 [inlined]
 [35] (::Zygote.Pullback{…})(Δ::@NamedTuple{})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] #28
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [37] (::Zygote.Pullback{…})(Δ::@NamedTuple{})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] #680
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201 [inlined]
 [39] #235
    @ ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157 [inlined]
 [40] (::Base.var"#1023#1028"{Distributed.var"#235#236"{}})(r::Base.RefValue{Any}, args::Tuple{Tuple{…}})
    @ Base ./asyncmap.jl:94
 [41] (::Base.var"#1039#1040"{Base.var"#1023#1028"{Distributed.var"#235#236"{Zygote.var"#680#685"}}, Channel{Any}, Nothing})()
    @ Base ./asyncmap.jl:228
Stacktrace:
  [1] (::Base.var"#1033#1035")(x::Task)
    @ Base ./asyncmap.jl:171
  [2] foreach(f::Base.var"#1033#1035", itr::Vector{Any})
    @ Base ./abstractarray.jl:3097
  [3] maptwice(wrapped_f::Function, chnl::Channel{Any}, worker_tasks::Vector{Any}, c::Base.Iterators.Zip{Tuple{…}})
    @ Base ./asyncmap.jl:171
  [4] wrap_n_exec_twice
    @ ./asyncmap.jl:147 [inlined]
  [5] #async_usemap#1018
    @ ./asyncmap.jl:97 [inlined]
  [6] async_usemap
    @ ./asyncmap.jl:78 [inlined]
  [7] #asyncmap#1017
    @ ./asyncmap.jl:75 [inlined]
  [8] asyncmap
    @ ./asyncmap.jl:74 [inlined]
  [9] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{…}; distributed::Bool, batch_size::Int64, on_error::Nothing, retry_delays::Vector{…}, retry_check::Nothing)
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:126
 [10] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{Tuple{Vector{Tuple{}}, Vector{@NamedTuple{}}}})
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:99
 [11] pmap(f::Function, c::Base.Iterators.Zip{Tuple{Vector{Tuple{}}, Vector{@NamedTuple{}}}}; kwargs::@Kwargs{})
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156
 [12] pmap
    @ ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156 [inlined]
 [13] pmap(f::Function, c1::Vector{Tuple{…}}, c::Vector{@NamedTuple{…}})
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157
 [14] (::Zygote.var"#map_back#682"{ODINN.var"#28#29"{}, 1, Tuple{}, Tuple{}, Vector{}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201
 [15] (::Zygote.var"#2861#back#688"{Zygote.var"#map_back#682"{}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [16] predict_iceflow!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] loss_iceflow
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:58 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] #25
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:31 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [24] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/sakPO/src/scimlfunctions.jl:3762 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{}, Zygote.Pullback{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] #37
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:90 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [31] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [32] #39
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [35] gradient(f::Function, args::ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [36] (::OptimizationZygoteExt.var"#38#56"{})(::ComponentArrays.ComponentVector{…}, ::ComponentArrays.ComponentVector{…}, ::Vector{…}, ::Vararg{…})
    @ OptimizationZygoteExt ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93
 [37] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [38] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [39] (::OptimizationOptimisers.var"#12#13"{OptimizationBase.OptimizationCache{}, ComponentArrays.ComponentVector{}})()
    @ OptimizationOptimisers ~/.julia/packages/Optimization/jWtfU/src/utils.jl:29
 [40] maybe_with_logger(f::OptimizationOptimisers.var"#12#13"{}, logger::Nothing)
    @ Optimization ~/.julia/packages/Optimization/jWtfU/src/utils.jl:7
 [41] macro expansion
    @ ~/.julia/packages/Optimization/jWtfU/src/utils.jl:28 [inlined]
 [42] __solve(cache::OptimizationBase.OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [43] solve!(cache::OptimizationBase.OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:188
 [44] solve(prob::SciMLBase.OptimizationProblem{…}, alg::Optimisers.Adam, args::IterTools.NCycle{…}; kwargs::@Kwargs{})
    @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:96
 [45] train_UDE!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:43
 [46] run!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:11
 [47] macro expansion
    @ ./timing.jl:279 [inlined]
 [48] top-level scope
    @ ~/Desktop/test.jl:76
Some type information was truncated. Use `show(err)` to see complete types.

Hokay the autodiff is working using the SIA2D!, but in the adjoint the dt goes to NaN which crashes it and I'll need to investigate that.

@wsmoses
Copy link

wsmoses commented Aug 31, 2024

oh if the issue is a nan, enzyme (on cpu only presently) has a nan checker which will throw a backtrace the first time a nan is generated for a derivative.

set Enzyme.Compiler.CheckNan[] = true after importing and the instrumentation will be added to later code.

Clearly this needs more docs [PRs welcome ofc!].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants