Skip to content

Commit

Permalink
Regularized (Warp) also completed and verified!
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Aug 16, 2024
1 parent 3d54ffa commit b15bb05
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
Expand Down Expand Up @@ -127,7 +127,7 @@ def post_process(self, i):
# Running the simulation
grid_shape = (512 // 2, 128 // 2, 128 // 2)
velocity_set = xlb.velocity_set.D3Q19()
backend = ComputeBackend.JAX
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32
omega = 1.6

Expand Down
51 changes: 51 additions & 0 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ def __init__(
# The operator to compute the momentum flux
self.momentum_flux = SecondMoment()

# helper function
def compute_qi(self):
# Qi = cc - cs^2*I
dim = self.velocity_set.d
Qi = self.velocity_set.cc
if dim == 3:
diagonal = (0, 3, 5)
offdiagonal = (1, 2, 4)
elif dim == 2:
diagonal = (0, 2)
offdiagonal = (1,)
else:
raise ValueError(f"dim = {dim} not supported")

# multiply off-diagonal elements by 2 because the Q tensor is symmetric
Qi[:, diagonal] += -1.0 / 3.0
Qi[:, offdiagonal] *= 2.0
return Qi

@partial(jit, static_argnums=(0,), inline=True)
def regularize_fpop(self, fpop, feq):
"""
Expand All @@ -82,6 +101,8 @@ def regularize_fpop(self, fpop, feq):
# Qi = cc - cs^2*I
dim = self.velocity_set.d
weights = self.velocity_set.w[(slice(None),) + (None,) * dim]
# TODO: if I use the following I get NaN ! figure out why!
# Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype)
Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype)
if dim == 3:
diagonal = (0, 3, 5)
Expand Down Expand Up @@ -142,10 +163,14 @@ def _construct_warp(self):

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi()))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = wp.float32(rho)
_u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1])
_opp_indices = self.velocity_set.wp_opp_indices
_w = self.velocity_set.wp_w
_c = self.velocity_set.wp_c
_c32 = self.velocity_set.wp_c32
# TODO: this is way less than ideal. we should not be making new types
Expand Down Expand Up @@ -192,6 +217,32 @@ def bounceback_nonequilibrium(
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = self.momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
QiPi1 = _f_vec()
for l in range(_q):
QiPi1[l] = 0.0
for t in range(nt):
QiPi1[l] += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = 9.0 / 2.0 * _w[l] * QiPi1[l]
fpop[l] = feq[l] + fpop1
return fpop

@wp.func
def functional3d_velocity(
f_pre: Any,
Expand Down

0 comments on commit b15bb05

Please sign in to comment.