Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neigbors for >128K particles #101

Open
RaulPPelaez opened this issue May 5, 2023 · 1 comment
Open

Neigbors for >128K particles #101

RaulPPelaez opened this issue May 5, 2023 · 1 comment

Comments

@RaulPPelaez
Copy link
Contributor

The index here overflows with 128K particles (a thread is launched per possible pair, N*(N-1)/2)

const int32_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= num_all_pairs) return;
int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2);

I know it sounds like an absurd regime for such a brute force approach, but it turns out that this kernel remains a
competitive option in some situations depending on the average number of neighbors per particle, for instance.

I believe there should be at least a TORCH_CHECK around.

Switching to int64_t virtually eliminates this problem, but there is a big performance penalty and the maximum number of blocks a CUDA kernel can take becomes a problem soon after.

OTOH, the way the row index is computed results in incorrect results due to floating point error at ~100k particles, even when switching to int64 and double.

I suspect this kernel will be defeated by this other approach at some point:
https://developer.nvidia.com/gpugems/gpugems3/part-v-physics-simulation/chapter-31-fast-n-body-simulation-cuda

Which does not take too much time to implement and does not suffer from the aforementioned issues, since it launches an O(N) number of threads. Maybe it would be worth to have the two options and decide on the algorithm at runtime depending on some heuristic?

Additionally, this line:

const Tensor indices = arange(0, num_pairs, positions.options().dtype(kInt32));

Tries to allocate 8GB of memory when asked for 128k particles.
This can be fixed by using something like fancy iterators, however, I wonder if there is a reason the CPU implementation is written using torch operations exclusively.

@peastman
Copy link
Member

peastman commented May 5, 2023

This is one of many reasons you really don't want to use a thread per interaction. In torchmd/torchmd-net#61 (comment) I recommended using a smaller number of threads and having each thread loop over multiple interactions, but that didn't get implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants