Skip to content

Commit

Permalink
Merge pull request #75 from mehdiataei/major-refactoring
Browse files Browse the repository at this point in the history
Marged 2D and 3D kernels in Warp
  • Loading branch information
hsalehipour authored Oct 21, 2024
2 parents cee77b9 + ed5f643 commit 55e2921
Show file tree
Hide file tree
Showing 30 changed files with 178 additions and 1,180 deletions.
3 changes: 2 additions & 1 deletion examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def run(self, num_steps, post_process_interval=100):
def post_process(self, i):
# Write the results. We'll use JAX backend for the post-processing
if not isinstance(self.f_0, jnp.ndarray):
f_0 = wp.to_jax(self.f_0)
# If the backend is warp, we need to drop the last dimension added by warp for 2D simulations
f_0 = wp.to_jax(self.f_0)[..., 0]
else:
f_0 = self.f_0

Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def define_boundary_indices(self):
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

# Load the mesh
stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl"
# Load the mesh (replace with your own mesh)
stl_filename = "../stl-files/DrivAer-Notchback.stl"
mesh = trimesh.load_mesh(stl_filename, process=False)
mesh_vertices = mesh.vertices

Expand Down
25 changes: 10 additions & 15 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
jax==0.4.20
jaxlib==0.4.20
matplotlib==3.8.0
numpy==1.26.1
pyvista==0.43.4
Rtree==1.0.1
trimesh==4.4.1
orbax-checkpoint==0.4.1
termcolor==2.3.0
PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git
tqdm==4.66.2
warp-lang==1.0.2
numpy-stl==3.1.1
pydantic==2.7.0
ruff==0.5.6
jax[cuda]
matplotlib
numpy
pyvista
Rtree
trimesh
warp-lang
numpy-stl
pydantic
ruff
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape):
f = f.numpy()
f_post = f_post.numpy()

assert f.shape == (velocity_set.q,) + grid_shape
assert f.shape == (velocity_set.q,) + grid_shape if dim == 3 else (velocity_set.q, grid_shape[0], grid_shape[1], 1)

# Assert that the values are correct in the indices of the sphere
weights = velocity_set.w
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):
bc_mask, missing_mask = indices_boundary_masker([fullway_bc], bc_mask, missing_mask, start_index=None)

# Generate a random field with the same shape
random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32)
if dim == 2:
random_field = np.random.rand(velocity_set.q, grid_shape[0], grid_shape[1], 1).astype(np.float32)
else:
random_field = np.random.rand(velocity_set.q, grid_shape[0], grid_shape[1], grid_shape[2]).astype(np.float32)
# Add the random field to f_pre
f_pre = wp.array(random_field)

Expand All @@ -71,7 +74,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):
f = f_pre.numpy()
f_post = f_post.numpy()

assert f.shape == (velocity_set.q,) + grid_shape
assert f.shape == (velocity_set.q,) + grid_shape if dim == 3 else (velocity_set.q, grid_shape[0], grid_shape[1], 1)

for i in range(velocity_set.q):
np.allclose(
Expand Down
7 changes: 4 additions & 3 deletions tests/grids/test_grid_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def test_warp_grid_create_field(grid_size):
init_xlb_env(xlb.velocity_set.D3Q19)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9, dtype=Precision.FP32)

assert f.shape == (9,) + grid_shape, "Field shape is incorrect"
if len(grid_shape) == 2:
assert f.shape == (9,) + grid_shape + (1,), "Field shape is incorrect got {}".format(f.shape)
else:
assert f.shape == (9,) + grid_shape, "Field shape is incorrect got {}".format(f.shape)
assert isinstance(f, wp.array), "Field should be a Warp ndarray"


Expand All @@ -37,7 +39,6 @@ def test_warp_grid_create_field_fill_value():
assert isinstance(f, wp.array), "Field should be a Warp ndarray"

f = f.numpy()
assert f.shape == (9,) + grid_shape, "Field shape is incorrect"
assert np.allclose(f, fill_value), "Field not properly initialized with fill_value"


Expand Down
7 changes: 5 additions & 2 deletions tests/kernels/stream/test_stream_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape):
expected = jnp.stack(expected, axis=0)

if dim == 2:
f_initial_warp = wp.array(f_initial)
f_initial_warp = wp.array(f_initial[..., np.newaxis])

elif dim == 3:
f_initial_warp = wp.array(f_initial)
Expand All @@ -71,7 +71,10 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape):
f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q)
f_streamed = stream_op(f_initial_warp, f_streamed)

assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected"
if len(grid_shape) == 2:
assert jnp.allclose(f_streamed.numpy()[..., 0], np.array(expected)), "Streaming did not occur as expected"
else:
assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected"


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion xlb/grid/warp_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def create_field(
fill_value=None,
):
dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype
shape = (cardinality,) + (self.shape)

# Check if shape is 2D, and if so, append a singleton dimension to the shape
shape = (cardinality,) + (self.shape if len(self.shape) != 2 else self.shape + (1,))

if fill_value is None:
f = wp.zeros(shape, dtype=dtype)
Expand Down
55 changes: 2 additions & 53 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,58 +64,7 @@ def functional(
):
return f_pre

@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
bc_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.uint8),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
def kernel3d(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
kernel = self._construct_kernel(functional)

return functional, kernel

Expand All @@ -127,4 +76,4 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f_post
return f_post
57 changes: 3 additions & 54 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,59 +88,8 @@ def functional(
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f

# Construct the warp kernel
@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
bc_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
def kernel3d(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
# Use the parent class's kernel and pass the functional
kernel = self._construct_kernel(functional)

return functional, kernel

Expand All @@ -152,4 +101,4 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f_post
return f_post
Loading

0 comments on commit 55e2921

Please sign in to comment.