Skip to content

Commit

Permalink
Improve orthogonalize. Simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Oct 18, 2024
1 parent d50387b commit d460884
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 92 deletions.
49 changes: 37 additions & 12 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -585,36 +585,61 @@ end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return _orthogonalize_edges(tn, [edge]; kwargs...)
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function _orthogonalize_edges(
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
)
# tn = factorize(tn, edge; kwargs...)
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
# return contract(tn, new_vertex => dst(edge))
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
for edge in edges
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
end
return tn
end

function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return _orthogonalize_edge(tn, edge; kwargs...)
end

function ITensorMPS.orthogonalize(
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
)
return _orthogonalize_edges(tn, edges; kwargs...)
end

function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return orthogonalize(tn, edgetype(tn)(edge); kwargs...)
end

# Orthogonalize an ITensorNetwork towards a source vertex, treating
function ITensorMPS.orthogonalize(
tn::AbstractITensorNetwork, edges::Vector{Pair}; kwargs...
)
return orthogonalize(tn, edgetype(tn).(edges); kwargs...)
end

# Orthogonalize an ITensorNetwork towards a region, treating
# the network as a tree spanned by a spanning tree.
# TODO: Rename `tree_orthogonalize`.
function ITensorMPS.orthogonalize::AbstractITensorNetwork, source_vertex)
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex)
for e in spanning_tree_edges
ψ = orthogonalize(ψ, e)
end
return ψ
function ITensorMPS.orthogonalize::AbstractITensorNetwork, region::Vector)
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region))
spanning_tree_edges = filter(
e -> !(src(e) region && dst(e) region), spanning_tree_edges
)
return orthogonalize(ψ, spanning_tree_edges)
end

function ITensorMPS.orthogonalize::AbstractITensorNetwork, region)
return orthogonalize(ψ, [region])
end

# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
Expand Down
39 changes: 2 additions & 37 deletions src/solvers/alternating_update/region_update.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,3 @@
#ToDo: generalize beyond 2-site
#ToDo: remove concept of orthogonality center for generality
function current_ortho(sweep_plan, which_region_update)
regions = first.(sweep_plan)
region = regions[which_region_update]
current_verts = support(region)
if !isa(region, AbstractEdge) && length(region) == 1
return only(current_verts)
end
# look forward
other_regions = filter(
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
)
# find the first region that has overlapping support with current region
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
if isnothing(ind)
# look backward
other_regions = reverse(
filter(
x -> !(issetequal(x, current_verts)), support.(regions[1:(which_region_update - 1)])
),
)
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
end
@assert !isnothing(ind)
future_verts = union(support(other_regions[ind]))
# return ortho_ceter as the vertex in current region that does not overlap with following one
overlapping_vertex = intersect(current_verts, future_verts)
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
return nonoverlapping_vertex
end

function region_update(
projected_operator,
state;
Expand All @@ -55,14 +23,13 @@ function region_update(

# ToDo: remove orthogonality center on vertex for generality
# region carries same information
ortho_vertex = current_ortho(sweep_plan, which_region_update)
if !isnothing(transform_operator)
projected_operator = transform_operator(
state, projected_operator; outputlevel, transform_operator_kwargs...
)
end
state, projected_operator, phi = extracter(
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
state, projected_operator, region; extracter_kwargs..., internal_kwargs
)
# create references, in case solver does (out-of-place) modify PH or state
state! = Ref(state)
Expand All @@ -88,9 +55,7 @@ function region_update(
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
# so noiseterm is a solver
#end
state, spec = inserter(
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
)
state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs)
all_kwargs = (;
which_region_update,
sweep_plan,
Expand Down
13 changes: 7 additions & 6 deletions src/solvers/extract/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
# insert_local_tensors takes that tensor and factorizes it back
# apart and puts it back into the network.
#
function default_extracter(state, projected_operator, region, ortho; internal_kwargs)
state = orthogonalize(state, ortho)
function default_extracter(state, projected_operator, region; internal_kwargs)
if isa(region, AbstractEdge)
other_vertex = only(setdiff(support(region), [ortho]))
left_inds = uniqueinds(state[ortho], state[other_vertex])
vsrc, vdst = src(region), dst(region)
state = orthogonalize(state, vsrc)
left_inds = uniqueinds(state[vsrc], state[vdst])
#ToDo: replace with call to factorize
U, S, V = svd(
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
)
state[ortho] = U
state[vsrc] = U
local_tensor = S * V
else
state = orthogonalize(state, region)
local_tensor = prod(state[v] for v in region)
end
projected_operator = position(projected_operator, state, region)
Expand Down
22 changes: 9 additions & 13 deletions src/solvers/insert/insert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
function default_inserter(
state::AbstractTTN,
phi::ITensor,
region,
ortho_vert;
region;
normalize=false,
maxdim=nothing,
mindim=nothing,
Expand All @@ -16,16 +15,14 @@ function default_inserter(
)
state = copy(state)
spec = nothing
other_vertex = setdiff(support(region), [ortho_vert])
if !isempty(other_vertex)
v = only(other_vertex)
e = edgetype(state)(ortho_vert, v)
indsTe = inds(state[ortho_vert])
if length(region) == 2
v = last(region)
e = edgetype(state)(first(region), last(region))
indsTe = inds(state[first(region)])
L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff)
state[ortho_vert] = L

state[first(region)] = L
else
v = ortho_vert
v = only(region)
end
state[v] = phi
state = set_ortho_region(state, [v])
Expand All @@ -44,8 +41,7 @@ function default_inserter(
normalize=false,
internal_kwargs,
)
v = only(setdiff(support(region), [ortho]))
state[v] *= phi
state = set_ortho_region(state, [v])
state[dst(region)] *= phi
state = set_ortho_region(state, [dst(region)])
return state, nothing
end
40 changes: 16 additions & 24 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using Graphs: has_vertex
using NamedGraphs.GraphsExtensions:
GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
GraphsExtensions,
edge_path,
leaf_vertices,
post_order_dfs_edges,
post_order_dfs_vertices,
a_star
using NamedGraphs: namedgraph_a_star
using IsApprox: IsApprox, Approx
using ITensors: @Algorithm_str, directsum, hasinds, permute, plev
using ITensorMPS: linkind, loginner, lognorm, orthogonalize
Expand Down Expand Up @@ -29,30 +35,16 @@ function set_ortho_region(tn::AbstractTTN, new_region)
return error("Not implemented")
end

#
# Orthogonalization
#

function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...)
if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn))
return tn
end
# TODO: Rewrite this in a more general way.
if isone(length(ortho_region(tn)))
edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center)
else
edge_list = post_order_dfs_edges(tn, ortho_center)
end
for e in edge_list
tn = orthogonalize(tn, e)
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
paths = [
namedgraph_a_star(underlying_graph(ttn), rp, r) for r in region for
rp in ortho_region(ttn)
]
path = unique(reduce(vcat, paths))
if !isempty(path)
ttn = typeof(ttn)(orthogonalize(ITensorNetwork(ttn), path; kwargs...))
end
return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center]))
end

# For ambiguity error

function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...)
return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...))
return set_ortho_region(ttn, region)
end

#
Expand Down

0 comments on commit d460884

Please sign in to comment.