Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way to know how much memory is requried for a task? #423

Open
feiyang-k opened this issue Sep 1, 2023 · 3 comments
Open

Is there a way to know how much memory is requried for a task? #423

feiyang-k opened this issue Sep 1, 2023 · 3 comments
Labels
question Further information is requested

Comments

@feiyang-k
Copy link

Hi,

jax seems to reserve all the gpu memory at import. So we cannot see how much memory is used exactly by the ott package from the nvidia panels. Right now, if some problem runs into memory issues, the only thing we can do is to reduce the problem size until the error disappears. Is there a more direct way to know how much memory is required for a target task?

Thanks!

@michalk8
Copy link
Collaborator

michalk8 commented Sep 1, 2023

Hi @feiyang-k , please check https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html on how to change the pre-allocation JAX does. In short, you can do

import os  # before importing anything jax
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

import jax
import jax.numpy as jnp
...

@michalk8 michalk8 added the question Further information is requested label Sep 1, 2023
@feiyang-k
Copy link
Author

Thanks @michalk8 ! I tried it and it works exactly as I wished!

By the way, I'm using it jupyter notebook and the GPU memory recycling seems not fully working. Each time an OT problem is computed, the GPU memory will not be released.

More interestingly, if the computing successfully completed, it is ok to use the memory to compute the next problem. But if a problem went into error, then the allocated memory seems "dead". The available memory for the next OT problem will be the remaining memory, which could be much smaller. Thus, I would need to restart the Jupyter Notebook kernel every time I went into any error with ott. Is this a known issue?

Also, it seems I'm never able to interrupt the block running OT problems. It will never respond. I will need restart to jupyter notebook kernel whenever a task seems will never finish in a reasonable time. Is this as expected?

Thanks again!

@michalk8
Copy link
Collaborator

michalk8 commented Sep 8, 2023

By the way, I'm using it jupyter notebook and the GPU memory recycling seems not fully working. Each time an OT problem is computed, the GPU memory will not be released.

According to the docs, XLA_PYTHON_CLIENT_PREALLOCATE='false' will re-use the memory, XLA_PYTHON_CLIENT_ALLOCATOR='platform' will de-allocate it, but is much slower.

But if a problem went into error, then the allocated memory seems "dead". The available memory for the next OT problem will be the remaining memory, which could be much smaller.

I will go and investigate this behavior.

Also, it seems I'm never able to interrupt the block running OT problems. It will never respond. I will need restart to jupyter notebook kernel whenever a task seems will never finish in a reasonable time. Is this as expected?

I'm not 100% sure, but would say yes, as the code runs on device and the interrupt will happen when execution is given to host (will check if this statement is true). Maybe adding a printing callback (see this tutorial) will allow for easier interruption of an execution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants