diff --git a/enterprise_extensions/empirical_distr.py b/enterprise_extensions/empirical_distr.py index a3c76827..d8d14030 100644 --- a/enterprise_extensions/empirical_distr.py +++ b/enterprise_extensions/empirical_distr.py @@ -227,6 +227,7 @@ def make_empirical_distributions(pta, paramlist, params, chain, if len(pl) == 1: idx = pta.param_names.index(pl[0]) + prior_min = pta.params[idx].prior._defaults['pmin'] prior_max = pta.params[idx].prior._defaults['pmax'] diff --git a/enterprise_extensions/sampler.py b/enterprise_extensions/sampler.py index b73f0a12..b9a4d3d9 100644 --- a/enterprise_extensions/sampler.py +++ b/enterprise_extensions/sampler.py @@ -16,7 +16,7 @@ EmpiricalDistribution2DKDE) -def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outdir='chains'): +def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outdir='./chains'): new_emp_dists = [] modified = False # check if anything was changed for emp_dist in emp_dists: @@ -24,13 +24,27 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd # check if we need to extend the distribution prior_ok=True for ii, (param, nbins) in enumerate(zip(emp_dist.param_names, emp_dist._Nbins)): - if param not in pta.param_names: # skip if one of the parameters isn't in our PTA object - continue + param_names = [par.name for par in pta.params] + if param not in param_names: # skip if one of the parameters isn't in our PTA object + short_par = '_'.join(param.split('_')[:-1]) # make sure we aren't skipping priors with size!=None + if short_par in param_names: + param = short_par + else: + continue # check 2 conditions on both params to make sure that they cover their priors # skip if emp dist already covers the prior - param_idx = pta.param_names.index(param) - prior_min = pta.params[param_idx].prior._defaults['pmin'] - prior_max = pta.params[param_idx].prior._defaults['pmax'] + param_idx = param_names.index(param) + if pta.params[param_idx].type not in ['uniform', 'normal']: + msg = '{} cannot be covered automatically by the empirical distribution\n'.format(pta.params[param_idx].prior) + msg += 'Please check that your prior is covered by the empirical distribution.\n' + print(msg) + continue + elif pta.params[param_idx].type == 'uniform': + prior_min = pta.params[param_idx].prior._defaults['pmin'] + prior_max = pta.params[param_idx].prior._defaults['pmax'] + elif pta.params[param_idx].type == 'normal': + prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma'] + prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma'] # no need to extend if histogram edges are already prior min/max if isinstance(emp_dist, EmpiricalDistribution2D): @@ -53,9 +67,13 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd maxvals = [] idxs_to_remove = [] for ii, (param, nbins) in enumerate(zip(emp_dist.param_names, emp_dist._Nbins)): - param_idx = pta.param_names.index(param) - prior_min = pta.params[param_idx].prior._defaults['pmin'] - prior_max = pta.params[param_idx].prior._defaults['pmax'] + param_idx = param_names.index(param) + if pta.params[param_idx].type == 'uniform': + prior_min = pta.params[param_idx].prior._defaults['pmin'] + prior_max = pta.params[param_idx].prior._defaults['pmax'] + elif pta.params[param_idx].type == 'normal': + prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma'] + prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma'] # drop samples that are outside the prior range (in case prior is smaller than samples) if isinstance(emp_dist, EmpiricalDistribution2D): samples[(samples[:, ii] < prior_min) | (samples[:, ii] > prior_max), ii] = -np.inf @@ -74,11 +92,27 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd new_emp_dists.append(new_emp) elif isinstance(emp_dist, EmpiricalDistribution1D) or isinstance(emp_dist, EmpiricalDistribution1DKDE): - if emp_dist.param_name not in pta.param_names: + param_names = [par.name for par in pta.params] + if emp_dist.param_name not in param_names: # skip if one of the parameters isn't in our PTA object + short_par = '_'.join(emp_dist.param_name.split('_')[:-1]) # make sure we aren't skipping priors with size!=None + if short_par in param_names: + param = short_par + else: + continue + else: + param = emp_dist.param_name + param_idx = param_names.index(param) + if pta.params[param_idx].type not in ['uniform', 'normal']: + msg = 'This prior cannot be covered automatically by the empirical distribution\n' + msg += 'Please check that your prior is covered by the empirical distribution.\n' + print(msg) continue - param_idx = pta.param_names.index(emp_dist.param_name) - prior_min = pta.params[param_idx].prior._defaults['pmin'] - prior_max = pta.params[param_idx].prior._defaults['pmax'] + if pta.params[param_idx].type == 'uniform': + prior_min = pta.params[param_idx].prior._defaults['pmin'] + prior_max = pta.params[param_idx].prior._defaults['pmax'] + elif pta.params[param_idx].type == 'uniform': + prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma'] + prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma'] # check 2 conditions on param to make sure that it covers the prior # skip if emp dist already covers the prior if isinstance(emp_dist, EmpiricalDistribution1D): @@ -96,7 +130,6 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd new_bins = [] idxs_to_remove = [] # drop samples that are outside the prior range (in case prior is smaller than samples) - if isinstance(emp_dist, EmpiricalDistribution1D): samples[(samples < prior_min) | (samples > prior_max)] = -np.inf elif isinstance(emp_dist, EmpiricalDistribution1DKDE): @@ -111,20 +144,20 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd minval=prior_min, maxval=prior_max, bandwidth=emp_dist.bandwidth) new_emp_dists.append(new_emp) - else: print('Unable to extend class of unknown type to the edges of the priors.') new_emp_dists.append(emp_dist) continue - if save_ext_dists and modified: # if user wants to save them, and they have been modified... - pickle.dump(new_emp_dists, outdir + 'new_emp_dists.pkl') + if save_ext_dists and modified: # if user wants to save them, and they have been modified... + with open(outdir + '/new_emp_dists.pkl', 'wb') as f: + pickle.dump(new_emp_dists, f) return new_emp_dists class JumpProposal(object): - def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, save_ext_dists=False, outdir='chains'): + def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, save_ext_dists=False, outdir='./chains'): """Set up some custom jump proposals""" self.params = pta.params self.pnames = pta.param_names