diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 72c72a2..e71da52 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -15,7 +15,7 @@ def setup_stepper(self, omega): omega, boundary_conditions=self.boundary_conditions ) distributed_stepper = distribute( - stepper, self.grid, self.velocity_set, sharding_flags=(True, True, True, True, False) + stepper, self.grid, self.velocity_set, ) self.stepper = distributed_stepper return diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index bb72072..ad07dd0 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -14,7 +14,6 @@ def distribute( operator: Operator, grid, velocity_set, - sharding_flags: Tuple[bool, ...], num_results=1, ops="permute", ) -> Operator: @@ -27,47 +26,62 @@ def _sharded_operator(*args): rightPerm = [(i, (i + 1) % grid.nDevices) for i in range(grid.nDevices)] leftPerm = [((i + 1) % grid.nDevices, i) for i in range(grid.nDevices)] - right_comm = lax.ppermute( + left_comm, right_comm = ( result[velocity_set.right_indices, :1, ...], + result[velocity_set.left_indices, -1:, ...], + ) + + left_comm = lax.ppermute( + left_comm, perm=rightPerm, axis_name="x", ) - left_comm = lax.ppermute( - result[velocity_set.left_indices, -1:, ...], + right_comm = lax.ppermute( + right_comm, perm=leftPerm, axis_name="x", ) - result = result.at[velocity_set.right_indices, :1, ...].set( - right_comm - ) - result = result.at[velocity_set.left_indices, -1:, ...].set( - left_comm - ) + result = result.at[velocity_set.right_indices, :1, ...].set(left_comm) + result = result.at[velocity_set.left_indices, -1:, ...].set(right_comm) return result else: raise NotImplementedError(f"Operation {ops} not implemented") - in_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() - for flag in sharding_flags - ) - out_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) - ) + # Build sharding_flags and in_specs based on args + def build_specs(grid, *args): + sharding_flags = [] + in_specs = [] + for arg in args: + if arg.shape[1:] == grid.shape: + sharding_flags.append(True) + else: + sharding_flags.append(False) + + in_specs = tuple( + P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() + for flag in sharding_flags + ) + out_specs = tuple( + P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) + ) + return tuple(sharding_flags), in_specs, out_specs + + def _wrapped_operator(*args): + sharding_flags, in_specs, out_specs = build_specs(grid, *args) - if len(out_specs) == 1: - out_specs = out_specs[0] + if len(out_specs) == 1: + out_specs = out_specs[0] - distributed_operator = shard_map( - _sharded_operator, - mesh=grid.global_mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - ) - distributed_operator = jit(distributed_operator) + distributed_operator = shard_map( + _sharded_operator, + mesh=grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + return distributed_operator(*args) - return distributed_operator + return jit(_wrapped_operator)