Skip to content

Commit

Permalink
fixed race conditioning in indices_boundary_masker due to duplicate b…
Browse files Browse the repository at this point in the history
…c indices at corners and edges.
  • Loading branch information
hsalehipour committed Oct 18, 2024
1 parent 3843188 commit e16ba24
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 37 deletions.
24 changes: 21 additions & 3 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ def setup_boundary_conditions(self):
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_sphere = HalfwayBounceBackBC(indices=sphere)

self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls]
# Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]
# Note: it is important to add bc_walls before bc_outlet/bc_inlet because
# of the corner nodes. This way the corners are treated as wall and not inlet/outlet.
# TODO: how to ensure about this behind in the src code?

def setup_boundary_masker(self):
indices_boundary_masker = IndicesBoundaryMasker(
Expand All @@ -105,6 +104,8 @@ def run(self, num_steps, post_process_interval=100):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i == 0:
self.check_boundary_mask()
if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)
end_time = time.time()
Expand Down Expand Up @@ -134,6 +135,23 @@ def post_process(self, i):

# save_fields_vtk(fields, timestep=i)
save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i)
return

def check_boundary_mask(self):
# Write the results. We'll use JAX backend for the post-processing
if not isinstance(self.f_0, jnp.ndarray):
bmask = wp.to_jax(self.bc_mask)[0]
else:
bmask = self.bc_mask[0]

# save_fields_vtk(fields, timestep=i)
save_image(bmask[0, :, :], prefix="00_left")
save_image(bmask[self.grid_shape[0] - 1, :, :], prefix="00_right")
save_image(bmask[:, :, self.grid_shape[2] - 1], prefix="00_top")
save_image(bmask[:, :, 0], prefix="00_bottom")
save_image(bmask[:, 0, :], prefix="00_front")
save_image(bmask[:, self.grid_shape[1] - 1, :], prefix="00_back")
save_image(bmask[:, self.grid_shape[1] // 2, :], prefix="00_middle")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_top, bc_walls]
self.boundary_conditions = [bc_walls, bc_top]

def setup_boundary_masker(self):
indices_boundary_masker = IndicesBoundaryMasker(
Expand Down
4 changes: 3 additions & 1 deletion examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def setup_boundary_conditions(self):
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
bc_car = GradsApproximationBC(mesh_vertices=car)
# bc_car = FullwayBounceBackBC(mesh_vertices=car)
self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car]
self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car]
# Note: it is important to add bc_walls before bc_outlet/bc_inlet because
# of the corner nodes. This way the corners are treated as wall and not inlet/outlet.

def setup_boundary_masker(self):
indices_boundary_masker = IndicesBoundaryMasker(
Expand Down
1 change: 0 additions & 1 deletion xlb/operator/boundary_condition/bc_grads_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
indices=None,
mesh_vertices=None,
):

# TODO: the input velocity must be suitably stored elesewhere when mesh is moving.
self.u = (0, 0, 0)

Expand Down
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from xlb import DefaultConfig
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry


# Enum for implementation step
class ImplementationStep(Enum):
COLLISION = auto()
Expand Down
46 changes: 18 additions & 28 deletions xlb/operator/boundary_masker/indices_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
start_index = (0,) * dim

domain_shape = bc_mask[0].shape
for bc in bclist:
for bc in reversed(bclist):
assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!"
assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!"
id_number = bc.id
Expand Down Expand Up @@ -103,6 +103,11 @@ def _construct_warp(self):
_c = self.velocity_set.c
_q = wp.constant(self.velocity_set.q)

@wp.func
def check_index_bounds(index: wp.vec3i, shape: wp.vec3i):
is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2]
return is_in_bounds

# Construct the warp 2D kernel
@wp.kernel
def kernel2d(
Expand Down Expand Up @@ -173,14 +178,8 @@ def kernel3d(
index[2] = indices[2, ii] - start_index[2]

# Check if index is in bounds
if (
index[0] >= 0
and index[0] < missing_mask.shape[1]
and index[1] >= 0
and index[1] < missing_mask.shape[2]
and index[2] >= 0
and index[2] < missing_mask.shape[3]
):
shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3])
if check_index_bounds(index, shape):
# Stream indices
for l in range(_q):
# Get the index of the streaming direction
Expand All @@ -195,27 +194,12 @@ def kernel3d(

# check if pull index is out of bound
# These directions will have missing information after streaming
if (
pull_index[0] < 0
or pull_index[0] >= missing_mask.shape[1]
or pull_index[1] < 0
or pull_index[1] >= missing_mask.shape[2]
or pull_index[2] < 0
or pull_index[2] >= missing_mask.shape[3]
):
if not check_index_bounds(pull_index, shape):
# Set the missing mask
missing_mask[l, index[0], index[1], index[2]] = True

# handling geometries in the interior of the computational domain
elif (
is_interior[ii]
and push_index[0] >= 0
and push_index[0] < missing_mask.shape[1]
and push_index[1] >= 0
and push_index[1] < missing_mask.shape[2]
and push_index[2] >= 0
and push_index[2] < missing_mask.shape[3]
):
elif check_index_bounds(pull_index, shape) and is_interior[ii]:
# Set the missing mask
missing_mask[l, push_index[0], push_index[1], push_index[2]] = True
bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii]
Expand All @@ -241,8 +225,14 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
# We are done with bc.indices. Remove them from BC objects
bc.__dict__.pop("indices", None)

indices = wp.array2d(index_list, dtype=wp.int32)
id_number = wp.array1d(id_list, dtype=wp.uint8)
# Remove duplicates indices to avoid race conditioning
index_arr, unique_loc = np.unique(index_list, axis=-1, return_index=True)
id_arr = np.array(id_list)[unique_loc]
is_interior = np.array(is_interior)[unique_loc]

# convert to warp arrays
indices = wp.array2d(index_arr, dtype=wp.int32)
id_number = wp.array1d(id_arr, dtype=wp.uint8)
is_interior = wp.array1d(is_interior, dtype=wp.bool)

if start_index is None:
Expand Down
8 changes: 5 additions & 3 deletions xlb/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def downsample_field(field, factor, method="bicubic"):
return jnp.stack(downsampled_components, axis=-1)


def save_image(fld, timestep, prefix=None):
def save_image(fld, timestep=None, prefix=None, **kwargs):
"""
Save an image of a field at a given timestep.
Expand Down Expand Up @@ -74,15 +74,17 @@ def save_image(fld, timestep, prefix=None):
else:
fname = prefix

fname = fname + "_" + str(timestep).zfill(4)
if timestep is not None:
fname = fname + "_" + str(timestep).zfill(4)

if len(fld.shape) > 3:
raise ValueError("The input field should be 2D!")
if len(fld.shape) == 3:
fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2)

plt.clf()
plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower")
kwargs.pop("cmap", None)
plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower", **kwargs)


def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"):
Expand Down

0 comments on commit e16ba24

Please sign in to comment.