Skip to content

Commit

Permalink
Speedup calculation of F and ρtor (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyons12 authored Jul 25, 2024
1 parent adc0217 commit 37b5879
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 21 deletions.
104 changes: 90 additions & 14 deletions src/fsa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ end
"""
Vprime(shot::F1, ρ::Real; tid=Threads.threadid()) where {F1<:Shot}
Compute dV/ at `ρ` for the equilibrium defined in `shot`
Compute dV/ at `ρ` for the equilibrium defined in `shot`
"""
function Vprime(shot::F1, ρ::Real; tid=Threads.threadid()) where {F1<:Shot}
k, nu_ou, nu_eu, nu_ol, nu_el, D_nu_ou, D_nu_eu, D_nu_ol, D_nu_el = compute_both_bases(shot.ρ, ρ)
Expand Down Expand Up @@ -205,21 +205,54 @@ end

function FE_coeffs!(Y::FE_rep, shot::F1, f::F2; ε::Real=1e-6, derivative::Symbol=:auto) where {F1<:Shot,F2}
@assert derivative in (:auto, :finite)
(derivative === :auto) && (g = x -> f(shot, x))
for (i, x) in enumerate(Y.x)
g = x -> f(shot, x)
Y.coeffs[2i] = f(shot, x)
if derivative === :auto
if x == 0.0
Y.coeffs[2i-1] = ForwardDiff.derivative(g, 1e-12)
else
Y.coeffs[2i-1] = ForwardDiff.derivative(g, x)
end
else
xp = x == Y.x[end] ? x : x + ε * (Y.x[i+1] - Y.x[i])
xm = x == Y.x[1] ? x : x - ε * (Y.x[i] - Y.x[i-1])
Y.coeffs[2i-1] = (f(shot, xp) - f(shot, xm)) / (xp - xm)
if x == Y.x[end]
xp = x
fp = Y.coeffs[2i]
else
xp = x + ε * (Y.x[i+1] - Y.x[i])
fp = f(shot, xp)
end

if x == Y.x[1]
xm = x
fm = Y.coeffs[2i]
else
xm = x - ε * (Y.x[i] - Y.x[i-1])
fm = f(shot, xm)
end

Y.coeffs[2i-1] = (fp - fm) / (xp - xm)
end

Y.coeffs[2i] = g(x)

end
return Y
end

function Fpol_coeffs!(Y::FE_rep, shot::F1; invR=FE_rep(shot, fsa_invR), invR2=FE_rep(shot, fsa_invR2)) where {F1<:Shot}
F = shot.Fbnd
F_dF_dψ= FFprime(shot, shot.F_dF_dψ, shot.Jt_R, shot.Jt; invR, invR2)
FFp = x -> F_dF_dψ(x) * dψ_dρ(shot, x)
for i in reverse(eachindex(Y.x))
ρ = Y.x[i]
if i < length(Y.x)
# just integrate from ρ to Y.x[i+1],
# and use previous F as boundary condition
endpoint = (Y.x[i+1], F)
F = Fpol(shot, ρ, endpoint; invR, invR2)
end
Y.coeffs[2i] = F
Y.coeffs[2i-1] = FFp(ρ) / F
end
return Y
end
Expand Down Expand Up @@ -263,12 +296,10 @@ function toroidal_flux(shot::F1, ρs::AbstractVector{<:Real}) where {F1<:Shot}
return toroidal_flux!(Φ, shot, ρs)
end

function toroidal_flux!::Vector{<:Real}, shot::F1, ρs::AbstractVector{<:Real}) where {F1<:Shot}
function toroidal_flux!::AbstractVector{<:Real}, shot::F1, ρs::AbstractVector{<:Real}; use_cached=true) where {F1<:Shot}
@assert length(Φ) === length(ρs)
invR = FE_rep(shot, fsa_invR)
invR2 = FE_rep(shot, fsa_invR2)
Vp = FE_rep(shot, Vprime)
f = x -> Fpol(shot, x; invR, invR2) * Vp(x) * invR2(x)
Vp, _, invR2, F = get_FEs(shot, use_cached)
f = x -> F(x) * Vp(x) * invR2(x)
for k in eachindex(ρs)[2:end]
Φ[k] = Φ[k-1] + quadgk(f, ρs[k-1], ρs[k])[1] / twopi
end
Expand All @@ -286,7 +317,7 @@ function rho_tor_norm(shot::F1) where {F1<:Shot}
return rho_tor_norm!(ρtor, shot)
end

function rho_tor_norm!(ρtor::Vector{<:Real}, shot::F1) where {F1<:Shot}
function rho_tor_norm!(ρtor::AbstractVector{<:Real}, shot::F1) where {F1<:Shot}
@assert length(ρtor) === length(shot.ρ)
toroidal_flux!(ρtor, shot, shot.ρ)
@. ρtor = sqrt(ρtor / ρtor[end])
Expand All @@ -295,12 +326,57 @@ function rho_tor_norm!(ρtor::Vector{<:Real}, shot::F1) where {F1<:Shot}
return ρtor
end

