diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py deleted file mode 100644 index 1684266..0000000 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ /dev/null @@ -1,203 +0,0 @@ -# Simple flow past sphere example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import numpy as np - -from xlb.compute_backend import ComputeBackend - -import warp as wp - -import xlb - -xlb.init( - default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=xlb.velocity_set.D2Q9, -) - - -from xlb.operator import Operator - - -class UniformInitializer(Operator): - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - ): - # Get the global index - i, j, k = wp.tid() - - # Set the velocity - u[0, i, j, k] = vel - u[1, i, j, k] = 0.0 - u[2, i, j, k] = 0.0 - - # Set the density - rho[0, i, j, k] = 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - ], - dim=rho.shape[1:], - ) - return rho, u - - -if __name__ == "__main__": - # Set parameters - compute_backend = xlb.ComputeBackend.WARP - precision_policy = xlb.PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D3Q19() - - # Make feilds - nr = 256 - vel = 0.05 - shape = (nr, nr, nr) - grid = xlb.grid.grid_factory(shape=shape) - rho = grid.create_field(cardinality=1) - u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) - f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - bc_mask = grid.create_field(cardinality=1, dtype=wp.uint8) - missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) - - # Make operators - initializer = UniformInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - collision = xlb.operator.collision.BGK( - omega=1.95, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=1.0, - u=(vel, 0.0, 0.0), - equilibrium_operator=equilibrium, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream, - equilibrium_bc=equilibrium_bc, - do_nothing_bc=do_nothing_bc, - half_way_bc=half_way_bc, - ) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - - # Make indices for boundary conditions (sphere) - sphere_radius = 32 - x = np.arange(nr) - y = np.arange(nr) - z = np.arange(nr) - X, Y, Z = np.meshgrid(x, y, z) - indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) - indices = np.array(indices).T - indices = wp.from_numpy(indices, dtype=wp.int32) - - # Set boundary conditions on the indices - bc_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set inlet bc - lower_bound = (0, 0, 0) - upper_bound = (0, nr, nr) - direction = (1, 0, 0) - bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set outlet bc - lower_bound = (nr - 1, 0, 0) - upper_bound = (nr - 1, nr, nr) - direction = (-1, 0, 0) - bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set initial conditions - rho, u = initializer(rho, u, vel) - f0 = equilibrium(rho, u, f0) - - # Time stepping - plot_freq = 512 - save_dir = "flow_past_sphere" - os.makedirs(save_dir, exist_ok=True) - # compute_mlup = False # Plotting results - compute_mlup = True - num_steps = 1024 * 8 - start = time.time() - for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, bc_mask, missing_mask, _) - f1, f0 = f0, f1 - if (_ % plot_freq == 0) and (not compute_mlup): - rho, u = macroscopic(f0, rho, u) - - # Plot the velocity field and boundary id side by side - plt.subplot(1, 2, 1) - plt.imshow(u[0, :, nr // 2, :].numpy()) - plt.colorbar() - plt.subplot(1, 2, 2) - plt.imshow(bc_mask[0, :, nr // 2, :].numpy()) - plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") - plt.close() - - wp.synchronize() - end = time.time() - - # Print MLUPS - print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py deleted file mode 100644 index 846ba30..0000000 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ /dev/null @@ -1,181 +0,0 @@ -# Simple Taylor green example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import jax.numpy as jnp -import warp as wp - -wp.init() - -import xlb -from xlb.operator import Operator - - -class TaylorGreenInitializer(Operator): - """ - Initialize the Taylor-Green vortex. - """ - - @Operator.register_backend(xlb.ComputeBackend.JAX) - # @partial(jit, static_argnums=(0)) - def jax_implementation(self, vel, nr): - # Make meshgrid - x = jnp.linspace(0, 2 * jnp.pi, nr) - y = jnp.linspace(0, 2 * jnp.pi, nr) - z = jnp.linspace(0, 2 * jnp.pi, nr) - X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij") - - # Compute u - u = jnp.stack( - [ - vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), - -vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), - jnp.zeros_like(X), - ], - axis=0, - ) - - # Compute rho - rho = 3.0 * vel * vel * (1.0 / 16.0) * (jnp.cos(2.0 * X) + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0))) + 1.0 - rho = jnp.expand_dims(rho, axis=0) - - return rho, u - - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - nr: int, - ): - # Get the global index - i, j, k = wp.tid() - - # Get real pos - x = 2.0 * wp.pi * wp.float(i) / wp.float(nr) - y = 2.0 * wp.pi * wp.float(j) / wp.float(nr) - z = 2.0 * wp.pi * wp.float(k) / wp.float(nr) - - # Compute u - u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1, i, j, k] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) - u[2, i, j, k] = 0.0 - - # Compute rho - rho[0, i, j, k] = 3.0 * vel * vel * (1.0 / 16.0) * (wp.cos(2.0 * x) + (wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0))) + 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel, nr): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - nr, - ], - dim=rho.shape[1:], - ) - return rho, u - - -def run_taylor_green(backend, compute_mlup=True): - # Set the compute backend - if backend == "warp": - compute_backend = xlb.ComputeBackend.WARP - elif backend == "jax": - compute_backend = xlb.ComputeBackend.JAX - - # Set the precision policy - precision_policy = xlb.PrecisionPolicy.FP32FP32 - - # Set the velocity set - velocity_set = xlb.velocity_set.D3Q19() - - # Make grid - nr = 128 - shape = (nr, nr, nr) - if backend == "jax": - grid = xlb.grid.JaxGrid(shape=shape) - elif backend == "warp": - grid = xlb.grid.WarpGrid(shape=shape) - - # Make feilds - rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) - u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) - f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - bc_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) - - # Make operators - initializer = TaylorGreenInitializer(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - collision = xlb.operator.collision.BGK(omega=1.9, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend - ) - macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - stream = xlb.operator.stream.Stream(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream - ) - - # Parrallelize the stepper TODO: Add this functionality - # stepper = grid.parallelize_operator(stepper) - - # Set initial conditions - if backend == "warp": - rho, u = initializer(rho, u, 0.1, nr) - f0 = equilibrium(rho, u, f0) - elif backend == "jax": - rho, u = initializer(0.1, nr) - f0 = equilibrium(rho, u) - - # Time stepping - plot_freq = 32 - save_dir = "taylor_green" - os.makedirs(save_dir, exist_ok=True) - num_steps = 8192 - start = time.time() - - for _ in tqdm(range(num_steps)): - # Time step - if backend == "warp": - f1 = stepper(f0, f1, bc_mask, missing_mask, _) - f1, f0 = f0, f1 - elif backend == "jax": - f0 = stepper(f0, bc_mask, missing_mask, _) - - # Plot if needed - if (_ % plot_freq == 0) and (not compute_mlup): - if backend == "warp": - rho, u = macroscopic(f0, rho, u) - local_u = u.numpy() - elif backend == "jax": - rho, local_u = macroscopic(f0) - - plt.imshow(local_u[0, :, nr // 2, :]) - plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") - plt.close() - wp.synchronize() - end = time.time() - - # Print MLUPS - print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") - - -if __name__ == "__main__": - # Run Taylor-Green vortex on different backends - backends = ["warp", "jax"] - # backends = ["jax"] - for backend in backends: - run_taylor_green(backend, compute_mlup=True) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 55ce9ed..4b7ac90 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -27,8 +27,6 @@ class DoNothingBC(BoundaryCondition): boundary nodes. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 8c33d29..4dd4b9e 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -29,8 +29,6 @@ class EquilibriumBC(BoundaryCondition): Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, rho: float, diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 8b5f139..53645c6 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -38,8 +38,6 @@ class ExtrapolationOutflowBC(BoundaryCondition): doi:10.1016/j.camwa.2015.05.001. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 29f83c1..8569e84 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -27,8 +27,6 @@ class FullwayBounceBackBC(BoundaryCondition): Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 870635e..3d60879 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -39,8 +39,6 @@ class GradsApproximationBC(BoundaryCondition): """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, @@ -49,6 +47,7 @@ def __init__( indices=None, mesh_vertices=None, ): + # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. self.u = (0, 0, 0) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 6e787c2..e8df6b7 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -29,8 +29,6 @@ class HalfwayBounceBackBC(BoundaryCondition): TODO: Implement moving boundary conditions for this """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index bb4b5f0..065a0b0 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -61,8 +61,7 @@ def __init__( indices, mesh_vertices, ) - - # The operator to compute the momentum flux + # Overwrite the boundary condition registry id with the bc_type in the name self.momentum_flux = MomentumFlux() @partial(jit, static_argnums=(0,), inline=True) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 66b6377..4be2cf2 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -48,7 +48,6 @@ def __init__( # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." - self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__ + "_" + bc_type) self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() self.prescribed_value = prescribed_value diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index be920bf..6d72fc0 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -13,7 +13,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from xlb import DefaultConfig - +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry # Enum for implementation step class ImplementationStep(Enum): @@ -35,6 +35,7 @@ def __init__( indices=None, mesh_vertices=None, ): + self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + str(hash(self))) velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy compute_backend = compute_backend or DefaultConfig.default_backend diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 5b1e092..6238fc5 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -23,6 +23,7 @@ def register_boundary_condition(self, boundary_condition): self.next_id += 1 self.id_to_bc[_id] = boundary_condition self.bc_to_id[boundary_condition] = _id + print(f"registered bc {boundary_condition} with id {_id}") return _id diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 62790a6..99431eb 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -14,7 +14,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep -from xlb.operator.boundary_condition import DoNothingBC as DummyBC +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.collision import ForcedCollision @@ -91,92 +91,53 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): return f_0, f_1 def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update + # Set local constants _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) _opp_indices = self.velocity_set.opp_indices - @wp.struct - class BoundaryConditionIDStruct: - # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning - # One needs to manually add the names of additional BC's as they are added. - # TODO: Any way to improve this? - id_EquilibriumBC: wp.uint8 - id_DoNothingBC: wp.uint8 - id_HalfwayBounceBackBC: wp.uint8 - id_FullwayBounceBackBC: wp.uint8 - id_ZouHeBC_velocity: wp.uint8 - id_ZouHeBC_pressure: wp.uint8 - id_RegularizedBC_velocity: wp.uint8 - id_RegularizedBC_pressure: wp.uint8 - id_ExtrapolationOutflowBC: wp.uint8 - id_GradsApproximationBC: wp.uint8 + # Read the list of bc_to_id created upon instantiation + bc_to_id = boundary_condition_registry.bc_to_id + id_to_bc = boundary_condition_registry.id_to_bc - @wp.func - def apply_post_streaming_bc( - index: Any, - timestep: Any, - _boundary_id: Any, - bc_struct: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # Apply post-streaming type boundary conditions - # NOTE: 'f_pre' is included here as an input to the BC functionals for consistency with the BC API, - # particularly when compared to post-collision boundary conditions (see below). - - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ZouHeBC_velocity: - # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ZouHeBC_pressure: - # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_RegularizedBC_velocity: - # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_RegularizedBC_pressure: - # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_GradsApproximationBC: - # Reformulated Grads boundary condition - f_post = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post + # Gather IDs of ExtrapolationOutflowBC boundary conditions + extrapolation_outflow_bc_ids = [] + for bc_name, bc_id in bc_to_id.items(): + if bc_name.startswith("ExtrapolationOutflowBC"): + extrapolation_outflow_bc_ids.append(bc_id) + # Group active boundary conditions + active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) @wp.func - def apply_post_collision_bc( + def apply_bc( index: Any, timestep: Any, _boundary_id: Any, - bc_struct: Any, missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, f_post: Any, + is_post_streaming: bool, ): - # Apply post-collision type boundary conditions or special boundary preparations - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # Storing post-streaming data in directions that leave the domain - f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post + f_result = f_post + + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if is_post_streaming: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + else: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].prepare_bc_auxilary_data)( + index, timestep, missing_mask, f_0, f_1, f_pre, f_post + ) + return f_result @wp.func def get_thread_data_2d( @@ -186,17 +147,17 @@ def get_thread_data_2d( index: Any, ): # Read thread data for populations and missing mask - f0_thread = _f_vec() - f1_thread = _f_vec() + _f0_thread = _f_vec() + _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): - f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) - f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) + _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) + _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f0_thread, f1_thread, _missing_mask + return _f0_thread, _f1_thread, _missing_mask @wp.func def get_thread_data_3d( @@ -206,19 +167,19 @@ def get_thread_data_3d( index: Any, ): # Read thread data for populations - f0_thread = _f_vec() - f1_thread = _f_vec() + _f0_thread = _f_vec() + _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) - f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) + _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) + _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f0_thread, f1_thread, _missing_mask + return _f0_thread, _f1_thread, _missing_mask @wp.kernel def kernel2d( @@ -226,27 +187,23 @@ def kernel2d( f_1: wp.array3d(dtype=Any), bc_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: Any, timestep: int, ): - # Get the global index i, j = wp.tid() - index = wp.vec2i(i, j) # TODO warp should fix this + index = wp.vec2i(i, j) - # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1]] if _boundary_id == wp.uint8(255): return - # Apply streaming (pull method) + # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming type boundary conditions - f0_thread, f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) - _f_post_collision = f0_thread - _f_post_stream = apply_post_streaming_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream - ) + _f0_thread, _f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) + _f_post_collision = _f0_thread + + # Apply post-streaming boundary conditions + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -257,119 +214,63 @@ def kernel2d( # Apply collision _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision - ) + # Apply post-collision boundary conditions + _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) - # Set the output + # Store the result in f_1 for l in range(self.velocity_set.q): f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) - # Construct the kernel @wp.kernel def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), - bc_struct: Any, timestep: int, ): - # Get the global index i, j, k = wp.tid() - index = wp.vec3i(i, j, k) # TODO warp should fix this + index = wp.vec3i(i, j, k) - # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1], index[2]] if _boundary_id == wp.uint8(255): return - # Apply streaming (pull method) + # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming type boundary conditions - f0_thread, f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) - _f_post_collision = f0_thread - _f_post_stream = apply_post_streaming_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream - ) + _f0_thread, _f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) + _f_post_collision = _f0_thread - # Compute rho and u - _rho, _u = self.macroscopic.warp_functional(_f_post_stream) + # Apply post-streaming boundary conditions + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) - # Compute equilibrium + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) - - # Apply collision _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision - ) + # Apply post-collision boundary conditions + _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) - # Set the output + # Store the result in f_1 for l in range(self.velocity_set.q): - # TODO 1: fix the perf drop due to l324-l236 even in cases where this BC is not used. - # TODO 2: is there better way to move these lines to a function inside BC class like "restore_bc_data" - # if _boundary_id == bc_struct.id_GradsApproximationBC: - # if _missing_mask[l] == wp.uint8(1): - # f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) + # TODO: Improve this later + if wp.static("GradsApproximationBC" in active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): + if _missing_mask[l] == wp.uint8(1): + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return BoundaryConditionIDStruct, kernel + return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): - # Get the boundary condition ids - from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - # Read the list of bc_to_id created upon instantiation - bc_to_id = boundary_condition_registry.bc_to_id - id_to_bc = boundary_condition_registry.id_to_bc - bc_struct = self.warp_functional() - active_bc_list = [] - for bc in self.boundary_conditions: - # Setting the Struct attributes and active BC classes based on the BC class names - bc_name = id_to_bc[bc.id] - setattr(self, bc_name, bc) - setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) - active_bc_list.append("id_" + bc_name) - - # Check if boundary_conditions is an empty list (e.g. all periodic and no BC) - # TODO: There is a huge issue here with perf. when boundary_conditions list - # is empty and is initialized with a dummy BC. If it is not empty, no perf - # loss ocurrs. The following code at least prevents syntax error for periodic examples. - if self.boundary_conditions: - bc_dummy = self.boundary_conditions[0] - else: - bc_dummy = DummyBC() - - # Setting the Struct attributes for inactive BC classes - for var in vars(bc_struct): - if var not in active_bc_list and not var.startswith("_"): - # set unassigned boundaries to the maximum integer in uint8 - setattr(bc_struct, var, 255) - - # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not - # produce error when a particular BC is not used in an example. - setattr(self, var.replace("id_", ""), bc_dummy) - - # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[ - f_0, - f_1, - bc_mask, - missing_mask, - bc_struct, - timestep, - ], + inputs=[f_0, f_1, bc_mask, missing_mask, timestep], dim=f_0.shape[1:], ) return f_0, f_1