Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added abstraction layer for boundary condition and the capability to add profiles to boundary conditions #86

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mehdiataei
Copy link
Contributor

@mehdiataei mehdiataei commented Oct 31, 2024

Contributing Guidelines

Description

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

How Has This Been Tested?

  • All pytest tests pass

Linting and Code Formatting

Make sure the code follows the project's linting and formatting standards. This project uses Ruff for linting.

To run Ruff, execute the following command from the root of the repository:

ruff check .
  • Ruff passes

@mehdiataei mehdiataei marked this pull request as ready for review October 31, 2024 16:56
@mehdiataei mehdiataei changed the title Profile inlet Added abstraction layer for boundary condition and the capability to add profiles to boundary conditions Oct 31, 2024
@mehdiataei mehdiataei force-pushed the profile_inlet branch 5 times, most recently from fb3fdb1 to f0f81c3 Compare November 4, 2024 01:39
Copy link
Collaborator

@hsalehipour hsalehipour left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This is very helpful. I have made some comments which I think will improve this work.

@@ -69,7 +72,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a 1-to-1 association between a bc_profile function and a bc object ? Right now, it is not straightforward for example to have two BC's with different profiles. If we add functional to the BC constructor then we would need another operator (like boundary masker) which goes over all BC's and initializes them similar to what initialize_bc_aux_data does right now. This could be done as the first step of the stepper (if iteration is zero). This way we can also remove initilize_bc_aux_data from examples.

Also with these changes, it is kind of difficult to assign constant values to a BC. Is there a simpler way to pass in constant values to BCs without having to define a bc_prfile with jax and warp implementations?

def run(self, num_steps, post_process_interval=100):
start_time = time.time()
self.initialize_bc_aux_data()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with some minor changes you will be able to make this a stand alone operator (like boundary_masker).

We can either call that operator inside the stepper stepper (if iteration=0) or leave it to the user to call it in the example.

@@ -95,7 +95,6 @@ def define_boundary_indices(self):
def setup_boundary_conditions(self):
inlet, outlet, walls, car = self.define_boundary_indices()
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really nice if we could define BCs like before (simply as constant inputs) as well as with a function as an input to the bc constructor: eg u = (0.04, 0, 0) or u = bc_prfile(). Something like this. I think it could be done in the operator I am proposing for initializing BC aux data.

@@ -222,6 +216,15 @@ def functional_velocity(
# Find normal vector
normals = get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there not a performance hit here? The BC is not dynamic like Extrapolation Outflow where its information change in time. Here we have the same value of velocity or pressure being recomputed at each iteration over and over again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not recomputing here we're simply reading it. I didn't notice any performance hit before and after the PR.

An improvement for this could be that we can store the missing direction and avoid the for loop for slightly improved performance. I'll try to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this only works for linear/planar BCs like zouhe)

@@ -158,32 +181,29 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask):

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure this is consistent naming for all other jax implementations of all BCs.

Comment on lines +156 to +163
# Get the density and velocity from the f_1
if self.bc_type == "velocity":
vel = self.prescribed_values
elif self.bc_type == "pressure":
rho = self.prescribed_values
vel = self.calculate_vel(f_1, rho, missing_mask)
else:
raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is redundant.

@@ -95,7 +95,6 @@ def define_boundary_indices(self):
def setup_boundary_conditions(self):
inlet, outlet, walls, car = self.define_boundary_indices()
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if velocity here is also not constant and has a profile. All these boundary conditions should be set up similarly. It would be nice if we could input the profile to the BC constrcutor.

@@ -69,7 +72,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
bc_left = RegularizedBC("velocity", indices=inlet)
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equilibrium BC should also be able to accept variable values. Can you add that as well ?

@@ -49,7 +50,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_top = EquilibriumBC(rho=1.0, u=(self.prescribed_vel, 0.0), indices=lid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for EquilibriumBC in all examples.

# find and store the normal vector using indices
self._get_normal_vec(indices)

# Unpack the two warp functionals needed for this BC!
if self.compute_backend == ComputeBackend.WARP:
self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional
self.warp_functional, self.update_bc_auxilary_data = self.warp_functional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this renaming.

@mehdiataei
Copy link
Contributor Author

Thanks for the PR. This is very helpful. I have made some comments which I think will improve this work.

Thanks for the comments. I'll take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Non-uniform boundary velocity profile
2 participants