high RAM usage #265
-
Hi all, I have run into some problems with memory usage when trying to fit a regression model. My setup is following:
When I run this, I get following error
So apparently I am out of memory, when I check my memory usage though, I am only at 59% of the RAM I've requested (~100GB out of 180GB), which surprised me. Does anyone have an idea what is happening here? Thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 6 replies
-
Hi Eduard, It seems that you are running 4 chains in parallel? I am not sure how Jax vectorization works exactly under the hood, but since there are quite a bit of matrix multiplication involved with some pretty large matrices, it's possible that you could run out of memory. Can you try reducing core counts and see if that helps? |
Beta Was this translation helpful? Give feedback.
-
Hi @digicosmos86, Sorry for the tardy reply. I tried reducing the core count, running 4 chains (1500 samples) on one core, but the problem remains is very similar if not the same. It crashes when computing the Log Likelihood after the sampling (see below). Is that something I should approach the Jax people about? Or what else can I do to run multiple chains of a model?
|
Beta Was this translation helpful? Give feedback.
-
Thanks for letting us know, Eduard! That does seem like Jax is allocating a lot of memory. We have a few leads and will definitely look into this. Please stay tuned :) |
Beta Was this translation helpful? Give feedback.
-
BTW, to help us debug this issue, can you try adding this extra argument to your model configuration: model = hssm.HSSM(
data=data,
z=0.5,
model='ornstein',
loglik=ornstein.onnx,
"include":[
{
"name": "v",
"formula":"v ~ 1 + (1|subj_idx) + coh",
"link": "identity"
}
],
hierarchical=True,
p_outlier=0.01,
model_config={"backend": "pytensor" },
) |
Beta Was this translation helpful? Give feedback.
-
Yes, exactly. It occurs while computing the log likelihoods. That is at least the last log message before the crash. At the moment the script runs on CPU. I could try to a GPU if that helped? |
Beta Was this translation helpful? Give feedback.
From the error message, since the error happens after sampling, it's very likely that the error is due to how PyMC handles the post-sampling calculation of likelihoods. It seems that they have recently changed this process to improve the stability of this calculation (they removed the use of chunks, for example, and added a "scan" vectorization here) which might help with memory usage.
Would you mind trying the latest version of hssm, which has the latest PyMC (5.9.0) as a dependency? You can install it in the same environment with
pip install git+https://github.com/lnccbrown/HSSM.git@280-pin-numpy-version