From e2616bd876b4ac181ed5200401bf3939170e5c56 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 30 Jul 2024 00:04:49 -0400 Subject: [PATCH] somewhat improved bc handling using structs --- xlb/operator/stepper/nse_stepper.py | 61 ++++++++++++++------ xlb/operator/stepper/stepper.py | 86 ++++++++--------------------- 2 files changed, 68 insertions(+), 79 deletions(-) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 11b5615..29ad088 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -95,11 +95,15 @@ def _construct_warp(self): self.velocity_set.q, dtype=wp.uint8 ) # TODO fix vec bool - # Get the boundary condition ids - _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) - _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) - _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) - _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) + @wp.struct + class BoundaryConditionIDStruct: + # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning + # One needs to manually add the names of additional BC's as they are added. + # TODO: Anyway to improve this + id_EquilibriumBC: wp.uint8 + id_DoNothingBC: wp.uint8 + id_HalfwayBounceBackBC: wp.uint8 + id_FullwayBounceBackBC: wp.uint8 @wp.kernel def kernel2d( @@ -107,6 +111,7 @@ def kernel2d( f_1: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -124,20 +129,20 @@ def kernel2d( _missing_mask[l] = wp.uint8(0) # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC: # Regular streaming f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + elif _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional( f_0, _missing_mask, index @@ -158,7 +163,7 @@ def kernel2d( ) # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -177,6 +182,7 @@ def kernel3d( f_1: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -194,20 +200,20 @@ def kernel3d( _missing_mask[l] = wp.uint8(0) # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC: # Regular streaming f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + elif _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional( f_0, _missing_mask, index @@ -223,7 +229,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -238,10 +244,32 @@ def kernel3d( # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel + return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + + # Get the boundary condition ids + from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + bc_to_id = boundary_condition_registry.bc_to_id + + bc_struct = self.warp_functional() + bc_attribute_list = [] + for bc in self.boundary_conditions: + # Setting the Struct attributes based on the BC class names + attribute_str = bc.__class__.__name__ + setattr(bc_struct, 'id_' + attribute_str, bc_to_id[attribute_str]) + bc_attribute_list.append('id_' + attribute_str) + + # Unused attributes of the struct are set to inernal (id=0) + ll = vars(bc_struct) + for var in ll: + if var not in bc_attribute_list and not var.startswith('_'): + # set unassigned boundaries to the maximum integer in uint8 + attribute_str = bc.__class__.__name__ + setattr(bc_struct, var, 255) + + # Launch the warp kernel wp.launch( self.warp_kernel, @@ -250,6 +278,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_1, boundary_mask, missing_mask, + bc_struct, timestep, ], dim=f_0.shape[1:], diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c11b39b..fca088e 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,17 +1,5 @@ # Base class for all stepper operators - -from ast import Raise -from functools import partial -import jax.numpy as jnp -from jax import jit -import warp as wp - -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.precision_caster import PrecisionCaster -from xlb.operator.equilibrium import Equilibrium from xlb import DefaultConfig @@ -59,65 +47,37 @@ def __init__(self, operators, boundary_conditions): ) # Add boundary conditions - # Warp cannot handle lists of functions currently - # Because of this we manually unpack the boundary conditions ############################################ + # Warp cannot handle lists of functions currently # TODO: Fix this later ############################################ from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import ( - HalfwayBounceBackBC, - ) - from xlb.operator.boundary_condition.bc_fullway_bounce_back import ( - FullwayBounceBackBC, - ) + from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + + + # Define a list of tuples with attribute names and their corresponding classes + conditions = [ + ("equilibrium_bc", EquilibriumBC), + ("do_nothing_bc", DoNothingBC), + ("halfway_bounce_back_bc", HalfwayBounceBackBC), + ("fullway_bounce_back_bc", FullwayBounceBackBC), + ] + + # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. + bc_fallback = boundary_conditions[0] - self.equilibrium_bc = None - self.do_nothing_bc = None - self.halfway_bounce_back_bc = None - self.fullway_bounce_back_bc = None + # Iterate over each boundary condition + for attr_name, bc_class in conditions: + for bc in boundary_conditions: + if isinstance(bc, bc_class): + setattr(self, attr_name, bc) + break + elif not hasattr(self, attr_name): + setattr(self, attr_name, bc_fallback) - for bc in boundary_conditions: - if isinstance(bc, EquilibriumBC): - self.equilibrium_bc = bc - elif isinstance(bc, DoNothingBC): - self.do_nothing_bc = bc - elif isinstance(bc, HalfwayBounceBackBC): - self.halfway_bounce_back_bc = bc - elif isinstance(bc, FullwayBounceBackBC): - self.fullway_bounce_back_bc = bc - if self.equilibrium_bc is None: - # Select the equilibrium operator based on its type - self.equilibrium_bc = EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=next( - (op for op in self.operators if isinstance(op, Equilibrium)), None - ), - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.do_nothing_bc is None: - self.do_nothing_bc = DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.halfway_bounce_back_bc is None: - self.halfway_bounce_back_bc = HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.fullway_bounce_back_bc is None: - self.fullway_bounce_back_bc = FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) ############################################ # Initialize operator