function get_FEs(shot::Shot, use_cached::Bool)
Vp = use_cached ? shot.Vp : FE_rep(shot, Vprime)
invR = use_cached ? shot.invR : FE_rep(shot, fsa_invR)
invR2 = use_cached ? shot.invR2 : FE_rep(shot, fsa_invR2)
if use_cached
F = shot.F
else
N = length(shot.ρ)
F = FE_rep(shot.ρ, zeros(eltype(shot.ρ, 2N)))
Fpol_coeffs!(F, shot; invR, invR2)
end

return Vp, invR, invR2, F
end

function ρtor_coeffs!(Y::FE_rep, shot::F1; use_cached=true, ε::Real=1e-6) where {F1<:Shot}
@assert length(Y.x) === length(shot.ρ)

Vp, _, invR2, F = get_FEs(shot, use_cached)

# to start, compute Φ
@views Φ = Y.coeffs[2:2:end]
toroidal_flux!(Φ, shot, shot.ρ; use_cached)
Φ0 = Φ[end]

# then compute ρtor
@. Y.coeffs[2:2:end] = sqrt/ Φ0)
Y.coeffs[1] = 0.0
Y.coeffs[end] = 1.0
@views ρtor = Y.coeffs[2:2:end]

# compute derivative analytically, since Φ is an integral
# dρtor_dρ = dΦ_dρ / (2 * Φ0 * ρtor)
f = x -> F(x) * Vp(x) * invR2(x) / twopi
@views ρtor = Y.coeffs[2:2:end]
Y.coeffs[1:2:end] .= f.(Y.x) ./ (2.0 .* Φ0 .* ρtor)

# fix on-axis derivative by taking approximate derivative at δ
# using ρtor(δ) ≈ dρtor_dρ(δ) * δ in equation above
δ = ε * Y.x[2]
Y.coeffs[1] = sqrt(f(δ) / (2.0 * Φ0 * δ))

return Y
end

function set_FSAs!(shot)
FE_coeffs!(shot.Vp, shot, Vprime; derivative=:auto)
FE_coeffs!(shot.invR, shot, fsa_invR; derivative=:auto)
FE_coeffs!(shot.invR2, shot, fsa_invR2; derivative=:auto)
FE_coeffs!(shot.F, shot, (shot, x) -> Fpol(shot, x; shot.invR, shot.invR2); derivative=:finite)
FE_coeffs!(shot.ρtor, shot, rho_tor_norm; derivative=:finite)
Fpol_coeffs!(shot.F, shot; shot.invR, shot.invR2)
ρtor_coeffs!(shot.ρtor, shot)
return shot
end

Expand Down
15 changes: 8 additions & 7 deletions src/shot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ function MXHEquilibrium.pressure(shot::Shot, psi)
throw(ErrorException("Must specify one of the following: P, dP_dψ"))
end

function FFprime(shot::Shot, F_dF_dψ::Nothing, Jt_R::Nothing, Jt::Nothing)
function FFprime(shot::Shot, F_dF_dψ::Nothing, Jt_R::Nothing, Jt::Nothing; invR=nothing, invR2=nothing)
throw(ErrorException("Must specify one of the following: F_dF_dψ, Jt_R, Jt"))
end

Expand All @@ -844,21 +844,22 @@ function Fpol_dFpol_dψ(shot::Shot, ρ::Real; kwargs...)
return ffp(ρ)
end

function Fpol(shot::Shot, ρ::Real; kwargs...)
function Fpol(shot::Shot, ρ::Real, endpoint::Tuple{Real, Real}=(1.0, shot.Fbnd); kwargs...)
return Fpol(shot, FFprime(shot, shot.F_dF_dψ, shot.Jt_R, shot.Jt; kwargs...), ρ)
end

function Fpol(shot::Shot, F_dF_dψ, ρ::Real)
function Fpol(shot::Shot, F_dF_dψ, ρ::Real, endpoint::Tuple{Real, Real}=(1.0, shot.Fbnd))
f(x) = F_dF_dψ(x) * dψ_dρ(shot, x)
half_dF2 = quadgk(f, ρ, 1.0)[1]
F2 = shot.Fbnd^2 - 2.0 * half_dF2
ρend, Fend = endpoint
half_F2 = quadgk(f, ρend, ρ)[1]
F2 = Fend^2 + 2.0 * half_F2
return sign(shot.Fbnd) * sqrt(F2)
end

# Misnomer: "poloidal_current" is actually Fpol = R*Bt, so we'll rename
function MXHEquilibrium.poloidal_current(shot::Shot, psi)
function MXHEquilibrium.poloidal_current(shot::Shot, psi; use_cached::Bool=true)
rho = ρ(shot, psi)
return Fpol(shot, rho)
return use_cached ? shot.F(rho) : Fpol(shot, rho)
end

function dFpol_dψ(shot::Shot, ρ::Real)
Expand Down

0 comments on commit 37b5879

Please sign in to comment.