Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nan gradients in analytical likelihood #468

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ repos:
rev: v1.10.0 # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
args:
[
--no-strict-optional,
--ignore-missing-imports,
--config-file=pyproject.toml,
]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ convention = "numpy"

[tool.mypy]
ignore_missing_imports = true
exclude = 'tests/*'

[build-system]
requires = ["poetry-core"]
Expand Down
85 changes: 34 additions & 51 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
from numpy import inf
from pymc.distributions.dist_math import check_parameters
Expand All @@ -25,7 +24,7 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray:
Parameters
----------
rt
A 1D numpy array of flipped R.... T.....s. (0, inf).
A 1D numpy array of flipped R.... pt.....s. (0, inf).
err
Error bound.

Expand All @@ -34,9 +33,11 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_small.
"""
ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err))
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0)
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2)
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would a better name for this boolean array be, maybe mask or sieve?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should pt.lt be used here as done elsewhere in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually equivalent but I was just playing around

_b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err))
_c = pt.sqrt(rt) + 1
Comment on lines +36 to +38
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fundamental operation is pt.sqrt(rt). It's better to do this first and reuse the result to avoid computing it again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical stability, it's better to group the constant factor C = 2 * pt.sqrt(2 * np.pi) * err and compare each member of sqrt_rt = pt.sqrt(rt) against 1/C.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Feel free to change this

_d = pt.maximum(_b, _c)
ks = _a * _d + (1 - _a) * 2
digicosmos86 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because _a is boolean, I think it's better to treat it as such and use pt.switch.

Suggested change
ks = _a * _d + (1 - _a) * 2
ks = pt.switch(mask, _d, 2) # having renamed `_a` to `mask`, for example

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comment below


return ks

Expand All @@ -56,9 +57,12 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_large.
"""
kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt))
kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0)
kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt)))
_a = np.pi * rt * err < 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_a = np.pi * rt * err < 1
_a = rt < 1/(np.pi * err)

_b = 1.0 / (np.pi * pt.sqrt(rt))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the k_large/small functions are called in succession. It might be better to merge them into a single function for improved computational efficiency, which is reusing sqrt_rt here.

_log = pt.log(np.pi * err) + pt.log(rt) # will require all members to be negative for the below operation to work
_c = pt.sqrt(-2 * _log) / (np.pi * sqrt_rt) # reusing sqrt_rt
_d = pt.maximum(_b, _c)
kl = _a * _b + (1 - _a) * _b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_c and _d are not used. Should _d be used in the second term instead of _b? Otherwise kl will be _b.

Suggested change
kl = _a * _b + (1 - _a) * _b
kl = pt.switch(mask, _b, _d)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comment below


return kl

Expand All @@ -81,34 +85,7 @@ def compare_k(rt: np.ndarray, err: float) -> np.ndarray:
ks = k_small(rt, err)
kl = k_large(rt, err)

return ks < kl


def get_ks(k_terms: int, fast: bool) -> np.ndarray:
"""Return an array of ks.

Returns an array of ks given the number of terms needed to approximate the sum of
the infinite series.

Parameters
----------
k_terms
number of terms needed
fast
whether the function is used in the fast of slow expansion.

Returns
-------
np.ndarray
An array of ks.
"""
ks = (
pt.arange(-pt.floor((k_terms - 1) / 2), pt.ceil((k_terms - 1) / 2) + 1)
if fast
else pt.arange(1, k_terms + 1).reshape((-1, 1))
)

return ks.astype(pytensor.config.floatX)
return pt.lt(ks, kl)


