diff --git a/Project.toml b/Project.toml index 806d0ed..5d55f4d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StableLinearAlgebra" uuid = "7d06c537-f8ff-4c22-91e1-ce4088e3cfd7" authors = ["Benjamin Cohen-Stead "] -version = "1.3.3" +version = "1.3.4" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/exported_functions.jl b/src/exported_functions.jl index 23194f7..4fbc1a2 100644 --- a/src/exported_functions.jl +++ b/src/exported_functions.jl @@ -130,8 +130,10 @@ The numerically stable inverse ``G := [I + UV]^{-1}`` is calculated using the pr \begin{align*} G:= & [I+UV]^{-1}\\ = & [I+L_{u}D_{u}R_{u}L_{v}D_{v}R_{v}]^{-1}\\ -= & R_{v}^{-1}[\overset{M}{\overbrace{L_{u}^{\dagger}R_{v}^{-1}+D_{u}R_{u}L_{v}D_{v}}}]^{-1}L_{u}^{\dagger}\\ -= & R_{v}^{-1}M^{-1}L_{u}^{\dagger}. += & [I+L_{u}D_{u,\max}D_{u,\min}R_{u}L_{v}D_{v,\min}D_{v,\max}R_{v}]^{-1}\\ += & [L_{u}D_{u,\max}(D_{u,\max}^{-1}L_{u}^{\dagger}R_{v}^{-1}D_{v,\max}^{-1}+D_{u,\min}R_{u}L_{v}D_{v,\min})D_{v,\max}R_{v}]^{-1}\\ += & R_{v}^{-1}D_{v,\max}^{-1}[\overset{M}{\overbrace{D_{u,\max}^{-1}L_{u}^{\dagger}R_{v}^{-1}D_{v,\max}^{-1}+D_{u,\min}R_{u}L_{v}D_{v,\min}}}]^{-1}D_{u,\max}^{-1}L_{u}^{\dagger}\\ += & R_{v}^{-1}D_{v,\max}^{-1}M^{-1}D_{u,\max}^{-1}L_{u}^{\dagger} \end{align*} ``` Intermediate matrix inversions and relevant determinant calculations are performed @@ -141,45 +143,76 @@ function inv_IpUV!(G::AbstractMatrix{T}, U::LDR{T,E}, V::LDR{T,E}, ws::LDRWorksp Lᵤ = U.L dᵤ = U.d - Rᵤ = V.R + Rᵤ = U.R Lᵥ = V.L dᵥ = V.d Rᵥ = V.R - # calculate sign(det(Lᵤ)) and log(det(Lᵤ)) + # calculate sign(det(Lᵤ)) and log(|det(Lᵤ)|) copyto!(ws.M, Lᵤ) logdetLᵤ, sgndetLᵤ = det_lu!(ws.M, ws.lu_ws) - # calculate Rᵥ⁻¹, sign(det(Rᵥ⁻¹)) and log(det(Rᵥ⁻¹)) + # calculate Rᵥ⁻¹, sign(det(Rᵥ⁻¹)) and log(|det(Rᵥ⁻¹)|) Rᵥ⁻¹ = ws.M′ copyto!(Rᵥ⁻¹, Rᵥ) logdetRᵥ⁻¹, sgndetRᵥ⁻¹ = inv_lu!(Rᵥ⁻¹, ws.lu_ws) - # represent Lᵤᵀ - Lᵤᵀ = adjoint(Lᵤ) + # calcuate Dᵥ₊ = max(Dᵥ, 1) + dᵥ₊ = ws.v + @. dᵥ₊ = max(dᵥ, 1) - # calculate Dᵤ⋅Rᵤ⋅Lᵥ⋅Dᵥ - mul!(ws.M, Rᵤ, Lᵥ) # Rᵤ⋅Lᵥ - lmul_D!(dᵤ, ws.M) # Dᵤ⋅[Rᵤ⋅Lᵥ] - rmul_D!(ws.M, dᵥ) # [Dᵤ⋅Rᵤ⋅Lᵥ]⋅Dᵥ + # calculate sign(det(Dᵥ₊)) and log(|det(Dᵥ₊)|) + logdetDᵥ₊, sgndetDᵥ₊ = det_D(dᵥ₊) - # calculate Lᵤᵀ⋅Rᵥ⁻¹ - mul!(G, Lᵤᵀ, Rᵥ⁻¹) + # calculate Rᵥ⁻¹⋅Dᵥ₊⁻¹ + rdiv_D!(Rᵥ⁻¹, dᵥ₊) + Rᵥ⁻¹Dᵥ₊⁻¹ = Rᵥ⁻¹ - # calculate M = Lᵤᵀ⋅Rᵥ⁻¹ + Dᵤ⋅Rᵤ⋅Lᵥ⋅Dᵥ - @. G += ws.M + # calcuate Dᵤ₊ = max(Dᵤ, 1) + dᵤ₊ = ws.v + @. dᵤ₊ = max(dᵤ, 1) + + # calculate sign(det(Dᵥ₊)) and log(|det(Dᵥ₊)|) + logdetDᵤ₊, sgndetDᵤ₊ = det_D(dᵤ₊) + + # calcualte Dᵤ₊⁻¹⋅Lᵤᵀ + adjoint!(ws.M, Lᵤ) + ldiv_D!(dᵤ₊, ws.M) + Dᵤ₊⁻¹Lᵤᵀ = ws.M - # calculate M⁻¹ = [Lᵤᵀ⋅Rᵥ⁻¹ + Dᵤ⋅Rᵤ⋅Lᵥ⋅Dᵥ]⁻¹ + # calculate Dᵤ₋ = min(Dᵤ, 1) + dᵤ₋ = ws.v + @. dᵤ₋ = min(dᵤ, 1) + + # calculate Dᵤ₋⋅Rᵤ⋅Lᵥ + mul!(G, Rᵤ, Lᵥ) # Rᵤ⋅Lᵥ + lmul_D!(dᵤ₋, G) # Dᵤ₋⋅Rᵤ⋅Lᵥ + + # calculate Dᵥ₋ = min(Dᵥ, 1) + dᵥ₋ = ws.v + @. dᵥ₋ = min(dᵥ, 1) + + # caluclate Dᵤ₋⋅Rᵤ⋅Lᵥ⋅Dᵥ₋ + rmul_D!(G, dᵥ₋) + + # caluclate Dᵤ₊⁻¹⋅Lᵤᵀ⋅Rᵥ⁻¹⋅Dᵥ₊⁻¹ + mul!(ws.M″, Dᵤ₊⁻¹Lᵤᵀ, Rᵥ⁻¹Dᵥ₊⁻¹) + + # calculate M = Dᵤ₊⁻¹⋅Lᵤᵀ⋅Rᵥ⁻¹⋅Dᵥ₊⁻¹ + Dᵤ₋⋅Rᵤ⋅Lᵥ⋅Dᵥ₋ + M = G + @. M = ws.M″ + M + + # calculate M⁻¹, sign(det(M)) and log(|det(M)|) M⁻¹ = G logdetM⁻¹, sgndetM⁻¹ = inv_lu!(M⁻¹, ws.lu_ws) - # calculate G := Rᵥ⁻¹⋅M⁻¹⋅Lᵤᵀ - mul!(ws.M, M⁻¹, Lᵤᵀ) # M⁻¹⋅Lᵤᵀ - mul!(G, Rᵥ⁻¹, ws.M) # G := Rᵥ⁻¹⋅[M⁻¹⋅Lᵤᵀ] + # calculate G = Rᵥ⁻¹⋅Dᵥ₊⁻¹⋅M⁻¹⋅Dᵤ₊⁻¹⋅Lᵤᵀ + mul!(ws.M″, M⁻¹, Dᵤ₊⁻¹Lᵤᵀ) # M⁻¹⋅Dᵤ₊⁻¹⋅Lᵤᵀ + mul!(G, Rᵥ⁻¹Dᵥ₊⁻¹, ws.M″) # G = Rᵥ⁻¹⋅Dᵥ₊⁻¹⋅M⁻¹⋅Dᵤ₊⁻¹⋅Lᵤᵀ # calculate sign(det(G)) and log(|det(G)|) - sgndetG = sgndetRᵥ⁻¹ * sgndetM⁻¹ * conj(sgndetLᵤ) - logdetG = logdetM⁻¹ + sgndetG = sgndetRᵥ⁻¹ * conj(sgndetDᵥ₊) * sgndetM⁻¹ * conj(sgndetDᵤ₊) * conj(sgndetLᵤ) + logdetG = -logdetDᵥ₊ + logdetM⁻¹ - logdetDᵤ₊ return real(logdetG), sgndetG end @@ -329,7 +362,7 @@ function inv_invUpV!(G::AbstractMatrix{T}, U::LDR{T,E}, V::LDR{T,E}, ws::LDRWork Lᵤ = U.L dᵤ = U.d - Rᵤ = V.R + Rᵤ = U.R Lᵥ = V.L dᵥ = V.d Rᵥ = V.R diff --git a/test/runtests.jl b/test/runtests.jl index 014122e..63f457a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -246,11 +246,16 @@ end @test A ≈ I # testing inv_IpUV! + F′ = ldr(F) copyto!(F, I, ws) + copyto!(F′, I, ws) for i in 1:Nₛ÷2 - lmul!(B̄, F, ws) + rmul!(F, B̄, ws) + end + for i in Nₛ÷2+1:Nₛ + rmul!(F′, B̄, ws) end - logdetG, sgndetG = inv_IpUV!(G, F, F, ws) + logdetG, sgndetG = inv_IpUV!(G, F, F′, ws) @test G ≈ G_0 @test sgndetG ≈ sgndetG_0 @test logdetG ≈ logdetG_0