Skip to content

Commit

Permalink
Merge pull request #57 from hsalehipour/major-refactoring
Browse files Browse the repository at this point in the history
Missing mask in JAX and Warp
  • Loading branch information
mehdiataei authored Aug 8, 2024
2 parents 329bd4c + 27e5d66 commit 0539f61
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 199 deletions.
6 changes: 3 additions & 3 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from xlb.helper import create_nse_fields, initialize_eq
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC
from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import warp as wp
Expand Down Expand Up @@ -48,7 +48,7 @@ def define_boundary_indices(self):
def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_top, bc_walls]

def setup_boundary_masks(self):
Expand Down Expand Up @@ -99,7 +99,7 @@ def post_process(self, i):
# Running the simulation
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32
omega = 1.6
Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def run(self, num_steps, print_interval, post_process_interval=100):
elapsed_time = time.time() - start_time
print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s")

if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)
if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)

def post_process(self, i):
# Write the results. We'll use JAX backend for the post-processing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape):
cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0
) # Arbitrary value so that we can check if the values are changed outside the boundary

f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask, f)
f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask)

f = f.numpy()
f_post = f_post.numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xlb.compute_backend import ComputeBackend
from xlb.grid import grid_factory
from xlb import DefaultConfig
from xlb.operator.boundary_masker import IndicesBoundaryMasker


def init_xlb_env(velocity_set):
Expand Down Expand Up @@ -35,7 +36,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):

boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8)

indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker()
indices_boundary_masker = IndicesBoundaryMasker()

# Make indices for boundary conditions (sphere)
sphere_radius = grid_shape[0] // 4
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):
cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0
) # Arbitrary value so that we can check if the values are changed outside the boundary

f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask, f_pre)
f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask)

f = f_pre.numpy()
f_post = f_post.numpy()
Expand Down
59 changes: 25 additions & 34 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,43 +58,34 @@ def _construct_warp(self):
# Construct the funcional to get streamed indices

@wp.func
def functional2d(
f: wp.array3d(dtype=Any),
def functional(
f_pre: Any,
f_post: Any,
missing_mask: Any,
index: Any,
):
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
return _f

@wp.func
def functional3d(
f: wp.array4d(dtype=Any),
missing_mask: Any,
index: Any,
):
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
return _f
return f_pre

@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.uint8),
f: wp.array3d(dtype=Any),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -103,15 +94,13 @@ def kernel2d(

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
_f = functional3d(f_pre, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = _f[l]

# Construct the warp kernel
@wp.kernel
Expand All @@ -120,16 +109,21 @@ def kernel3d(
f_post: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
f: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -138,27 +132,24 @@ def kernel3d(

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
_f = functional3d(f_pre, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1], index[2]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = _f[l]

functional = functional3d if self.velocity_set.d == 3 else functional2d
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f):
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, boundary_mask, missing_mask, f],
inputs=[f_pre, f_post, boundary_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f
return f_post
52 changes: 24 additions & 28 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def _construct_warp(self):

# Construct the funcional to get streamed indices
@wp.func
def functional2d(
f: wp.array3d(dtype=Any),
def functional(
f_pre: Any,
f_post: Any,
missing_mask: Any,
index: Any,
):
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f
Expand All @@ -93,16 +93,21 @@ def kernel2d(
f_post: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
f: wp.array3d(dtype=Any),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -111,24 +116,13 @@ def kernel2d(

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
_f = functional2d(f_post, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = _f[l]

@wp.func
def functional3d(
f: wp.array4d(dtype=Any),
missing_mask: Any,
index: Any,
):
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f
f_post[l, index[0], index[1]] = _f[l]

# Construct the warp kernel
@wp.kernel
Expand All @@ -137,16 +131,21 @@ def kernel3d(
f_post: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
f: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -155,27 +154,24 @@ def kernel3d(

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
_f = functional3d(f_post, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1], index[2]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = _f[l]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
functional = functional3d if self.velocity_set.d == 3 else functional2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f):
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, boundary_mask, missing_mask, f],
inputs=[f_pre, f_post, boundary_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f
return f_post
14 changes: 7 additions & 7 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def kernel2d(
f_post: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
f: wp.array3d(dtype=Any),
): # Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)
Expand All @@ -88,6 +87,7 @@ def kernel2d(
_f_post = _f_vec()
_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

Expand All @@ -105,7 +105,7 @@ def kernel2d(

# Write the result to the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = _f[l]

# Construct the warp kernel
@wp.kernel
Expand All @@ -114,7 +114,6 @@ def kernel3d(
f_post: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
f: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -128,6 +127,7 @@ def kernel3d(
_f_post = _f_vec()
_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

Expand All @@ -145,18 +145,18 @@ def kernel3d(

# Write the result to the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = _f[l]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f):
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, boundary_mask, missing_mask, f],
inputs=[f_pre, f_post, boundary_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f
return f_post
Loading

0 comments on commit 0539f61

Please sign in to comment.