Skip to content

Commit

Permalink
Merge pull request #188 from firedrakeproject/split_ensemble_update
Browse files Browse the repository at this point in the history
update `EnsembleConnector` with new `pyop2.internal_comm` implementation
  • Loading branch information
JHopeCollins authored May 7, 2024
2 parents 9f43f62 + 0b6bfc6 commit 9425562
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions asQ/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from firedrake import COMM_WORLD, Ensemble
from pyop2.mpi import internal_comm, decref
from pyop2.mpi import internal_comm

__all__ = ['create_ensemble', 'split_ensemble', 'EnsembleConnector']

Expand All @@ -23,7 +23,7 @@ def create_ensemble(time_partition, comm=COMM_WORLD):
return Ensemble(comm, nspatial_domains)


def split_ensemble(ensemble, split_size):
def split_ensemble(ensemble, split_size, **kwargs):
"""
Split an Ensemble into multiple smaller Ensembles which share the same
spatial communicators `ensemble.comm`.
Expand All @@ -45,11 +45,11 @@ def split_ensemble(ensemble, split_size):
split_comm = ensemble.global_comm.Split(color=split_rank,
key=ensemble.global_comm.rank)

return EnsembleConnector(split_comm, ensemble.comm, split_size)
return EnsembleConnector(split_comm, ensemble.comm, split_size, **kwargs)


class EnsembleConnector(Ensemble):
def __init__(self, global_comm, local_comm, nmembers):
def __init__(self, global_comm, local_comm, nmembers, **kwargs):
"""
An Ensemble created from provided spatial communicators (ensemble.comm).
Expand All @@ -61,22 +61,16 @@ def __init__(self, global_comm, local_comm, nmembers):
msg = "The global ensemble must have the same number of ranks as the sum of the local comms"
raise ValueError(msg)

ensemble_name = kwargs.get("ensemble_name", "Ensemble")
self.global_comm = global_comm
self._global_comm = internal_comm(self.global_comm)
self._comm = internal_comm(self.global_comm, self)

self.comm = local_comm
self._comm = internal_comm(self.comm)
self.comm.name = f"{ensemble_name} spatial comm"
self._spatial_comm = internal_comm(self.comm, self)

self.ensemble_comm = self.global_comm.Split(color=self.comm.rank,
key=global_comm.rank)
self.ensemble_comm.name = f"{ensemble_name} ensemble comm"

self._ensemble_comm = internal_comm(self.ensemble_comm)

def __del__(self):
if hasattr(self, "ensemble_comm"):
self.ensemble_comm.Free()
del self.ensemble_comm
for comm_name in ["_global_comm", "_comm", "_ensemble_comm"]:
if hasattr(self, comm_name):
comm = getattr(self, comm_name)
decref(comm)
self._ensemble_comm = internal_comm(self.ensemble_comm, self)

0 comments on commit 9425562

Please sign in to comment.