Skip to content

Commit

Permalink
combine complex-proxy hybridpc into utils HybridisationSCPC
Browse files Browse the repository at this point in the history
  • Loading branch information
JHopeCollins committed Mar 26, 2024
1 parent a8cd450 commit a0b8ac3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 165 deletions.
134 changes: 3 additions & 131 deletions case_studies/shallow_water/blockpc/lswe_cpx_hybr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils import shallow_water as swe
from utils.planets import earth
from utils import units
from utils.hybridisation import HybridisedSCPC
from utils.hybridisation import HybridisedSCPC # noqa: F401

import numpy as np
from scipy.fft import fft, fftfreq
Expand Down Expand Up @@ -109,150 +109,24 @@ def form_function(u, h, v, q, t=None):
u, h, v, q, t)


# shallow water equation forms with trace variable
def form_mass_tr(u, h, tr, v, q, s):
return form_mass(u, h, v, q)


def form_function_tr(u, h, tr, v, q, dtr, t=None):
K = form_function(u, h, v, q, t)
n = fd.FacetNormal(mesh)
K += (
g*fd.jump(v, n)*tr('+')
)*fd.dS
return K


def form_trace(u, h, tr, v, q, dtr, t=None):
n = fd.FacetNormal(mesh)
K = (
+ fd.jump(u, n)*dtr('+')
)*fd.dS
return K


class OldHybridisedSCPC(fd.PCBase):
def initialize(self, pc):
if pc.getType() != "python":
raise ValueError("Expecting PC type python")

from utils.hybridisation import BrokenHDivProjector
self.projector = BrokenHDivProjector(Wu, Wub)

self.x = fd.Cofunction(W.dual())
self.y = fd.Function(W)

self.xu, self.xh = self.x.subfunctions
self.yu, self.yh = self.y.subfunctions

self.xtr = fd.Cofunction(Wtr.dual()).assign(0)
self.ytr = fd.Function(Wtr)

self.xbu, self.xbh, self.xbt = self.xtr.subfunctions
self.ybu, self.ybh, self.ybt = self.ytr.subfunctions

M = cpx.BilinearForm(Wtr, d1c, form_mass_tr)
K = cpx.BilinearForm(Wtr, d2c, form_function_tr)
Tr = cpx.BilinearForm(Wtr, 1, form_trace)

A = M + K + Tr
L = self.xtr

scpc_params = {
"mat_type": "matfree",
"ksp_type": "preonly",
"pc_type": "python",
"pc_python_type": "firedrake.SCPC",
"pc_sc_eliminate_fields": "0, 1",
"condensed_field": condensed_params
}

problem = fd.LinearVariationalProblem(A, L, self.ytr)
self.solver = fd.LinearVariationalSolver(
problem, solver_parameters=scpc_params)

def apply(self, pc, x, y):
# copy into unbroken vector
with self.x.dat.vec_wo as v:
x.copy(v)

# break each component of velocity
self.projector.project(self.xu, self.xbu)

# depth already broken
self.xbh.assign(self.xh)

# zero trace residual
self.xbt.assign(0)

# eliminate and solve the trace system
self.ytr.assign(0)
self.solver.solve()

# mend each component of velocity
self.projector.project(self.ybu, self.yu)

# depth already mended
self.yh.assign(self.ybh)

# copy out to petsc
with self.y.dat.vec_ro as v:
v.copy(y)

def update(self, pc):
pass

def applyTranspose(self, pc, x, y):
raise NotImplementedError


# random rhs
L = fd.Cofunction(W.dual())

# PETSc solver parameters
lu_params = {
'ksp_type': 'preonly',
'pc_type': 'lu',
# 'pc_factor_mat_ordering_type': 'rcm',
'pc_factor_mat_solver_type': 'mumps'
}

ilu_params = {
'ksp_type': 'preonly',
'pc_type': 'ilu',
}

gamg_params = {
'ksp_type': 'richardson',
# 'ksp_view': None,
'ksp_rtol': 1e-12,
'ksp_monitor': ':trace_monitor.log',
'ksp_converged_rate': None,
'pc_type': 'gamg',
'pc_gamg_threshold': 0.1,
'pc_gamg_agg_nsmooths': 0,
'pc_gamg_esteig_ksp_maxit': 10,
'pc_mg_cycle_type': 'v',
'pc_mg_type': 'multiplicative',
'mg_levels': {
'ksp_type': 'gmres',
# 'ksp_chebyshev_esteig': None,
# 'ksp_chebyshev_esteig_noisy': None,
# 'ksp_chebyshev_esteig_steps': 30,
'ksp_max_it': 5,
'pc_type': 'bjacobi',
'sub': ilu_params,
},
'mg_coarse': lu_params,
}

