Skip to content

Commit

Permalink
Add in additional stubs to mock MPI object and update tests (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
jellis18 authored Nov 4, 2022
1 parent 0848a84 commit eaede19
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
2 changes: 1 addition & 1 deletion PTMCMCSampler/PTMCMCSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.stream = [np.random.default_rng(s) for s in child_seeds]
else:
self.stream = None
self.stream = comm.scatter(self.stream, root=0)
self.stream = self.comm.scatter(self.stream, root=0)

self.ndim = ndim
self.logl = _function_wrapper(logl, loglargs, loglkwargs)
Expand Down
11 changes: 11 additions & 0 deletions PTMCMCSampler/nompi4py.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def recv(self, source=1, tag=55):
def Iprobe(self, source=1, tag=55):
pass

def scatter(self, sendobj, **kwargs):
if sendobj is not None:
return sendobj[0]
return None

def bcast(self, obj, **kwargs):
return obj

def gather(self, sendobj, **kwargs):
return [sendobj]


# Global object representing no MPI:
COMM_WORLD = MPIDummy()
11 changes: 11 additions & 0 deletions tests/test_nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import numpy as np
import scipy.linalg as sl
import scipy.optimize as so
from mpi4py import MPI

from PTMCMCSampler import PTMCMCSampler
from PTMCMCSampler import nompi4py as MPIDUMMY


class GaussianLikelihood(object):
Expand Down Expand Up @@ -165,6 +167,9 @@ class TestNuts(TestCase):
def tearDownClass(cls):
shutil.rmtree("chains")

def setUp(self) -> None:
self.comm = MPI.COMM_WORLD

def test_nuts(self):
ndim = 40
glo = GaussianLikelihood(ndim=ndim, pmin=0.0, pmax=10.0)
Expand Down Expand Up @@ -196,6 +201,7 @@ def test_nuts(self):
logl_grad=gl.lnlikefn_grad,
logp_grad=gl.lnpriorfn_grad,
outDir="./chains",
comm=self.comm,
)

sampler.sample(
Expand All @@ -213,3 +219,8 @@ def test_nuts(self):
HMCsteps=100,
HMCstepsize=0.4,
)


class TestNutsNoMPI(TestNuts):
def setUp(self) -> None:
self.comm = MPIDUMMY.COMM_WORLD
22 changes: 21 additions & 1 deletion tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from unittest import TestCase

import numpy as np
from mpi4py import MPI

from PTMCMCSampler import PTMCMCSampler
from PTMCMCSampler import nompi4py as MPIDUMMY


class GaussianLikelihood(object):
Expand Down Expand Up @@ -65,6 +67,9 @@ class TestSimpleSampler(TestCase):
def tearDownClass(cls):
shutil.rmtree("chains")

def setUp(self) -> None:
self.comm = MPI.COMM_WORLD

def test_simple(self):
# ## Setup Gaussian model class
ndim = 20
Expand All @@ -76,10 +81,25 @@ def test_simple(self):
p0 = np.random.uniform(pmin, pmax, ndim)
cov = np.eye(ndim) * 0.1**2

sampler = PTMCMCSampler.PTSampler(ndim, glo.lnlikefn, glo.lnpriorfn, np.copy(cov), outDir="./chains")
sampler = PTMCMCSampler.PTSampler(
ndim,
glo.lnlikefn,
glo.lnpriorfn,
np.copy(cov),
outDir="./chains",
comm=self.comm,
)

# add to jump proposal cycle
ujump = UniformJump(pmin, pmax)
sampler.addProposalToCycle(ujump.jump, 5)

sampler.sample(p0, 10000, burn=500, thin=1, covUpdate=500, SCAMweight=20, AMweight=20, DEweight=20)


class TestSimpleSamplerNoMPI(TestSimpleSampler):
def setUp(self) -> None:
self.comm = MPIDUMMY.COMM_WORLD

def test_simple(self):
return super().test_simple()

0 comments on commit eaede19

Please sign in to comment.