Skip to content

Commit

Permalink
Merge pull request #585 from lnccbrown/578-improve-analytical-likelihood
Browse files Browse the repository at this point in the history
578 improve analytical likelihood
  • Loading branch information
cpaniaguam authored Sep 27, 2024
2 parents 4bd6c77 + c56d654 commit fa12018
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 121 deletions.
44 changes: 34 additions & 10 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@

LOGP_LB = pm.floatX(-66.1)

π = np.pi
τ = 2 * π
sqrt_τ = pt.sqrt(τ)
log_π = pt.log(π)
log_τ = pt.log(τ)
log_4 = pt.log(4)


def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return pt.max(pt.stack([a, b]), axis=0)


def k_small(rt: np.ndarray, err: float) -> np.ndarray:
"""Determine number of terms needed for small-t expansion.
Expand All @@ -34,9 +45,15 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_small.
"""
ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err))
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0)
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2)
sqrt_rt = pt.sqrt(rt)
log_rt = pt.log(rt)
rt_log_2_sqrt_τ_rt_times_2 = rt * (log_4 + log_τ + log_rt)

ks = 2 + pt.sqrt(-err * rt_log_2_sqrt_τ_rt_times_2)
ks = _max(ks, sqrt_rt + 1)

condition = 2 * sqrt_τ * sqrt_rt * err < 1
ks = pt.switch(condition, ks, 2)

return ks

Expand All @@ -56,9 +73,16 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_large.
"""
kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt))
kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0)
kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt)))
log_rt = pt.log(rt)
sqrt_rt = pt.sqrt(rt)
log_err = pt.log(err)

π_rt_err = π * rt * err
π_sqrt_rt = π * sqrt_rt

kl = pt.sqrt(-2 * (log_π + log_err + log_rt)) / π_sqrt_rt
kl = _max(kl, 1.0 / pt.sqrt(π_sqrt_rt))
kl = pt.switch(π_rt_err < 1, kl, 1.0 / π_sqrt_rt)

return kl

Expand Down Expand Up @@ -141,7 +165,7 @@ def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
c = pt.max(r, axis=0)
p = pt.exp(c) * pt.sum(y * pt.exp(r - c), axis=0)
# Normalize p
p = p / pt.sqrt(2 * np.pi * pt.power(tt, 3))
p = p / pt.sqrt(2 * π * pt.power(tt, 3))

return p

Expand All @@ -167,9 +191,9 @@ def ftt01w_slow(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
The approximated function f(tt|0, 1, w).
"""
k = get_ks(k_terms, fast=False)
y = k * pt.sin(k * np.pi * w)
r = -pt.power(k, 2) * pt.power(np.pi, 2) * tt / 2
p = pt.sum(y * pt.exp(r), axis=0) * np.pi
y = k * pt.sin(k * π * w)
r = -pt.power(k, 2) * pt.power(π, 2) * tt / 2
p = pt.sum(y * pt.exp(r), axis=0) * π

return p

Expand Down
161 changes: 50 additions & 111 deletions tests/test_likelihoods_lba.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,82 @@
"""Unit testing for LBA likelihood functions."""

from pathlib import Path
from itertools import product

import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt
import pytest
import arviz as az
from pytensor.compile.nanguardmode import NanGuardMode

import hssm

# pylint: disable=C0413
from hssm.likelihoods.analytical import logp_lba2, logp_lba3
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
from hssm.distribution_utils import make_likelihood_callable

hssm.set_floatX("float32")

CLOSE_TOLERANCE = 1e-4


def test_lba2_basic():
size = 1000
def filter_theta(theta, exclude_keys=["A", "b"]):
"""Filter out specific keys from the theta dictionary."""
return {k: v for k, v in theta.items() if k not in exclude_keys}

lba_data_out = hssm.simulate_data(
model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size
)

# Test if vectorization ok across parameters
out_A_vec = logp_lba2(
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0
).eval()
out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval()
assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE)

out_b_vec = logp_lba2(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=1.0,
v1=1.0,
).eval()
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE)

out_v_vec = logp_lba2(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=np.array([1.0] * size),
v1=np.array([1.0] * size),
).eval()
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE)

# Test A > b leads to error
def assert_parameter_value_error(logp_func, lba_data_out, A, b, theta):
"""Helper function to assert ParameterValueError for given parameters."""
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0
logp_func(
lba_data_out.values,
A=A,
b=b,
**filter_theta(theta, ["A", "b"]),
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0
).eval()
def vectorize_param(theta, param, size):
"""
Vectorize a specific parameter in the theta dictionary.
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values,
A=np.array([0.6] * 1000),
b=np.array([0.5] * 1000),
v0=1.0,
v1=1.0,
).eval()
Parameters:
theta (dict): Dictionary of parameters.
param (str): The parameter to vectorize.
size (int): The size of the vector.
Returns:
dict: A new dictionary with the specified parameter vectorized.
def test_lba3_basic():
size = 1000
Examples:
>>> theta = {"A": 0.2, "b": 0.5, "v0": 1.0, "v1": 1.0}
>>> vectorize_param(theta, "A", 3)
{'A': array([0.2, 0.2, 0.2]), 'b': 0.5, 'v0': 1.0, 'v1': 1.0}
lba_data_out = hssm.simulate_data(
model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size
)

# Test if vectorization ok across parameters
out_A_vec = logp_lba3(
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

out_base = logp_lba3(
lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE)

out_b_vec = logp_lba3(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=1.0,
v1=1.0,
v2=1.0,
).eval()
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE)

out_v_vec = logp_lba3(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=np.array([1.0] * size),
v1=np.array([1.0] * size),
v2=np.array([1.0] * size),
).eval()
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE)
>>> vectorize_param(theta, "v0", 2)
{'A': 0.2, 'b': 0.5, 'v0': array([1., 1.]), 'v1': 1.0}
"""
return {k: (np.full(size, v) if k == param else v) for k, v in theta.items()}

# Test A > b leads to error
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval()
theta_lba2 = dict(A=0.2, b=0.5, v0=1.0, v1=1.0)
theta_lba3 = theta_lba2 | {"v2": 1.0}

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values,
A=np.array([0.6] * 1000),
b=np.array([0.5] * 1000),
v0=1.0,
v1=1.0,
v2=1.0,
).eval()
@pytest.mark.parametrize(
"logp_func, model, theta",
[(logp_lba2, "lba2", theta_lba2), (logp_lba3, "lba3", theta_lba3)],
)
def test_lba(logp_func, model, theta):
size = 1000
lba_data_out = hssm.simulate_data(model=model, theta=theta, size=size)

# Test if vectorization is ok across parameters
for param in theta:
param_vec = vectorize_param(theta, param, size)
out_vec = logp_func(lba_data_out.values, **param_vec).eval()
out_base = logp_func(lba_data_out.values, **theta).eval()
assert np.allclose(out_vec, out_base, atol=CLOSE_TOLERANCE)

# Test A > b leads to error
A_values = [np.full(size, 0.6), 0.6]
b_values = [np.full(size, 0.5), 0.5]

for A, b in product(A_values, b_values):
assert_parameter_value_error(logp_func, lba_data_out, A, b, theta)

0 comments on commit fa12018

Please sign in to comment.