condensed_params = lu_params

scpc_params = {
"ksp_type": 'preonly',
"mat_type": "matfree",
"pc_type": "python",
"pc_python_type": f"{__name__}.HybridisedSCPC",
"hybridscpc_condensed_field": lu_params
}

rtol = 1e-3
Expand All @@ -261,10 +135,8 @@ def applyTranspose(self, pc, x, y):
'monitor': None,
'converged_rate': None,
'rtol': rtol,
# 'view': None
},
}
# params.update(lu_params)
params.update(scpc_params)

# trace component should have zero rhs
Expand All @@ -279,7 +151,7 @@ def applyTranspose(self, pc, x, y):

A = M + K

appctx = {'broken_space': Wub}
appctx = {'cpx': cpx}

wout = fd.Function(W).assign(0)
problem = fd.LinearVariationalProblem(A, L, wout)
Expand Down
86 changes: 52 additions & 34 deletions utils/hybridisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,52 @@ def project(self, src, dst):
"dst": (dst, INC)})


def _break_function_space(V, appctx):
cpx = appctx.get('cpx', None)
mesh = V.mesh()

# find HDiv space
iu = None
for i, Vi in enumerate(V):
if Vi.ufl_element().sobolev_space.name == "HDiv":
iu = i
break
if iu is None:
msg = "Hybridised space must have one HDiv component"
raise ValueError(msg)

# hybridisable space - broken HDiv and Trace
Vs = V.subfunctions

Vu = Vs[iu]

# broken HDiv space - either we are given one or we build one.
# If we are using complex-proxy then we need to build the real
# broken space first and then convert to complex-proxy.
if 'broken_space' in appctx:
Vub = appctx['broken_space']
else:
if cpx is None:
broken_element = fd.BrokenElement(Vu.ufl_element())
Vub = fd.FunctionSpace(mesh, broken_element)
else:
broken_element = fd.BrokenElement(Vu.sub(0).ufl_element())
broken_element = cpx.FiniteElement(broken_element)
Vub = fd.FunctionSpace(mesh, broken_element)

# trace space - possibly complex-valued
Tr = fd.FunctionSpace(mesh, "HDivT", Vu.ufl_element().degree())
if cpx is not None:
Tr = cpx.FunctionSpace(Tr)

# trace space always last component
trsubs = [Vs[i] if (i != iu) else Vub
for i in range(len(Vs))] + [Tr]
Vtr = fd.MixedFunctionSpace(trsubs)

return Vtr, iu


class HybridisedSCPC(fd.PCBase):
_prefix = "hybridscpc"

Expand All @@ -102,40 +148,15 @@ def initialize(self, pc):

V = test.function_space()
mesh = V.mesh()
ncpts = len(V)

# find HDiv space
iu = None
for i, Vi in enumerate(V):
if Vi.ufl_element().sobolev_space.name == "HDiv":
iu = i
break
if iu is None:
msg = "Hybridised space must have one HDiv component"
raise ValueError(msg)
# break the HDiv component of the function space,
# leaving the rest untouched
Vtr, iu = _break_function_space(V, appctx)
self.iu = iu

# hybridisable space - broken HDiv and Trace
Vs = V.subfunctions
ncpts = len(Vs)

Vu = Vs[iu]

if 'broken_space' in appctx:
Vub = appctx['broken_space']
else:
broken_element = fd.BrokenElement(Vu.ufl_element())
Vub = fd.FunctionSpace(mesh, broken_element)

Tr = fd.FunctionSpace(mesh, "HDivT", Vu.ufl_element().degree())

# trace space always last component
trsubs = [Vs[i] if (i != iu) else Vub
for i in range(ncpts)] + [Tr]
Vtr = fd.MixedFunctionSpace(trsubs)
print(Vtr)

# breaks/mends the velocity residual
self.projector = BrokenHDivProjector(Vu, Vub)
self.projector = BrokenHDivProjector(V[iu], Vtr[iu])

# build working buffers
self.x = fd.Cofunction(V.dual())
Expand All @@ -156,6 +177,7 @@ def initialize(self, pc):

# add the trace bit
n = fd.FacetNormal(mesh)

def form_trace(*args):
trls = args[:ncpts+1]
tsts = args[ncpts+1:]
Expand All @@ -169,10 +191,6 @@ def form_trace(*args):
cpx = appctx['cpx']
A += cpx.BilinearForm(Vtr, 1, form_trace)
else:
# A += (
# fd.jump(vtrs[iu], n)*utrs[-1]('+')
# + fd.jump(utrs[iu], n)*vtrs[-1]('+')
# )*fd.dS
A += form_trace(*utrs, *vtrs)

L = self.xtr
Expand Down

0 comments on commit a0b8ac3

Please sign in to comment.