-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Steady state callback #601
base: main
Are you sure you want to change the base?
Changes from all commits
a475e82
4db105e
bfc3bca
ecb23be
6a8e7e7
8543619
f75a858
f3fe009
c121c37
a1d69d6
0871515
941a758
bc9933f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
""" | ||
SteadyStateCallback(; interval::Integer=0, dt=0.0, interval_size::Integer=10, | ||
abstol=1.0e-8, reltol=1.0e-6) | ||
|
||
Terminates the integration when the residual of the change in kinetic energy | ||
falls below the threshold specified by `abstol + reltol * ekin`, | ||
where `ekin` is the total kinetic energy of the simulation. | ||
|
||
# Keywords | ||
- `interval=0`: Check steady state condition every `interval` time steps. | ||
- `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms | ||
of integration time by adding additional `tstops` | ||
(note that this may change the solution). | ||
- `interval_size`: The interval in which the change of the kinetic energy is considered. | ||
`interval_size` is a (integer) multiple of `interval` or `dt`. | ||
- `abstol`: Absolute tolerance. | ||
- `reltol`: Relative tolerance. | ||
LasNikas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
mutable struct SteadyStateCallback{I, ELTYPE <: Real} | ||
interval :: I | ||
abstol :: ELTYPE | ||
reltol :: ELTYPE | ||
previous_ekin :: Vector{ELTYPE} | ||
interval_size :: Int | ||
end | ||
|
||
function SteadyStateCallback(; interval::Integer=0, dt=0.0, interval_size::Integer=10, | ||
abstol=1.0e-8, reltol=1.0e-6) | ||
abstol, reltol = promote(abstol, reltol) | ||
|
||
if dt > 0 && interval > 0 | ||
throw(ArgumentError("setting both `interval` and `dt` is not supported")) | ||
end | ||
|
||
if dt > 0 | ||
interval = Float64(dt) | ||
end | ||
|
||
steady_state_callback = SteadyStateCallback(interval, abstol, reltol, [Inf64], | ||
interval_size) | ||
|
||
if dt > 0 | ||
return PeriodicCallback(steady_state_callback, dt, save_positions=(false, false), | ||
final_affect=true) | ||
else | ||
return DiscreteCallback(steady_state_callback, steady_state_callback, | ||
save_positions=(false, false)) | ||
end | ||
end | ||
|
||
function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SteadyStateCallback}) | ||
@nospecialize cb # reduce precompilation time | ||
|
||
cb_ = cb.affect! | ||
|
||
print(io, "SteadyStateCallback(abstol=", cb_.abstol, ", ", "reltol=", cb_.reltol, ")") | ||
end | ||
|
||
function Base.show(io::IO, | ||
cb::DiscreteCallback{<:Any, | ||
<:PeriodicCallbackAffect{<:SteadyStateCallback}}) | ||
@nospecialize cb # reduce precompilation time | ||
|
||
cb_ = cb.affect!.affect! | ||
|
||
print(io, "SteadyStateCallback(abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") | ||
end | ||
|
||
function Base.show(io::IO, ::MIME"text/plain", | ||
cb::DiscreteCallback{<:Any, <:SteadyStateCallback}) | ||
@nospecialize cb # reduce precompilation time | ||
|
||
if get(io, :compact, false) | ||
show(io, cb) | ||
else | ||
cb_ = cb.affect! | ||
|
||
setup = ["absolute tolerance" => cb_.abstol, | ||
"relative tolerance" => cb_.reltol, | ||
"interval" => cb_.interval, | ||
"interval size" => cb_.interval_size] | ||
summary_box(io, "SteadyStateCallback", setup) | ||
end | ||
end | ||
|
||
function Base.show(io::IO, ::MIME"text/plain", | ||
cb::DiscreteCallback{<:Any, | ||
<:PeriodicCallbackAffect{<:SteadyStateCallback}}) | ||
@nospecialize cb # reduce precompilation time | ||
|
||
if get(io, :compact, false) | ||
show(io, cb) | ||
else | ||
cb_ = cb.affect!.affect! | ||
|
||
setup = ["absolute tolerance" => cb_.abstol, | ||
"relative tolerance" => cb_.reltol, | ||
"interval" => cb_.interval, | ||
"interval_size" => cb_.interval_size] | ||
summary_box(io, "SteadyStateCallback", setup) | ||
end | ||
end | ||
|
||
# `affect!` (`PeriodicCallback`) | ||
function (cb::SteadyStateCallback)(integrator) | ||
steady_state_condition!(cb, integrator) || return nothing | ||
|
||
print_summary(integrator) | ||
|
||
# `terminate!(integrator)` terminates the simulation immediately and might cause an error message. | ||
integrator.opts.maxiters = integrator.iter | ||
Comment on lines
+110
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This causes a warning message. Any idea @efaulhaber and @svchb ? edit: The reason for this is the following.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another (simpler) solution: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the only way would be to implement a custom solve function which wraps the ODE solve? |
||
end | ||
|
||
# `affect!` (`DiscreteCallback`) | ||
function (cb::SteadyStateCallback{Int})(integrator) | ||
print_summary(integrator) | ||
|
||
# `terminate!(integrator)` terminates the simulation immediately and might cause an error message. | ||
integrator.opts.maxiters = integrator.iter | ||
end | ||
|
||
# `condition` (`DiscreteCallback`) | ||
function (steady_state_callback::SteadyStateCallback)(vu_ode, t, integrator) | ||
return steady_state_condition!(steady_state_callback, integrator) | ||
end | ||
|
||
@inline function steady_state_condition!(cb, integrator) | ||
(; abstol, reltol, previous_ekin, interval_size) = cb | ||
|
||
vu_ode = integrator.u | ||
v_ode, u_ode = vu_ode.x | ||
semi = integrator.p | ||
|
||
# Calculate kinetic energy | ||
ekin = 0.0 | ||
foreach_system(semi) do system | ||
v = wrap_v(v_ode, system, semi) | ||
unused_arg = nothing | ||
|
||
ekin += kinetic_energy(v, unused_arg, unused_arg, system) | ||
end | ||
|
||
if length(previous_ekin) == interval_size | ||
|
||
# Calculate MSE only over the `interval_size` | ||
mse = 0.0 | ||
for index in 1:interval_size | ||
mse += (previous_ekin[index] - ekin)^2 / interval_size | ||
end | ||
|
||
if mse <= abstol + reltol * ekin | ||
return true | ||
end | ||
|
||
# Pop old kinetic energy | ||
popfirst!(previous_ekin) | ||
end | ||
|
||
LasNikas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Add current kinetic energy | ||
push!(previous_ekin, ekin) | ||
|
||
return false | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the name a bit misleading. In grid-based methods, when simulating a physical instability arising from a slightly perturbed steady state, you sometimes subtract the right-hand side of the steady state to remove errors.
From this name, I would expect something like this here.
I'll think about other naming suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConvergenceConditionCallBack
ConvergenceTerminationCallBack
TerminationConditionCallBack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SteadyStateDetector
SteadyStateTerminator