Skip to content

Commit

Permalink
Merge pull request #51 from mehdiataei/major-refactoring
Browse files Browse the repository at this point in the history
Removed the need to pass sharding_flag in for distributed workload
  • Loading branch information
mehdiataei authored Jul 30, 2024
2 parents fc87b9b + 6f7397e commit 987d136
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def setup_stepper(self, omega):
omega, boundary_conditions=self.boundary_conditions
)
distributed_stepper = distribute(
stepper, self.grid, self.velocity_set, sharding_flags=(True, True, True, True, False)
stepper, self.grid, self.velocity_set,
)
self.stepper = distributed_stepper
return
Expand Down
70 changes: 42 additions & 28 deletions xlb/distribute/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def distribute(
operator: Operator,
grid,
velocity_set,
sharding_flags: Tuple[bool, ...],
num_results=1,
ops="permute",
) -> Operator:
Expand All @@ -27,47 +26,62 @@ def _sharded_operator(*args):
rightPerm = [(i, (i + 1) % grid.nDevices) for i in range(grid.nDevices)]
leftPerm = [((i + 1) % grid.nDevices, i) for i in range(grid.nDevices)]

right_comm = lax.ppermute(
left_comm, right_comm = (
result[velocity_set.right_indices, :1, ...],
result[velocity_set.left_indices, -1:, ...],
)

left_comm = lax.ppermute(
left_comm,
perm=rightPerm,
axis_name="x",
)

left_comm = lax.ppermute(
result[velocity_set.left_indices, -1:, ...],
right_comm = lax.ppermute(
right_comm,
perm=leftPerm,
axis_name="x",
)

result = result.at[velocity_set.right_indices, :1, ...].set(
right_comm
)
result = result.at[velocity_set.left_indices, -1:, ...].set(
left_comm
)
result = result.at[velocity_set.right_indices, :1, ...].set(left_comm)
result = result.at[velocity_set.left_indices, -1:, ...].set(right_comm)

return result
else:
raise NotImplementedError(f"Operation {ops} not implemented")

in_specs = tuple(
P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P()
for flag in sharding_flags
)
out_specs = tuple(
P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results)
)
# Build sharding_flags and in_specs based on args
def build_specs(grid, *args):
sharding_flags = []
in_specs = []
for arg in args:
if arg.shape[1:] == grid.shape:
sharding_flags.append(True)
else:
sharding_flags.append(False)

in_specs = tuple(
P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P()
for flag in sharding_flags
)
out_specs = tuple(
P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results)
)
return tuple(sharding_flags), in_specs, out_specs

def _wrapped_operator(*args):
sharding_flags, in_specs, out_specs = build_specs(grid, *args)

if len(out_specs) == 1:
out_specs = out_specs[0]
if len(out_specs) == 1:
out_specs = out_specs[0]

distributed_operator = shard_map(
_sharded_operator,
mesh=grid.global_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
distributed_operator = jit(distributed_operator)
distributed_operator = shard_map(
_sharded_operator,
mesh=grid.global_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
return distributed_operator(*args)

return distributed_operator
return jit(_wrapped_operator)

0 comments on commit 987d136

Please sign in to comment.