diff --git a/irksome/__init__.py b/irksome/__init__.py index ab473b5..0fba006 100644 --- a/irksome/__init__.py +++ b/irksome/__init__.py @@ -12,6 +12,7 @@ from .getForm import getForm # noqa: F401 from .imex import RadauIIAIMEXMethod # noqa: F401 from .pc import RanaBase, RanaDU, RanaLD # noqa: F401 +from .pc import IRKAuxiliaryOperatorPC # noqa: F401 from .stage import StageValueTimeStepper # noqa: F401 from .stepper import TimeStepper # noqa: F401 from .tools import MeshConstant # noqa: F401 diff --git a/irksome/pc.py b/irksome/pc.py index f423560..d100616 100644 --- a/irksome/pc.py +++ b/irksome/pc.py @@ -1,10 +1,12 @@ import abc import copy + +import numpy from firedrake import AuxiliaryOperatorPC, derivative +from ufl import replace + from irksome import getForm from irksome.stage import getFormStage -import numpy -from ufl import replace # Oddly, we can't turn pivoting off in scipy? @@ -97,3 +99,43 @@ class RanaDU(RanaBase): def getAtilde(self, A): L, D, U = ldu(A) return D @ U + + +class IRKAuxiliaryOperatorPC(AuxiliaryOperatorPC): + @abc.abstractmethod + def getNewForm(self, pc, u0, test): + pass + + def form(self, pc, test, trial): + """Implements the interface for AuxiliaryOperatorPC.""" + appctx = self.get_appctx(pc) + butcher_tableau = appctx["butcher_tableau"] + oldF = appctx["F"] + t = appctx["t"] + dt = appctx["dt"] + u0 = appctx["u0"] + bcs = appctx["bcs"] + stage_type = appctx.get("stage_type", None) + bc_type = appctx.get("bc_type", None) + splitting = appctx.get("splitting", None) + nullspace = appctx.get("nullspace", None) + v0 = oldF.arguments()[0] + + F, bcs = self.getNewForm(pc, u0, v0) + # which getForm do I need to get? + + if stage_type in ("deriv", None): + Fnew, w, bcnew, bignsp, _ = \ + getForm(F, butcher_tableau, t, dt, u0, bcs, + bc_type, splitting, nullspace) + elif stage_type == "value": + Fnew, _, w, bcnew, _, bignsp = \ + getFormStage(F, butcher_tableau, u0, t, dt, bcs, + splitting, nullspace) + # Now we get the Jacobian for the modified system, + # which becomes the auxiliary operator! + test_old = Fnew.arguments()[0] + a = replace(derivative(Fnew, w, du=trial), + {test_old: test}) + + return a, bcnew diff --git a/tests/test_pc.py b/tests/test_pc.py index b4ef297..92a7081 100644 --- a/tests/test_pc.py +++ b/tests/test_pc.py @@ -3,24 +3,34 @@ from firedrake import (DirichletBC, Function, FunctionSpace, SpatialCoordinate, TestFunction, UnitSquareMesh, diff, div, dx, errornorm, grad, inner) -from irksome import Dt, MeshConstant, LobattoIIIC, RadauIIA, TimeStepper +from irksome import (Dt, IRKAuxiliaryOperatorPC, LobattoIIIC, MeshConstant, + RadauIIA, TimeStepper) from ufl.algorithms.ad import expand_derivatives # Tests that various PCs are actually getting the right answer. -def Fubc(V, uexact, rhs): +def Fubc(V, t, uexact): u = Function(V) u.interpolate(uexact) v = TestFunction(V) - F = inner(Dt(u), v)*dx + inner(grad(u), grad(v))*dx - inner(rhs, v)*dx + rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) - uexact * (1-uexact) + F = inner(Dt(u), v)*dx + inner(grad(u), grad(v))*dx - inner(u*(1-u), v)*dx - inner(rhs, v)*dx bc = DirichletBC(V, uexact, "on_boundary") return (F, u, bc) -def heat(butcher_tableau): +class myPC(IRKAuxiliaryOperatorPC): + def getNewForm(self, pc, u0, test): + appctx = self.get_appctx(pc) + bcs = appctx["bcs"] + F = inner(Dt(u0), test) * dx + inner(grad(u0), grad(test)) * dx + return F, bcs + + +def rd(butcher_tableau): N = 4 msh = UnitSquareMesh(N, N) @@ -34,17 +44,14 @@ def heat(butcher_tableau): x, y = SpatialCoordinate(msh) uexact = t*(x+y) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) sols = [] luparams = {"mat_type": "aij", - "snes_type": "ksponly", "ksp_type": "preonly", "pc_type": "lu"} ranaLD = {"mat_type": "aij", - "snes_type": "ksponly", "ksp_type": "gmres", "ksp_monitor": None, "pc_type": "python", @@ -60,7 +67,6 @@ def heat(butcher_tableau): ranaLD["fieldsplit_%s" % (s,)] = per_field ranaDU = {"mat_type": "aij", - "snes_type": "ksponly", "ksp_type": "gmres", "ksp_monitor": None, "pc_type": "python", @@ -73,10 +79,17 @@ def heat(butcher_tableau): for s in range(butcher_tableau.num_stages): ranaDU["fieldsplit_%s" % (s,)] = per_field - params = [luparams, ranaLD, ranaDU] + mypc_params = {"mat_type": "aij", + "ksp_type": "gmres", + "pc_type": "python", + "pc_python_type": "test_pc.myPC", + "aux": { + "pc_type": "lu"}} + + params = [luparams, ranaLD, ranaDU, mypc_params] for solver_parameters in params: - F, u, bc = Fubc(V, uexact, rhs) + F, u, bc = Fubc(V, t, uexact) stepper = TimeStepper(F, butcher_tableau, t, dt, u, bcs=bc, solver_parameters=solver_parameters) @@ -90,4 +103,4 @@ def heat(butcher_tableau): @pytest.mark.parametrize('butcher_tableau', (LobattoIIIC(3), RadauIIA(2))) def test_pc_acc(butcher_tableau): - assert heat(butcher_tableau) < 1.e-6 + assert rd(butcher_tableau) < 1.e-6