diff --git a/Project.toml b/Project.toml index 90d9b6f..2b004ba 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,12 @@ version = "1.0.0" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" GeoStats = "dcc97b0b-8ce5-5539-9008-bb190f959ef6" @@ -20,6 +22,7 @@ Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Muninn = "4b816528-16ba-4e32-9a2e-3c1bc2049d3a" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -48,7 +51,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" BenchmarkTools = "1.3.2" ChainRules = "1.50" Downloads = "1" -Flux = "0.13, 0.14" GR = "0.71, 0.72, 0.73" Huginn = "0.6" IJulia = "1.2" diff --git a/src/ODINN.jl b/src/ODINN.jl index eba402c..d661eb5 100644 --- a/src/ODINN.jl +++ b/src/ODINN.jl @@ -18,7 +18,7 @@ using IterTools: ncycle using Zygote using ChainRules: @ignore_derivatives using Base: @kwdef -using Flux +using Flux, Lux, ComponentArrays using Tullio using Infiltrator, Cthulhu using Plots, PlotThemes @@ -31,6 +31,12 @@ using Downloads using TimerOutputs using GeoStats using ImageFiltering +using EnzymeCore + +# This is equivalent to `@ignore_derivatives` +EnzymeCore.EnzymeRules.inactive(::typeof(Huginn.define_callback_steps), args...; kwargs...) = nothing +EnzymeCore.EnzymeRules.inactive(::typeof(Muninn.MB_timestep!), args...; kwargs...) = nothing +EnzymeCore.EnzymeRules.inactive(::typeof(apply_MB_mask!), args...; kwargs...) = nothing # ############################################## # ############ PARAMETERS ############### diff --git a/src/helpers/mass_balance.jl b/src/helpers/mass_balance.jl index c3d6522..59d919a 100644 --- a/src/helpers/mass_balance.jl +++ b/src/helpers/mass_balance.jl @@ -43,16 +43,10 @@ end function MB_timestep!(MB, mb_model::MB_model, climate, S, S_coords, t, step) # First we get the dates of the current time and the previous step period = partial_year(Day, t - step):Day(1):partial_year(Day, t) - @timeit to "Climate step" begin get_cumulative_climate!(climate, period) - end # Convert climate dataset to 2D based on the glacier's DEM - @timeit to "Climate 2D step" begin downscale_2D_climate!(climate, S, S_coords) - end - @timeit to "Compute MB" begin MB .= compute_MB(mb_model, climate.climate_2D_step[]) - end end function apply_MB_mask!(H, MB, MB_total, context::Tuple) diff --git a/src/helpers/utils.jl b/src/helpers/utils.jl index b9d144b..59d037c 100644 --- a/src/helpers/utils.jl +++ b/src/helpers/utils.jl @@ -217,28 +217,6 @@ function generate_batches(batch_size, UA, gdirs, gdir_refs, tspan::Tuple; shuffl return train_loader end - -""" - get_NN() - -Generates a neural network. -""" -function get_NN(θ_trained) - UA = Chain( - Dense(1,3, x->softplus.(x)), - Dense(3,10, x->softplus.(x)), - Dense(10,3, x->softplus.(x)), - Dense(3,1, sigmoid_A) - ) - UA = Flux.f64(UA) - # See if parameters need to be retrained or not - θ, UA_f = Flux.destructure(UA) - if !isempty(θ_trained) - θ = θ_trained - end - return UA_f, θ -end - function get_NN_inversion(θ_trained, target) if target == "D" U, θ = get_NN_inversion_D(θ_trained) diff --git a/src/models/machine_learning/ML_utils.jl b/src/models/machine_learning/ML_utils.jl index 062d6f5..ce3b5f6 100644 --- a/src/models/machine_learning/ML_utils.jl +++ b/src/models/machine_learning/ML_utils.jl @@ -5,19 +5,18 @@ get_NN() Generates a neural network. """ function get_NN(θ_trained) - UA = Flux.Chain( - Dense(1,3, x->softplus.(x)), - Dense(3,10, x->softplus.(x)), - Dense(10,3, x->softplus.(x)), - Dense(3,1, sigmoid_A) + UA = Lux.Chain( + Lux.Dense(1,3, x->Lux.softplus.(x)), + Lux.Dense(3,10, x->Lux.softplus.(x)), + Lux.Dense(10,3, x->Lux.softplus.(x)), + Lux.Dense(3,1, sigmoid_A) ) - UA = Flux.f64(UA) - # See if parameters need to be retrained or not - θ, UA_f = Flux.destructure(UA) + θ, st = Lux.setup(Random.default_rng(), UA) + θ = ComponentArray{Float64}(θ) if !isnothing(θ_trained) θ = θ_trained end - return UA, θ, UA_f + return UA, θ, st end """ @@ -26,7 +25,7 @@ end Predicts the value of A with a neural network based on the long-term air temperature. """ function predict_A̅(U, temp) - return U(temp) .* 1e-18 + return U(temp)[1] .* 1e-18 end function sigmoid_A(x) @@ -83,12 +82,12 @@ function build_D_features(H::Matrix, temp, ∇S) ∇S_flat = ∇S[inn1(H) .!= 0.0] # flatten H_flat = H[H .!= 0.0] # flatten T_flat = repeat(temp,length(H_flat)) - X = Flux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix + X = Lux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix return X end function build_D_features(H::Float64, temp::Float64, ∇S::Float64) - X = Flux.normalise(hcat([H],[temp],[∇S]))' # build feature matrix + X = Lux.normalise(hcat([H],[temp],[∇S]))' # build feature matrix return X end diff --git a/src/models/machine_learning/MLmodel.jl b/src/models/machine_learning/MLmodel.jl index 6f22fcc..71503e9 100644 --- a/src/models/machine_learning/MLmodel.jl +++ b/src/models/machine_learning/MLmodel.jl @@ -27,35 +27,36 @@ function Model(; return model end -mutable struct NN{F <: AbstractFloat} <: MLmodel - architecture::Flux.Chain - NN_f::Optimisers.Restructure - θ::Vector{F} +mutable struct NN{T1, T2, T3} <: MLmodel + architecture::T1 + st::T2 + θ::T3 end +(f::NN)(u) = f.architecture(u, f.θ, f.st) """ NN(params::Parameters; - architecture::Union{Flux.Chain, Nothing} = nothing, + architecture::Union{Lux.Chain, Nothing} = nothing, θ::Union{Vector{AbstractFloat}, Nothing} = nothing) Feed-forward neural network. Keyword arguments ================= - - `architecture`: `Flux.Chain` neural network architecture + - `architecture`: `Lux.Chain` neural network architecture - `θ`: Neural network parameters """ function NN(params::Sleipnir.Parameters; - architecture::Union{Flux.Chain, Nothing} = nothing, - θ::Union{Vector{F}, Nothing} = nothing) where {F <: AbstractFloat} + architecture::Union{Lux.Chain, Nothing} = nothing, + θ::Union{ComponentArray{F}, Nothing} = nothing) where {F <: AbstractFloat} if isnothing(architecture) - architecture, θ, NN_f = get_NN(θ) + architecture, θ, st = get_NN(θ) end # Build the simulation parameters based on input values ft = params.simulation.float_type - neural_net = NN{ft}(architecture, NN_f, θ) + neural_net = NN(architecture, st, θ) return neural_net end diff --git a/src/parameters/Hyperparameters.jl b/src/parameters/Hyperparameters.jl index 19365dd..ade57d7 100644 --- a/src/parameters/Hyperparameters.jl +++ b/src/parameters/Hyperparameters.jl @@ -4,7 +4,7 @@ export Hyperparameters current_epoch::I current_minibatch::I loss_history::Vector{F} - optimizer::Union{Optim.FirstOrderOptimizer, Flux.Optimise.AbstractOptimiser, Optimisers.AbstractRule} + optimizer::Union{Optim.FirstOrderOptimizer, Optimisers.AbstractRule} loss_epoch::F epochs::I batch_size::I @@ -33,7 +33,7 @@ function Hyperparameters(; current_epoch::Int64 = 1, current_minibatch::Int64 = 1, loss_history::Vector{Float64} = zeros(Float64, 0), - optimizer::Union{Optim.FirstOrderOptimizer, Flux.Optimise.AbstractOptimiser, Optimisers.AbstractRule} = BFGS(initial_stepnorm=0.001), + optimizer::Union{Optim.FirstOrderOptimizer, Optimisers.AbstractRule} = BFGS(initial_stepnorm=0.001), loss_epoch::Float64 = 0.0, epochs::Int64 = 50, batch_size::Int64 = 15 diff --git a/src/simulations/functional_inversions/functional_inversion_utils.jl b/src/simulations/functional_inversions/functional_inversion_utils.jl index 35ab25f..20a403f 100644 --- a/src/simulations/functional_inversions/functional_inversion_utils.jl +++ b/src/simulations/functional_inversions/functional_inversion_utils.jl @@ -28,7 +28,7 @@ function train_UDE!(simulation::FunctionalInversion) train_batches = generate_batches(simulation) θ = simulation.model.machine_learning.θ - optf = OptimizationFunction((θ, _, batch_ids, rgi_ids)->loss_iceflow(θ, batch_ids, simulation), Optimization.AutoReverseDiff()) + optf = OptimizationFunction((θ, _, batch_ids, rgi_ids)->loss_iceflow(θ, batch_ids, simulation), Optimization.AutoZygote()) optprob = OptimizationProblem(optf, θ) if simulation.parameters.UDE.target == "A" @@ -127,15 +127,13 @@ function batch_iceflow_UDE(θ, simulation::FunctionalInversion, batch_id::I) whe # Initialize glacier ice flow model initialize_iceflow_model(model.iceflow[batch_id], batch_id, glacier, params) - params.solver.tstops = @ignore_derivatives Huginn.define_callback_steps(params.simulation.tspan, params.solver.step) + params.solver.tstops = Huginn.define_callback_steps(params.simulation.tspan, params.solver.step) stop_condition(u,t,integrator) = Sleipnir.stop_condition_tstops(u,t,integrator, params.solver.tstops) #closure function action!(integrator) if params.simulation.use_MB # Compute mass balance - @ignore_derivatives begin - MB_timestep!(model, glacier, params.solver.step, integrator.t; batch_id = batch_id) - apply_MB_mask!(integrator.u, glacier, model.iceflow[batch_id]) - end + # MB_timestep!(model, glacier, params.solver.step, integrator.t; batch_id = batch_id) + # apply_MB_mask!(integrator.u, glacier, model.iceflow[batch_id]) end # Apply parametrization apply_UDE_parametrization!(θ, simulation, integrator, batch_id) @@ -144,7 +142,7 @@ function batch_iceflow_UDE(θ, simulation::FunctionalInversion, batch_id::I) whe cb_MB = DiscreteCallback(stop_condition, action!) # Run iceflow UDE for this glacier - du = params.simulation.use_iceflow ? Huginn.SIA2D : Huginn.noSIA2D + du = params.simulation.use_iceflow ? Huginn.SIA2D! : Huginn.noSIA2D! iceflow_sol = simulate_iceflow_UDE!(θ, simulation, model, params, cb_MB, batch_id; du = du) println("simulation finished for $batch_id") @@ -218,7 +216,7 @@ end function apply_UDE_parametrization!(θ, simulation::FunctionalInversion, integrator, batch_id::I) where {I <: Integer} # We load the ML model with the parameters - U = simulation.model.machine_learning.NN_f(θ) + U = NN(simulation.model.machine_learning.architecture, simulation.model.machine_learning.st, convert(typeof(simulation.model.machine_learning.θ),θ)) # We generate the ML parametrization based on the target if simulation.parameters.UDE.target == "A" A = predict_A̅(U, [mean(simulation.glaciers[batch_id].climate.longterm_temps)])[1] @@ -244,7 +242,7 @@ callback_plots_A = function (θ, l, simulation) # callback function to observe t p = sortperm(avg_temps) avg_temps = avg_temps[p] # We load the ML model with the parameters - U = simulation.model.machine_learning.NN_f(θ) + U = NN(simulation.model.machine_learning.architecture, simulation.model.machine_learning.st, convert(typeof(simulation.model.machine_learning.θ),θ)) pred_A = predict_A̅(U, collect(-23.0:1.0:0.0)') pred_A = Float64[pred_A...] # flatten true_A = A_fake(avg_temps, true)