From 0b6bfc6d11d92bde0343a255fd21ceac8b8704ba Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 3 May 2024 17:02:09 +0100 Subject: [PATCH] update asQ.ensemble with new internal_comm --- asQ/ensemble.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/asQ/ensemble.py b/asQ/ensemble.py index 740ddd9a..cd774a63 100644 --- a/asQ/ensemble.py +++ b/asQ/ensemble.py @@ -1,5 +1,5 @@ from firedrake import COMM_WORLD, Ensemble -from pyop2.mpi import internal_comm, decref +from pyop2.mpi import internal_comm __all__ = ['create_ensemble', 'split_ensemble', 'EnsembleConnector'] @@ -23,7 +23,7 @@ def create_ensemble(time_partition, comm=COMM_WORLD): return Ensemble(comm, nspatial_domains) -def split_ensemble(ensemble, split_size): +def split_ensemble(ensemble, split_size, **kwargs): """ Split an Ensemble into multiple smaller Ensembles which share the same spatial communicators `ensemble.comm`. @@ -45,11 +45,11 @@ def split_ensemble(ensemble, split_size): split_comm = ensemble.global_comm.Split(color=split_rank, key=ensemble.global_comm.rank) - return EnsembleConnector(split_comm, ensemble.comm, split_size) + return EnsembleConnector(split_comm, ensemble.comm, split_size, **kwargs) class EnsembleConnector(Ensemble): - def __init__(self, global_comm, local_comm, nmembers): + def __init__(self, global_comm, local_comm, nmembers, **kwargs): """ An Ensemble created from provided spatial communicators (ensemble.comm). @@ -61,22 +61,16 @@ def __init__(self, global_comm, local_comm, nmembers): msg = "The global ensemble must have the same number of ranks as the sum of the local comms" raise ValueError(msg) + ensemble_name = kwargs.get("ensemble_name", "Ensemble") self.global_comm = global_comm - self._global_comm = internal_comm(self.global_comm) + self._comm = internal_comm(self.global_comm, self) self.comm = local_comm - self._comm = internal_comm(self.comm) + self.comm.name = f"{ensemble_name} spatial comm" + self._spatial_comm = internal_comm(self.comm, self) self.ensemble_comm = self.global_comm.Split(color=self.comm.rank, key=global_comm.rank) + self.ensemble_comm.name = f"{ensemble_name} ensemble comm" - self._ensemble_comm = internal_comm(self.ensemble_comm) - - def __del__(self): - if hasattr(self, "ensemble_comm"): - self.ensemble_comm.Free() - del self.ensemble_comm - for comm_name in ["_global_comm", "_comm", "_ensemble_comm"]: - if hasattr(self, comm_name): - comm = getattr(self, comm_name) - decref(comm) + self._ensemble_comm = internal_comm(self.ensemble_comm, self)