def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
Expand All @@ -133,7 +110,10 @@ def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
"""
# Slightly changed the original code to mimic the paper and
# ensure correctness
k = get_ks(k_terms, fast=True)
k = pt.arange(
-pt.floor((k_terms - 1) / 2.0),
pt.ceil((k_terms - 1) / 2.0) + 1.0,
)

# A log-sum-exp trick is used here
y = w + 2 * k.reshape((-1, 1))
Expand Down Expand Up @@ -166,7 +146,7 @@ def ftt01w_slow(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
np.ndarray
The approximated function f(tt|0, 1, w).
"""
k = get_ks(k_terms, fast=False)
k = pt.arange(1, k_terms + 1).reshape((-1, 1))
y = k * pt.sin(k * np.pi * w)
r = -pt.power(k, 2) * pt.power(np.pi, 2) * tt / 2
p = pt.sum(y * pt.exp(r), axis=0) * np.pi
Expand Down Expand Up @@ -208,7 +188,7 @@ def ftt01w(
p_fast = ftt01w_fast(tt, w, k_terms)
p_slow = ftt01w_slow(tt, w, k_terms)

p = pt.switch(lambda_rt, p_fast, p_slow)
p = lambda_rt * p_fast + (1.0 - lambda_rt) * p_slow

return p

Expand All @@ -220,7 +200,7 @@ def logp_ddm(
z: float,
t: float,
err: float = 1e-15,
k_terms: int = 20,
k_terms: int = 7,
epsilon: float = 1e-15,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what was used for testing / is used as actual value for inference, but I guess it is this default?

The epsilon for the rt part should rather be on the order of 1e-3, or even 1e-2.

If we are reusing the same epsilon in multiple places, we should probably separate it out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was playing around. It seems that changing k_terms to 7 did not improve speed or computation

) -> np.ndarray:
"""Compute analytical likelihood for the DDM model with `sv`.
Expand Down Expand Up @@ -262,15 +242,17 @@ def logp_ddm(
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response
rt = rt - t

p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))
negative_rt = pt.lt(rt, epsilon)

logp = pt.where(
rt <= epsilon,
LOGP_LB,
tt = negative_rt * epsilon + (1 - negative_rt) * rt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tt = negative_rt * epsilon + (1 - negative_rt) * rt
tt = pt.switch(negative_rt, epsilon, rt)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually is done on purpose. pt.switch can cause some weird errors


p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick note,
it seems like we are only passing k_terms here, not actually computing k_terms.
I think we had agreed to do that way back on another iteration of trying to fix issues with this likelihood, and I think it's fine, but in this case we should make the default a bit higher than 7.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just playing around here. Not actually changing


logp = negative_rt * LOGP_LB + (1 - negative_rt) * (
pt.log(p)
- v_flipped * a * z_flipped
- (v_flipped**2 * rt / 2.0)
- 2.0 * pt.log(a),
- (v_flipped**2 * tt / 2.0)
- 2.0 * pt.log(pt.maximum(epsilon, a))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reflecting on this a bit,
I think this maximum business is actually corrupting the gradients, so we should just a priori restrict a > epsilon (via prior essentially?).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the other hand, apart from initialization (which 1. our strategies should already avoid, 2. we generally can impact) a should basically never come close to 0, so this should basically never be the culprit...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this did help a bit, for some reason...

)

checked_logp = check_parameters(logp, a >= 0, msg="a >= 0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the spirit of above, this check could be a>0 but honestly we shouldn't really ever get there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Expand Down Expand Up @@ -333,7 +315,8 @@ def logp_ddm_sdv(
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response
rt = rt - t

p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))
tt = pt.switch(rt <= epsilon, epsilon, rt)
digicosmos86 marked this conversation as resolved.
Show resolved Hide resolved
p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))

logp = pt.switch(
rt <= epsilon,
Expand All @@ -342,11 +325,11 @@ def logp_ddm_sdv(
+ (
(a * z_flipped * sv) ** 2
- 2 * a * v_flipped * z_flipped
- (v_flipped**2) * rt
- (v_flipped**2) * tt
)
/ (2 * (sv**2) * rt + 2)
- 0.5 * pt.log(sv**2 * rt + 1)
- 2 * pt.log(a),
/ (2 * (sv**2) * tt + 2)
- 0.5 * pt.log(sv**2 * tt + 1)
- 2 * pt.log(pt.maximum(epsilon, a)),
Comment on lines 325 to +332
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Evaluate separately providing a meaningful name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are probably not going to keep this one. I just tried this to see if we keep the log positive we can get somewhere. It helps a bit it seems, but the culprit is not this one

)

checked_logp = check_parameters(logp, a >= 0, msg="a >= 0")
Expand Down
72 changes: 42 additions & 30 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,28 @@
old implementation of WFPT from (https://github.com/hddm-devs/hddm)
"""

import math
from pathlib import Path
from itertools import product

import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt
import pytest
from numpy.random import rand

from pytensor.compile.nanguardmode import NanGuardMode

import hssm

# pylint: disable=C0413
from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv
from hssm.likelihoods.analytical import logp_ddm, logp_ddm_sdv
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
from hssm.distribution_utils import make_likelihood_callable

hssm.set_floatX("float32")


def test_kterm(data_ddm):
"""This function defines a range of kterms and tests results to
makes sure they are not equal to infinity or unknown values.
"""
for k_term in range(7, 12):
v = (rand() - 0.5) * 1.5
sv = 0
a = (1.5 + rand()) / 2
z = 0.5 * rand()
t = rand() * 0.5
err = 1e-7
logp = logp_ddm_sdv(data_ddm, v, a, z, t, sv, err, k_terms=k_term)
logp = sum(logp.eval())
assert not math.isinf(logp)
assert not math.isnan(logp)


def test_compare_k(data_ddm):
"""This function tests output of decision function."""
err = 1e-7
data = data_ddm["rt"] * data_ddm["response"]
lambda_rt = compare_k(np.abs(data.values), err)
assert all(not v for v in lambda_rt.eval())
assert data_ddm.shape[0] == lambda_rt.eval().shape[0]


# def test_logp(data_fixture):
# """
# This function compares new and old implementation of logp calculation
Expand Down Expand Up @@ -128,13 +104,49 @@ def test_bbox(data_ddm):
)


cav_data = hssm.load_data("cavanagh_theta")
cav_data: pd.DataFrame = hssm.load_data("cavanagh_theta")
cav_data_numpy = cav_data[["rt", "response"]].values
param_matrix = product(
(0.0, 0.01, 0.05, 0.5), ("analytical", "approx_differentiable", "blackbox")
)


def test_analytical_gradient():
v = pt.dvector()
a = pt.dvector()
z = pt.dvector()
t = pt.dvector()
sv = pt.dvector()
size = cav_data_numpy.shape[0]
logp = logp_ddm(cav_data_numpy, v, a, z, t).sum()
grad = pt.grad(logp, wrt=[v, a, z, t])
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=False),
)
v_test = np.random.normal(size=size)
a_test = np.random.uniform(0.0001, 2, size=size)
z_test = np.random.uniform(0.1, 1.0, size=size)
t_test = np.random.uniform(0, 2, size=size)
sv_test = np.random.uniform(0.001, 1.0, size=size)
grad = np.array(grad_func(v_test, a_test, z_test, t_test))

assert np.all(np.isfinite(grad), axis=None), "Gradient contains non-finite values."

grad_func_sdv = pytensor.function(
[v, a, z, t, sv],
pt.grad(logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t]),
mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=False),
)

grad_sdv = np.array(grad_func_sdv(v_test, a_test, z_test, t_test, sv_test))

assert np.all(
np.isfinite(grad_sdv), axis=None
), "Gradient contains non-finite values."


@pytest.mark.parametrize("p_outlier, loglik_kind", param_matrix)
def test_lapse_distribution_cav(p_outlier, loglik_kind):
true_values = (0.5, 1.5, 0.5, 0.5)
Expand Down