Skip to content

Commit

Permalink
somewhat improved bc handling using structs
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Jul 30, 2024
1 parent 987d136 commit e2616bd
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 79 deletions.
61 changes: 45 additions & 16 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,23 @@ def _construct_warp(self):
self.velocity_set.q, dtype=wp.uint8
) # TODO fix vec bool

# Get the boundary condition ids
_equilibrium_bc = wp.uint8(self.equilibrium_bc.id)
_do_nothing_bc = wp.uint8(self.do_nothing_bc.id)
_halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id)
_fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id)
@wp.struct
class BoundaryConditionIDStruct:
# Note the names are hardcoded here based on various BC operator names with "id_" at the beginning
# One needs to manually add the names of additional BC's as they are added.
# TODO: Anyway to improve this
id_EquilibriumBC: wp.uint8
id_DoNothingBC: wp.uint8
id_HalfwayBounceBackBC: wp.uint8
id_FullwayBounceBackBC: wp.uint8

@wp.kernel
def kernel2d(
f_0: wp.array3d(dtype=Any),
f_1: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=Any),
missing_mask: wp.array3d(dtype=Any),
bc_struct: BoundaryConditionIDStruct,
timestep: int,
):
# Get the global index
Expand All @@ -124,20 +129,20 @@ def kernel2d(
_missing_mask[l] = wp.uint8(0)

# Apply streaming boundary conditions
if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc:
if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Regular streaming
f_post_stream = self.stream.warp_functional(f_0, index)
elif _boundary_id == _equilibrium_bc:
elif _boundary_id == bc_struct.id_EquilibriumBC:
# Equilibrium boundary condition
f_post_stream = self.equilibrium_bc.warp_functional(
f_0, _missing_mask, index
)
elif _boundary_id == _do_nothing_bc:
elif _boundary_id == bc_struct.id_DoNothingBC:
# Do nothing boundary condition
f_post_stream = self.do_nothing_bc.warp_functional(
f_0, _missing_mask, index
)
elif _boundary_id == _halfway_bounce_back_bc:
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(
f_0, _missing_mask, index
Expand All @@ -158,7 +163,7 @@ def kernel2d(
)

# Apply collision type boundary conditions
if _boundary_id == _fullway_bounce_back_bc:
if _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Full way boundary condition
f_post_collision = self.fullway_bounce_back_bc.warp_functional(
f_post_stream,
Expand All @@ -177,6 +182,7 @@ def kernel3d(
f_1: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=Any),
missing_mask: wp.array4d(dtype=Any),
bc_struct: BoundaryConditionIDStruct,
timestep: int,
):
# Get the global index
Expand All @@ -194,20 +200,20 @@ def kernel3d(
_missing_mask[l] = wp.uint8(0)

# Apply streaming boundary conditions
if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc:
if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Regular streaming
f_post_stream = self.stream.warp_functional(f_0, index)
elif _boundary_id == _equilibrium_bc:
elif _boundary_id == bc_struct.id_EquilibriumBC:
# Equilibrium boundary condition
f_post_stream = self.equilibrium_bc.warp_functional(
f_0, _missing_mask, index
)
elif _boundary_id == _do_nothing_bc:
elif _boundary_id == bc_struct.id_DoNothingBC:
# Do nothing boundary condition
f_post_stream = self.do_nothing_bc.warp_functional(
f_0, _missing_mask, index
)
elif _boundary_id == _halfway_bounce_back_bc:
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(
f_0, _missing_mask, index
Expand All @@ -223,7 +229,7 @@ def kernel3d(
f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u)

# Apply collision type boundary conditions
if _boundary_id == _fullway_bounce_back_bc:
if _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Full way boundary condition
f_post_collision = self.fullway_bounce_back_bc.warp_functional(
f_post_stream,
Expand All @@ -238,10 +244,32 @@ def kernel3d(
# Return the correct kernel
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return None, kernel
return BoundaryConditionIDStruct, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):

# Get the boundary condition ids
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
bc_to_id = boundary_condition_registry.bc_to_id

bc_struct = self.warp_functional()
bc_attribute_list = []
for bc in self.boundary_conditions:
# Setting the Struct attributes based on the BC class names
attribute_str = bc.__class__.__name__
setattr(bc_struct, 'id_' + attribute_str, bc_to_id[attribute_str])
bc_attribute_list.append('id_' + attribute_str)

# Unused attributes of the struct are set to inernal (id=0)
ll = vars(bc_struct)
for var in ll:
if var not in bc_attribute_list and not var.startswith('_'):
# set unassigned boundaries to the maximum integer in uint8
attribute_str = bc.__class__.__name__
setattr(bc_struct, var, 255)


# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -250,6 +278,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
f_1,
boundary_mask,
missing_mask,
bc_struct,
timestep,
],
dim=f_0.shape[1:],
Expand Down
86 changes: 23 additions & 63 deletions xlb/operator/stepper/stepper.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
# Base class for all stepper operators

from ast import Raise
from functools import partial
import jax.numpy as jnp
from jax import jit
import warp as wp

from xlb.operator.equilibrium.equilibrium import Equilibrium
from xlb.velocity_set import VelocitySet
from xlb.compute_backend import ComputeBackend
from xlb.operator import Operator
from xlb.operator.precision_caster import PrecisionCaster
from xlb.operator.equilibrium import Equilibrium
from xlb import DefaultConfig


Expand Down Expand Up @@ -59,65 +47,37 @@ def __init__(self, operators, boundary_conditions):
)

# Add boundary conditions
# Warp cannot handle lists of functions currently
# Because of this we manually unpack the boundary conditions
############################################
# Warp cannot handle lists of functions currently
# TODO: Fix this later
############################################
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import (
HalfwayBounceBackBC,
)
from xlb.operator.boundary_condition.bc_fullway_bounce_back import (
FullwayBounceBackBC,
)
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC


# Define a list of tuples with attribute names and their corresponding classes
conditions = [
("equilibrium_bc", EquilibriumBC),
("do_nothing_bc", DoNothingBC),
("halfway_bounce_back_bc", HalfwayBounceBackBC),
("fullway_bounce_back_bc", FullwayBounceBackBC),
]

# this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example.
bc_fallback = boundary_conditions[0]

self.equilibrium_bc = None
self.do_nothing_bc = None
self.halfway_bounce_back_bc = None
self.fullway_bounce_back_bc = None
# Iterate over each boundary condition
for attr_name, bc_class in conditions:
for bc in boundary_conditions:
if isinstance(bc, bc_class):
setattr(self, attr_name, bc)
break
elif not hasattr(self, attr_name):
setattr(self, attr_name, bc_fallback)

for bc in boundary_conditions:
if isinstance(bc, EquilibriumBC):
self.equilibrium_bc = bc
elif isinstance(bc, DoNothingBC):
self.do_nothing_bc = bc
elif isinstance(bc, HalfwayBounceBackBC):
self.halfway_bounce_back_bc = bc
elif isinstance(bc, FullwayBounceBackBC):
self.fullway_bounce_back_bc = bc

if self.equilibrium_bc is None:
# Select the equilibrium operator based on its type
self.equilibrium_bc = EquilibriumBC(
rho=1.0,
u=(0.0, 0.0, 0.0),
equilibrium_operator=next(
(op for op in self.operators if isinstance(op, Equilibrium)), None
),
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.do_nothing_bc is None:
self.do_nothing_bc = DoNothingBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.halfway_bounce_back_bc is None:
self.halfway_bounce_back_bc = HalfwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.fullway_bounce_back_bc is None:
self.fullway_bounce_back_bc = FullwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
############################################

# Initialize operator
Expand Down

0 comments on commit e2616bd

Please sign in to comment.