Skip to content

Commit

Permalink
Add a CommTag class with a stable hash
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Smith <[email protected]>
  • Loading branch information
matthiasdiener and majosm committed Sep 13, 2023
1 parent 90b9b80 commit d756bf6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
6 changes: 3 additions & 3 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa

from grudge.dof_desc import as_dofdesc, DOFDesc, DISCR_TAG_BASE, DISCR_TAG_QUAD
from grudge.trace_pair import TracePair
from grudge.trace_pair import TracePair, CommTag
from grudge.discretization import DiscretizationCollection
from grudge.shortcuts import make_visualizer, compiled_lsrk45_step

Expand Down Expand Up @@ -95,7 +95,7 @@ def wave_flux(actx, dcoll, c, w_tpair):
return op.project(dcoll, dd, dd.with_dtag("all_faces"), c*flux_weak)


class _WaveStateTag:
class _WaveStateTag(CommTag):
pass


Expand Down Expand Up @@ -144,7 +144,7 @@ def interp_to_surf_quad(utpair):
) + sum(
wave_flux(actx, dcoll, c=c, w_tpair=interp_to_surf_quad(tpair))
for tpair in op.interior_trace_pairs(dcoll, w,
comm_tag=_WaveStateTag)
comm_tag=_WaveStateTag())
)
)
)
Expand Down
34 changes: 25 additions & 9 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
---------------------------------------
.. autofunction:: interior_trace_pairs
.. autoclass:: CommTag
.. autofunction:: local_interior_trace_pair
.. autofunction:: cross_rank_trace_pairs
"""
Expand Down Expand Up @@ -70,7 +71,7 @@

from numbers import Number

from pytools import memoize_on_first_arg
from pytools import memoize_on_first_arg, memoize_method

from grudge.discretization import DiscretizationCollection
from grudge.projection import project
Expand Down Expand Up @@ -318,8 +319,19 @@ def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair:
return local_interior_trace_pair(dcoll, vec)


class CommTag:
"""A communication tag with a hash value that is stable across
runs, even without setting ``PYTHONHASHSEED``."""
@memoize_method
def __hash__(self):
return hash(tuple(str(type(self)).encode("ascii")))

def __eq__(self, other):
return isinstance(other, type(self))


def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
comm_tag: Hashable = None, tag: Hashable = None,
comm_tag: CommTag = None, tag: Hashable = None,
volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Return a :class:`list` of :class:`TracePair` objects
defined on the interior faces of *dcoll* and any faces connected to a
Expand All @@ -331,10 +343,10 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
:arg vec: a :class:`~meshmode.dof_array.DOFArray` or an
:class:`~arraycontext.ArrayContainer` of them.
:arg comm_tag: a hashable object used to match sent and received data
:arg comm_tag: a :class:`CommTag` used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitary, potentially composite objects.
tags to arbitrary, potentially composite objects.
:returns: a :class:`list` of :class:`TracePair` objects.
"""

Expand Down Expand Up @@ -379,7 +391,7 @@ def connected_ranks(
dcoll._volume_discrs[volume_dd.domain_tag.tag].mesh)


def _sym_tag_to_num_tag(comm_tag: Optional[Hashable]) -> Optional[int]:
def _sym_tag_to_num_tag(comm_tag: Optional[CommTag]) -> Optional[int]:
if comm_tag is None:
return comm_tag

Expand Down Expand Up @@ -498,10 +510,14 @@ class _RankBoundaryCommunicationLazy:
def __init__(self,
dcoll: DiscretizationCollection,
array_container: ArrayOrContainer,
remote_rank: int, comm_tag: Hashable,
remote_rank: int, comm_tag: Optional[CommTag],
volume_dd=DD_VOLUME_ALL):
if comm_tag is None:
raise ValueError("lazy communication requires 'tag' to be supplied")
raise ValueError("lazy communication requires 'comm_tag' to be supplied")

if not isinstance(comm_tag, CommTag):
from warnings import warn
warn(f"comm_tag {comm_tag} should be an instance of CommTag")

bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank))

Expand Down Expand Up @@ -544,7 +560,7 @@ def finish(self):
def cross_rank_trace_pairs(
dcoll: DiscretizationCollection, ary: ArrayOrContainer,
tag: Hashable = None,
*, comm_tag: Hashable = None,
*, comm_tag: CommTag = None,
volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Get a :class:`list` of *ary* trace pairs for each partition boundary.
Expand All @@ -570,7 +586,7 @@ def cross_rank_trace_pairs(
:arg comm_tag: a hashable object used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitary, potentially composite objects.
tags to arbitrary, potentially composite objects.
:returns: a :class:`list` of :class:`TracePair` objects.
"""
Expand Down

0 comments on commit d756bf6

Please sign in to comment.