Parallel Chains Problem After Updating #398
-
Hi all, infer_DDM = DDM.sample(
sampler="nuts_numpyro", # type of sampler to choose,
cores=2, # how many cores to use
chains=chains, # how many chains to run
draws=draws, # number of draws from the markov chain
tune=int(draws*0.2), # number of burn-in samples
idata_kwargs=dict(log_likelihood=False)
I'm sure that I could run parallel chains in the past versions. So I'm wondering how I can solve this problem and why this did not happen in the past versions. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 9 replies
-
Updated:
And it can run parallel chains again! However, the estimation time become a lot larger than expected. |
Beta Was this translation helpful? Give feedback.
-
Hi @JoeSu112, First, are you using GPU for inference? If you have switched from using CPU to GPU as the backed for JAX, then it might not be possible to use parallel inference. In the error messages that you have provided, it seems to be checking for CUDA versions, which led me to think that you might have installed the GPU version of JAX. If you are unable to run different chains in parallel, then that could be the reason why estimation takes much longer. It might be helpful to paste all the packages and their versions in your virtual environment so we can start from there. Thanks! |
Beta Was this translation helpful? Give feedback.
-
@JoeSu112 Hi, I am also trying to use slurm scheduler to run multiple chains in parallel: Best, |
Beta Was this translation helpful? Give feedback.
Yeah this is the culprit. The problem is that newer numpy versions (>=1.26.4) stopped providing PyTensor with information about BLAS libraries, so newer versions of HSSM that depends on newer versions of PyMC (and Pytensor) will not know where to find the BLAS library.
There is an easy fix to this, which is to install PyMC in a conda environment before installing HSSM. We have a tutorial here. So basically you need to create a conda enivornment, install JAX (the GPU version with cuda versions consistent with the cuda installed on the system) and PyMC with conda, and then install HSSM via pip. We are making lots of progress in making HSSM available through conda as well, so hopefully soon …