Skip to content

Commit

Permalink
Merge pull request #20 from albertomercurio:albertomercurio-patch-1
Browse files Browse the repository at this point in the history
Optimized expv!
  • Loading branch information
albertomercurio authored Feb 21, 2024
2 parents 975e2cb + 85a826e commit 448c7a1
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/arnoldi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ export expv!, expv
struct ArnoldiSpace{VT<:AbstractMatrix{<: BlasFloat}, HT<:AbstractMatrix{<: BlasFloat}, mT<:Integer}
V::VT
H::HT
Hcopy::HT
m::mT
end

function Base.copy(AS::ArnoldiSpace{<:AbstractMatrix{T1}, <:AbstractMatrix{T1}}) where T1 <: BlasFloat
ArnoldiSpace(copy(AS.V), copy(AS.H), AS.m)
ArnoldiSpace(copy(AS.V), copy(AS.H), copy(AS.Hcopy), AS.m)
end

function Base.deepcopy(AS::ArnoldiSpace{<:AbstractMatrix{T1}, <:AbstractMatrix{T1}}) where T1 <: BlasFloat
ArnoldiSpace(deepcopy(AS.V), deepcopy(AS.H), AS.m)
ArnoldiSpace(deepcopy(AS.V), deepcopy(AS.H), deepcopy(AS.Hcopy), AS.m)
end

function arnoldi_init!(A, b::AbstractVector{T}, V::AbstractMatrix{T}, H::AbstractMatrix{T}) where T <: BlasFloat
Expand Down Expand Up @@ -62,7 +63,7 @@ function arnoldi(A, b::AbstractVector{T}, m::Integer) where T <: BlasFloat
n = size(A, 2)
V = similar(b, n, m+1)
H = zeros(T, m+1, m)
AS = ArnoldiSpace(V, H, m)
AS = ArnoldiSpace(V, H, copy(H), m)
arnoldi!(AS, A, b)
end

Expand All @@ -72,24 +73,26 @@ function expv!(x::AbstractVector{T1}, AS::ArnoldiSpace{<:AbstractMatrix{T1}, <:A
t::T2, b::AbstractVector{T1}) where {T1 <: BlasFloat, T2 <: Union{BlasFloat, BlasInt}}

H = AS.H
Hcopy = AS.Hcopy
V = AS.V
m = AS.m

Hm = view(H, 1:m, 1:m)
Vm = view(V, :, 1:m)
lmul!(t, Hm)

# expH = LinearAlgebra.exp!(Hm)
expH = exp(Hm)
Hcopym = view(Hcopy, 1:m, 1:m)
copyto!(Hcopym, Hm)
expH = LinearAlgebra.exp!(Hcopym)
# expH = exp(Hm)

β = norm(b)
expHe = view(expH, :, 1)
expHe = expH[:, 1] # the view doesn't work when using copyto! with CUDA
cache = similar(x, m)
cache .= expHe # just in case expHe is different from cache (e.g., with CUDA)
copyto!(cache, expHe) # just in case expHe is different from cache (e.g., with CUDA)
mul!(x, Vm, cache)
lmul!(β, x)


return x
end

Expand Down

0 comments on commit 448c7a1

Please sign in to comment.