Skip to content

Commit

Permalink
Extremely simplified BC implementation!
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Oct 8, 2024
1 parent d73a6d6 commit 59cbe8e
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 74 deletions.
2 changes: 0 additions & 2 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def __init__(
mesh_vertices,
)
# Overwrite the boundary condition registry id with the bc_type in the name
self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type)
# The operator to compute the momentum flux
self.momentum_flux = MomentumFlux()

@partial(jit, static_argnums=(0,), inline=True)
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def __init__(
mesh_vertices,
)

# Overwrite the boundary condition registry id with the bc_type in the name
self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type)

# Set the prescribed value for pressure or velocity
dim = self.velocity_set.d
if self.compute_backend == ComputeBackend.JAX:
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
indices=None,
mesh_vertices=None,
):
self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__)
self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + str(hash(self)))
velocity_set = velocity_set or DefaultConfig.velocity_set
precision_policy = precision_policy or DefaultConfig.default_precision_policy
compute_backend = compute_backend or DefaultConfig.default_backend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def register_boundary_condition(self, boundary_condition):
self.next_id += 1
self.id_to_bc[_id] = boundary_condition
self.bc_to_id[boundary_condition] = _id
print(f"registered bc {boundary_condition} with id {_id}")
return _id


Expand Down
94 changes: 26 additions & 68 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,65 +100,16 @@ def _construct_warp(self):
bc_to_id = boundary_condition_registry.bc_to_id
id_to_bc = boundary_condition_registry.id_to_bc

# Gather IDs of ExtrapolationOutflowBC boundary conditions
extrapolation_outflow_bc_ids = []
for bc_name, bc_id in bc_to_id.items():
if bc_name.startswith("ExtrapolationOutflowBC"):
extrapolation_outflow_bc_ids.append(bc_id)
# Group active boundary conditions
active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions)

for bc in self.boundary_conditions:
bc_name = id_to_bc[bc.id]
setattr(self, bc_name, bc)

@wp.func
def apply_post_streaming_bc(
index: Any,
timestep: Any,
_boundary_id: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
):
f_result = f_post

if wp.static("EquilibriumBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["EquilibriumBC"]):
f_result = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("DoNothingBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["DoNothingBC"]):
f_result = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("HalfwayBounceBackBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["HalfwayBounceBackBC"]):
f_result = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("ZouHeBC_pressure" in active_bcs):
if _boundary_id == wp.static(bc_to_id["ZouHeBC_pressure"]):
f_result = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("ZouHeBC_velocity" in active_bcs):
if _boundary_id == wp.static(bc_to_id["ZouHeBC_velocity"]):
f_result = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("RegularizedBC_pressure" in active_bcs):
if _boundary_id == wp.static(bc_to_id["RegularizedBC_pressure"]):
f_result = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("RegularizedBC_velocity" in active_bcs):
if _boundary_id == wp.static(bc_to_id["RegularizedBC_velocity"]):
f_result = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("ExtrapolationOutflowBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]):
f_result = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("GradsApproximationBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["GradsApproximationBC"]):
f_result = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

return f_result

@wp.func
def apply_post_collision_bc(
def apply_bc(
index: Any,
timestep: Any,
_boundary_id: Any,
Expand All @@ -167,17 +118,23 @@ def apply_post_collision_bc(
f_1: Any,
f_pre: Any,
f_post: Any,
is_post_streaming: bool,
):
f_result = f_post

if wp.static("FullwayBounceBackBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["FullwayBounceBackBC"]):
f_result = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

if wp.static("ExtrapolationOutflowBC" in active_bcs):
if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]):
f_result = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)

# Unroll the loop over boundary conditions
for i in range(wp.static(len(self.boundary_conditions))):
if is_post_streaming:
if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)
else:
if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)
if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
f_result = wp.static(self.boundary_conditions[i].prepare_bc_auxilary_data)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)
return f_result

@wp.func
Expand Down Expand Up @@ -244,7 +201,7 @@ def kernel2d(
_f_post_collision = _f0_thread

# Apply post-streaming boundary conditions
_f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream)
_f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True)

# Compute rho and u
_rho, _u = self.macroscopic.warp_functional(_f_post_stream)
Expand All @@ -256,7 +213,7 @@ def kernel2d(
_f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u)

# Apply post-collision boundary conditions
_f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision)
_f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False)

# Store the result in f_1
for l in range(self.velocity_set.q):
Expand Down Expand Up @@ -284,17 +241,18 @@ def kernel3d(
_f_post_collision = _f0_thread

# Apply post-streaming boundary conditions
_f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream)
_f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True)

_rho, _u = self.macroscopic.warp_functional(_f_post_stream)
_feq = self.equilibrium.warp_functional(_rho, _u)
_f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u)

# Apply post-collision boundary conditions
_f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision)
_f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False)

# Store the result in f_1
for l in range(self.velocity_set.q):
# TODO: Improve this later
if wp.static("GradsApproximationBC" in active_bcs):
if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]):
if _missing_mask[l] == wp.uint8(1):
Expand Down

0 comments on commit 59cbe8e

Please sign in to comment.