-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #585 from lnccbrown/578-improve-analytical-likelihood
578 improve analytical likelihood
- Loading branch information
Showing
2 changed files
with
84 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,143 +1,82 @@ | ||
"""Unit testing for LBA likelihood functions.""" | ||
|
||
from pathlib import Path | ||
from itertools import product | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pymc as pm | ||
import pytensor | ||
import pytensor.tensor as pt | ||
import pytest | ||
import arviz as az | ||
from pytensor.compile.nanguardmode import NanGuardMode | ||
|
||
import hssm | ||
|
||
# pylint: disable=C0413 | ||
from hssm.likelihoods.analytical import logp_lba2, logp_lba3 | ||
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox | ||
from hssm.distribution_utils import make_likelihood_callable | ||
|
||
hssm.set_floatX("float32") | ||
|
||
CLOSE_TOLERANCE = 1e-4 | ||
|
||
|
||
def test_lba2_basic(): | ||
size = 1000 | ||
def filter_theta(theta, exclude_keys=["A", "b"]): | ||
"""Filter out specific keys from the theta dictionary.""" | ||
return {k: v for k, v in theta.items() if k not in exclude_keys} | ||
|
||
lba_data_out = hssm.simulate_data( | ||
model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size | ||
) | ||
|
||
# Test if vectorization ok across parameters | ||
out_A_vec = logp_lba2( | ||
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0 | ||
).eval() | ||
out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval() | ||
assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) | ||
|
||
out_b_vec = logp_lba2( | ||
lba_data_out.values, | ||
A=np.array([0.2] * size), | ||
b=np.array([0.5] * size), | ||
v0=1.0, | ||
v1=1.0, | ||
).eval() | ||
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) | ||
|
||
out_v_vec = logp_lba2( | ||
lba_data_out.values, | ||
A=np.array([0.2] * size), | ||
b=np.array([0.5] * size), | ||
v0=np.array([1.0] * size), | ||
v1=np.array([1.0] * size), | ||
).eval() | ||
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) | ||
|
||
# Test A > b leads to error | ||
def assert_parameter_value_error(logp_func, lba_data_out, A, b, theta): | ||
"""Helper function to assert ParameterValueError for given parameters.""" | ||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba2( | ||
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0 | ||
logp_func( | ||
lba_data_out.values, | ||
A=A, | ||
b=b, | ||
**filter_theta(theta, ["A", "b"]), | ||
).eval() | ||
|
||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval() | ||
|
||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba2( | ||
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0 | ||
).eval() | ||
def vectorize_param(theta, param, size): | ||
""" | ||
Vectorize a specific parameter in the theta dictionary. | ||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba2( | ||
lba_data_out.values, | ||
A=np.array([0.6] * 1000), | ||
b=np.array([0.5] * 1000), | ||
v0=1.0, | ||
v1=1.0, | ||
).eval() | ||
Parameters: | ||
theta (dict): Dictionary of parameters. | ||
param (str): The parameter to vectorize. | ||
size (int): The size of the vector. | ||
Returns: | ||
dict: A new dictionary with the specified parameter vectorized. | ||
def test_lba3_basic(): | ||
size = 1000 | ||
Examples: | ||
>>> theta = {"A": 0.2, "b": 0.5, "v0": 1.0, "v1": 1.0} | ||
>>> vectorize_param(theta, "A", 3) | ||
{'A': array([0.2, 0.2, 0.2]), 'b': 0.5, 'v0': 1.0, 'v1': 1.0} | ||
lba_data_out = hssm.simulate_data( | ||
model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size | ||
) | ||
|
||
# Test if vectorization ok across parameters | ||
out_A_vec = logp_lba3( | ||
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0 | ||
).eval() | ||
|
||
out_base = logp_lba3( | ||
lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0 | ||
).eval() | ||
|
||
assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) | ||
|
||
out_b_vec = logp_lba3( | ||
lba_data_out.values, | ||
A=np.array([0.2] * size), | ||
b=np.array([0.5] * size), | ||
v0=1.0, | ||
v1=1.0, | ||
v2=1.0, | ||
).eval() | ||
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) | ||
|
||
out_v_vec = logp_lba3( | ||
lba_data_out.values, | ||
A=np.array([0.2] * size), | ||
b=np.array([0.5] * size), | ||
v0=np.array([1.0] * size), | ||
v1=np.array([1.0] * size), | ||
v2=np.array([1.0] * size), | ||
).eval() | ||
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) | ||
>>> vectorize_param(theta, "v0", 2) | ||
{'A': 0.2, 'b': 0.5, 'v0': array([1., 1.]), 'v1': 1.0} | ||
""" | ||
return {k: (np.full(size, v) if k == param else v) for k, v in theta.items()} | ||
|
||
# Test A > b leads to error | ||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba3( | ||
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0 | ||
).eval() | ||
|
||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval() | ||
theta_lba2 = dict(A=0.2, b=0.5, v0=1.0, v1=1.0) | ||
theta_lba3 = theta_lba2 | {"v2": 1.0} | ||
|
||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba3( | ||
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0 | ||
).eval() | ||
|
||
with pytest.raises(pm.logprob.utils.ParameterValueError): | ||
logp_lba3( | ||
lba_data_out.values, | ||
A=np.array([0.6] * 1000), | ||
b=np.array([0.5] * 1000), | ||
v0=1.0, | ||
v1=1.0, | ||
v2=1.0, | ||
).eval() | ||
@pytest.mark.parametrize( | ||
"logp_func, model, theta", | ||
[(logp_lba2, "lba2", theta_lba2), (logp_lba3, "lba3", theta_lba3)], | ||
) | ||
def test_lba(logp_func, model, theta): | ||
size = 1000 | ||
lba_data_out = hssm.simulate_data(model=model, theta=theta, size=size) | ||
|
||
# Test if vectorization is ok across parameters | ||
for param in theta: | ||
param_vec = vectorize_param(theta, param, size) | ||
out_vec = logp_func(lba_data_out.values, **param_vec).eval() | ||
out_base = logp_func(lba_data_out.values, **theta).eval() | ||
assert np.allclose(out_vec, out_base, atol=CLOSE_TOLERANCE) | ||
|
||
# Test A > b leads to error | ||
A_values = [np.full(size, 0.6), 0.6] | ||
b_values = [np.full(size, 0.5), 0.5] | ||
|
||
for A, b in product(A_values, b_values): | ||
assert_parameter_value_error(logp_func, lba_data_out, A, b, theta) |