Skip to content

Commit

Permalink
#2611 simplify HaloDepth implementation and finish tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Oct 31, 2024
1 parent 431e0c0 commit f4226d9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
23 changes: 7 additions & 16 deletions src/psyclone/dynamo0p3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4399,8 +4399,9 @@ def set_by_value(self, max_depth, var_depth, annexed_only, max_depth_m1):
:param bool max_depth: True if the field accesses all of the
halo and False otherwise
:param str var_depth: A variable name specifying the halo
:param var_depth: PSyIR expression specifying the halo
access depth, if one exists, and None if not
:type var_depth: :py:class:`psyclone.psyir.nodes.Node`
:param bool annexed_only: True if only the halo's annexed dofs
are accessed and False otherwise
:param bool max_depth_m1: True if the field accesses all of
Expand All @@ -4423,28 +4424,18 @@ def set_by_value(self, max_depth, var_depth, annexed_only, max_depth_m1):
# tree.
fake_assign = Assignment.create(
Reference(DataSymbol("tmp", INTEGER_TYPE)), var_depth.detach())
sched = self._parent.ancestor(Schedule)
sched = self._parent.ancestor(Schedule, include_self=True)
sched.addchild(fake_assign)

sym_maths.expand(fake_assign.rhs)
self._var_depth = fake_assign.rhs.detach()
fake_assign.detach()

def __str__(self):
table = self._parent.scope.symbol_table
depth_str = ""
if self.max_depth:
max_depth = table.lookup_with_tag("max_halo_depth_mesh")
depth_str += max_depth.name
elif self.max_depth_m1:
max_depth = table.lookup_with_tag("max_halo_depth_mesh")
depth_str += f"{max_depth.name}-1"
else:
if self.var_depth:
depth_str += FortranWriter()(self.var_depth)
else:
depth_str = "0"
return depth_str
psyir = self.psyir_expression()
if psyir:
return FortranWriter()(psyir)
return "0"

def psyir_expression(self):
'''
Expand Down
37 changes: 34 additions & 3 deletions src/psyclone/tests/domain/lfric/halo_depth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'''

from psyclone.dynamo0p3 import HaloDepth
from psyclone.psyir import symbols, nodes
from psyclone.tests.utilities import get_invoke


Expand All @@ -47,8 +48,8 @@ def test_halo_depth_ctor():
Basic test that we can construct a HaloDepth object.
'''
psy, invoke = get_invoke("14.4.2_halo_vector_xory.f90",
"lfric", idx=0)
_, invoke = get_invoke("14.4.2_halo_vector_xory.f90",
"lfric", idx=0)
hdepth = HaloDepth(invoke.schedule)
assert hdepth.max_depth is False
assert hdepth.max_depth_m1 is False
Expand All @@ -61,6 +62,36 @@ def test_halo_depth_ctor():

def test_halo_depth_set_by_value():
'''
Test for the set_by_value() method of HaloDepth.
Test for the set_by_value() method of HaloDepth. Also indirectly tests
the psyir_expression() method by checking the result of str().
'''
_, invoke = get_invoke("14.4.2_halo_vector_xory.f90",
"lfric", idx=0)
hdepth = HaloDepth(invoke.schedule)
# Halo is accessed to max depth.
hdepth.set_by_value(True, None, False, False)
assert hdepth.var_depth is None
assert str(hdepth) == "max_halo_depth_mesh"
# Halo is accessed to max-depth minus 1.
hdepth.set_by_value(False, None, False, True)
assert str(hdepth) == "max_halo_depth_mesh - 1"
# Annexed dofs only.
hdepth.set_by_value(False, None, True, False)
assert str(hdepth) == "0"
# PSyIR expression.
my_depth = symbols.DataSymbol("my_depth", symbols.INTEGER_TYPE)
invoke.schedule.symbol_table.add(my_depth)
exprn = nodes.BinaryOperation.create(
nodes.BinaryOperation.Operator.MUL,
nodes.Literal("2", symbols.INTEGER_TYPE),
nodes.Reference(my_depth))
hdepth.set_by_value(False, exprn, False, False)
assert str(hdepth) == "2 * my_depth"
# Check that the PSyIR expression is simplified where possible.
exprn2 = nodes.BinaryOperation.create(
nodes.BinaryOperation.Operator.MUL,
nodes.Literal("2", symbols.INTEGER_TYPE),
nodes.Literal("2", symbols.INTEGER_TYPE))
hdepth.set_by_value(False, exprn2, False, False)
assert str(hdepth) == "4"

0 comments on commit f4226d9

Please sign in to comment.