Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cover more priors with empirical distribution extension function #180

Closed
wants to merge 9 commits into from
Closed
1 change: 1 addition & 0 deletions enterprise_extensions/empirical_distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def make_empirical_distributions(pta, paramlist, params, chain,

if len(pl) == 1:
idx = pta.param_names.index(pl[0])

prior_min = pta.params[idx].prior._defaults['pmin']
prior_max = pta.params[idx].prior._defaults['pmax']

Expand Down
69 changes: 51 additions & 18 deletions enterprise_extensions/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,35 @@
EmpiricalDistribution2DKDE)


def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outdir='chains'):
def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outdir='./chains'):
new_emp_dists = []
modified = False # check if anything was changed
for emp_dist in emp_dists:
if isinstance(emp_dist, EmpiricalDistribution2D) or isinstance(emp_dist, EmpiricalDistribution2DKDE):
# check if we need to extend the distribution
prior_ok=True
for ii, (param, nbins) in enumerate(zip(emp_dist.param_names, emp_dist._Nbins)):
if param not in pta.param_names: # skip if one of the parameters isn't in our PTA object
continue
param_names = [par.name for par in pta.params]
if param not in param_names: # skip if one of the parameters isn't in our PTA object
short_par = '_'.join(param.split('_')[:-1]) # make sure we aren't skipping priors with size!=None
if short_par in param_names:
param = short_par
else:
continue
# check 2 conditions on both params to make sure that they cover their priors
# skip if emp dist already covers the prior
param_idx = pta.param_names.index(param)
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
param_idx = param_names.index(param)
if pta.params[param_idx].type not in ['uniform', 'normal']:
msg = '{} cannot be covered automatically by the empirical distribution\n'.format(pta.params[param_idx].prior)
msg += 'Please check that your prior is covered by the empirical distribution.\n'
print(msg)
continue
elif pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
elif pta.params[param_idx].type == 'normal':
prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma']
prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma']

# no need to extend if histogram edges are already prior min/max
if isinstance(emp_dist, EmpiricalDistribution2D):
Expand All @@ -53,9 +67,13 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
maxvals = []
idxs_to_remove = []
for ii, (param, nbins) in enumerate(zip(emp_dist.param_names, emp_dist._Nbins)):
param_idx = pta.param_names.index(param)
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
param_idx = param_names.index(param)
if pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
elif pta.params[param_idx].type == 'normal':
prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma']
prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma']
# drop samples that are outside the prior range (in case prior is smaller than samples)
if isinstance(emp_dist, EmpiricalDistribution2D):
samples[(samples[:, ii] < prior_min) | (samples[:, ii] > prior_max), ii] = -np.inf
Expand All @@ -74,11 +92,27 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
new_emp_dists.append(new_emp)

elif isinstance(emp_dist, EmpiricalDistribution1D) or isinstance(emp_dist, EmpiricalDistribution1DKDE):
if emp_dist.param_name not in pta.param_names:
param_names = [par.name for par in pta.params]
if emp_dist.param_name not in param_names: # skip if one of the parameters isn't in our PTA object
short_par = '_'.join(emp_dist.param_name.split('_')[:-1]) # make sure we aren't skipping priors with size!=None
if short_par in param_names:
param = short_par
else:
continue
else:
param = emp_dist.param_name
param_idx = param_names.index(param)
if pta.params[param_idx].type not in ['uniform', 'normal']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am testing this with a run that has an array-sampled parameter and I'm get an error at this line. I believe it's the same issue that the full param_names is longer than the list params. In my case there is a single-sampled parameter in the list after the array-sampled parameter. I think that would cause this problem. This is why in #179 I made a separate list of param names from the parameter.Parameter.name attributes.

The Trace of my error is below, if useful.

Traceback (most recent call last):
  File "/mmfs1/home/hazboun/miniconda3/envs/pta/lib/python3.9/site-packages/pta_sim/scripts//model2a_advnoise_sw.py", line 196, in <module>
    Sampler = sampler.setup_sampler(pta_crn, outdir=args.outdir, resume=True,
  File "/mmfs1/home/hazboun/miniconda3/envs/pta/lib/python3.9/site-packages/enterprise_extensions/sampler.py", line 1107, in setup_sampler
    jp = JumpProposal(pta, empirical_distr=empirical_distr, save_ext_dists=save_ext_dists, outdir=outdir)
  File "/mmfs1/home/hazboun/miniconda3/envs/pta/lib/python3.9/site-packages/enterprise_extensions/sampler.py", line 252, in __init__
    self.empirical_distr = extend_emp_dists(pta, self.empirical_distr, npoints=100_000,
  File "/mmfs1/home/hazboun/miniconda3/envs/pta/lib/python3.9/site-packages/enterprise_extensions/sampler.py", line 97, in extend_emp_dists
    if pta.params[param_idx].type not in ['uniform', 'normal']:
IndexError: list index out of range

msg = 'This prior cannot be covered automatically by the empirical distribution\n'
msg += 'Please check that your prior is covered by the empirical distribution.\n'
print(msg)
continue
param_idx = pta.param_names.index(emp_dist.param_name)
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
if pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
elif pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma']
prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma']
# check 2 conditions on param to make sure that it covers the prior
# skip if emp dist already covers the prior
if isinstance(emp_dist, EmpiricalDistribution1D):
Expand All @@ -96,7 +130,6 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
new_bins = []
idxs_to_remove = []
# drop samples that are outside the prior range (in case prior is smaller than samples)

if isinstance(emp_dist, EmpiricalDistribution1D):
samples[(samples < prior_min) | (samples > prior_max)] = -np.inf
elif isinstance(emp_dist, EmpiricalDistribution1DKDE):
Expand All @@ -111,20 +144,20 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
minval=prior_min, maxval=prior_max,
bandwidth=emp_dist.bandwidth)
new_emp_dists.append(new_emp)

else:
print('Unable to extend class of unknown type to the edges of the priors.')
new_emp_dists.append(emp_dist)
continue

if save_ext_dists and modified: # if user wants to save them, and they have been modified...
pickle.dump(new_emp_dists, outdir + 'new_emp_dists.pkl')
if save_ext_dists and modified: # if user wants to save them, and they have been modified...
with open(outdir + '/new_emp_dists.pkl', 'wb') as f:
pickle.dump(new_emp_dists, f)
return new_emp_dists


class JumpProposal(object):

def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, save_ext_dists=False, outdir='chains'):
def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, save_ext_dists=False, outdir='./chains'):
"""Set up some custom jump proposals"""
self.params = pta.params
self.pnames = pta.param_names
Expand Down