Skip to content

Commit

Permalink
Extend FormSplitter to handle restrictions (#310)
Browse files Browse the repository at this point in the history
* Add handling of case where a split returns a form with no arguments (they have all been reduced to zeros).

* Add test case

* Add test and some error handling

* Apply restriction in formsplitter

* Fix docstring
  • Loading branch information
jorgensd authored Oct 2, 2024
1 parent 3b85665 commit 8384202
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
21 changes: 21 additions & 0 deletions test/test_extract_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,24 @@ def lhs(u, p, v, q):
for j in range(2):
J_ij_ext = ufl.extract_blocks(J, i, j)
assert J_sub[2 * i + j].signature() == J_ij_ext.signature()


def test_postive_restricted_extract_none():
cell = ufl.triangle
d = cell.topological_dimension()
domain = ufl.Mesh(FiniteElement("Lagrange", cell, 1, (d,), ufl.identity_pullback, ufl.H1))
el_u = FiniteElement("Lagrange", cell, 2, (d,), ufl.identity_pullback, ufl.H1)
el_p = FiniteElement("Lagrange", cell, 1, (), ufl.identity_pullback, ufl.H1)
V = ufl.FunctionSpace(domain, el_u)
Q = ufl.FunctionSpace(domain, el_p)
W = ufl.MixedFunctionSpace(V, Q)
u, p = ufl.TrialFunctions(W)
v, q = ufl.TestFunctions(W)
a = (
ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx
+ ufl.div(u) * q * ufl.dx
+ ufl.div(v) * p * ufl.dx
)
a += ufl.inner(u("+"), v("+")) * ufl.dS
a_blocks = ufl.extract_blocks(a)
assert a_blocks[1][1] is None
13 changes: 12 additions & 1 deletion ufl/algorithms/formsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing import Optional

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags
from ufl.argument import Argument
from ufl.classes import FixedIndex, ListTensor
from ufl.constantvalue import Zero
Expand Down Expand Up @@ -81,6 +81,15 @@ def multi_index(self, obj):
"""Apply to multi_index."""
return obj

def restricted(self, o):
"""Apply to a restricted function."""
# If we hit a restriction first apply form splitter to argument, then check for zero
op_split = map_expr_dag(self, o.ufl_operands[0])
if isinstance(op_split, Zero):
return op_split
else:
return op_split(o._side)

expr = MultiFunction.reuse_if_untouched


Expand Down Expand Up @@ -139,6 +148,8 @@ def extract_blocks(form, i: Optional[int] = None, j: Optional[None] = None):
if f.empty():
form_i.append(None)
else:
if (num_args := len(f.arguments())) != 2:
raise RuntimeError(f"Expected 2 arguments, got {num_args}")
form_i.append(f)
forms.append(tuple(form_i))
else:
Expand Down

0 comments on commit 8384202

Please sign in to comment.