-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nan gradients in analytical likelihood #468
Changes from 3 commits
6248a4d
5e15bbe
9d16604
dcf7709
7d0bd89
d61f247
4774a55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -9,7 +9,6 @@ | |||||
|
||||||
import numpy as np | ||||||
import pymc as pm | ||||||
import pytensor | ||||||
import pytensor.tensor as pt | ||||||
from numpy import inf | ||||||
from pymc.distributions.dist_math import check_parameters | ||||||
|
@@ -25,7 +24,7 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray: | |||||
Parameters | ||||||
---------- | ||||||
rt | ||||||
A 1D numpy array of flipped R.... T.....s. (0, inf). | ||||||
A 1D numpy array of flipped R.... pt.....s. (0, inf). | ||||||
err | ||||||
Error bound. | ||||||
|
||||||
|
@@ -34,9 +33,11 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray: | |||||
np.ndarray | ||||||
A 1D at array of k_small. | ||||||
""" | ||||||
ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err)) | ||||||
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0) | ||||||
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2) | ||||||
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 | ||||||
_b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err)) | ||||||
_c = pt.sqrt(rt) + 1 | ||||||
Comment on lines
+36
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fundamental operation is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For numerical stability, it's better to group the constant factor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! Feel free to change this |
||||||
_d = pt.max(pt.stack([_b, _c]), axis=0) | ||||||
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
ks = _a * _d + (1 - _a) * 2 | ||||||
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because _a is boolean, I think it's better to treat it as such and use
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see comment below |
||||||
|
||||||
return ks | ||||||
|
||||||
|
@@ -56,9 +57,11 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray: | |||||
np.ndarray | ||||||
A 1D at array of k_large. | ||||||
""" | ||||||
kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt)) | ||||||
kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0) | ||||||
kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt))) | ||||||
_a = np.pi * rt * err < 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
_b = 1.0 / (np.pi * pt.sqrt(rt)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the |
||||||
_c = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt)) | ||||||
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
_d = pt.max(pt.stack([_b, _c]), axis=0) | ||||||
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
kl = _a * _b + (1 - _a) * _b | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see comment below |
||||||
|
||||||
return kl | ||||||
|
||||||
|
@@ -81,34 +84,7 @@ def compare_k(rt: np.ndarray, err: float) -> np.ndarray: | |||||
ks = k_small(rt, err) | ||||||
kl = k_large(rt, err) | ||||||
|
||||||
return ks < kl | ||||||
|
||||||
|
||||||
def get_ks(k_terms: int, fast: bool) -> np.ndarray: | ||||||
"""Return an array of ks. | ||||||
|
||||||
Returns an array of ks given the number of terms needed to approximate the sum of | ||||||
the infinite series. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
k_terms | ||||||
number of terms needed | ||||||
fast | ||||||
whether the function is used in the fast of slow expansion. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
np.ndarray | ||||||
An array of ks. | ||||||
""" | ||||||
ks = ( | ||||||
pt.arange(-pt.floor((k_terms - 1) / 2), pt.ceil((k_terms - 1) / 2) + 1) | ||||||
if fast | ||||||
else pt.arange(1, k_terms + 1).reshape((-1, 1)) | ||||||
) | ||||||
|
||||||
return ks.astype(pytensor.config.floatX) | ||||||
return pt.lt(ks, kl) | ||||||
|
||||||
|
||||||
def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: | ||||||
|
@@ -133,7 +109,10 @@ def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: | |||||
""" | ||||||
# Slightly changed the original code to mimic the paper and | ||||||
# ensure correctness | ||||||
k = get_ks(k_terms, fast=True) | ||||||
k = pt.arange( | ||||||
-pt.floor((k_terms - 1) / 2.0), | ||||||
pt.ceil((k_terms - 1) / 2.0) + 1.0, | ||||||
) | ||||||
|
||||||
# A log-sum-exp trick is used here | ||||||
y = w + 2 * k.reshape((-1, 1)) | ||||||
|
@@ -166,7 +145,7 @@ def ftt01w_slow(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: | |||||
np.ndarray | ||||||
The approximated function f(tt|0, 1, w). | ||||||
""" | ||||||
k = get_ks(k_terms, fast=False) | ||||||
k = pt.arange(1, k_terms + 1).reshape((-1, 1)) | ||||||
y = k * pt.sin(k * np.pi * w) | ||||||
r = -pt.power(k, 2) * pt.power(np.pi, 2) * tt / 2 | ||||||
p = pt.sum(y * pt.exp(r), axis=0) * np.pi | ||||||
|
@@ -208,7 +187,7 @@ def ftt01w( | |||||
p_fast = ftt01w_fast(tt, w, k_terms) | ||||||
p_slow = ftt01w_slow(tt, w, k_terms) | ||||||
|
||||||
p = pt.switch(lambda_rt, p_fast, p_slow) | ||||||
p = lambda_rt * p_fast + (1.0 - lambda_rt) * p_slow | ||||||
|
||||||
return p | ||||||
|
||||||
|
@@ -220,7 +199,7 @@ def logp_ddm( | |||||
z: float, | ||||||
t: float, | ||||||
err: float = 1e-15, | ||||||
k_terms: int = 20, | ||||||
k_terms: int = 7, | ||||||
epsilon: float = 1e-15, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know what was used for testing / is used as actual value for inference, but I guess it is this default? The epsilon for the If we are reusing the same epsilon in multiple places, we should probably separate it out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was playing around. It seems that changing |
||||||
) -> np.ndarray: | ||||||
"""Compute analytical likelihood for the DDM model with `sv`. | ||||||
|
@@ -262,15 +241,17 @@ def logp_ddm( | |||||
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response | ||||||
rt = rt - t | ||||||
|
||||||
p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) | ||||||
negative_rt = rt <= epsilon | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok reflecting on this a bit, the logic that we want should probably look something like:
This should actually cut the gradient for problematic rts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We were doing this. I think the problem is that the gradient is computed anyway and the over/underflow was still happening
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
logp = pt.where( | ||||||
rt <= epsilon, | ||||||
LOGP_LB, | ||||||
tt = negative_rt * epsilon + (1 - negative_rt) * rt | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually is done on purpose. |
||||||
|
||||||
p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quick note, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just playing around here. Not actually changing |
||||||
|
||||||
logp = negative_rt * LOGP_LB + (1 - negative_rt) * ( | ||||||
pt.log(p) | ||||||
- v_flipped * a * z_flipped | ||||||
- (v_flipped**2 * rt / 2.0) | ||||||
- 2.0 * pt.log(a), | ||||||
- (v_flipped**2 * tt / 2.0) | ||||||
- 2.0 * pt.log(pt.maximum(epsilon, a)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reflecting on this a bit, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on the other hand, apart from initialization (which 1. our strategies should already avoid, 2. we generally can impact) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this did help a bit, for some reason... |
||||||
) | ||||||
|
||||||
checked_logp = check_parameters(logp, a >= 0, msg="a >= 0") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the spirit of above, this check could be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||||||
|
@@ -333,7 +314,8 @@ def logp_ddm_sdv( | |||||
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response | ||||||
rt = rt - t | ||||||
|
||||||
p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) | ||||||
tt = pt.switch(rt <= epsilon, epsilon, rt) | ||||||
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) | ||||||
|
||||||
logp = pt.switch( | ||||||
rt <= epsilon, | ||||||
|
@@ -342,11 +324,11 @@ def logp_ddm_sdv( | |||||
+ ( | ||||||
(a * z_flipped * sv) ** 2 | ||||||
- 2 * a * v_flipped * z_flipped | ||||||
- (v_flipped**2) * rt | ||||||
- (v_flipped**2) * tt | ||||||
) | ||||||
/ (2 * (sv**2) * rt + 2) | ||||||
- 0.5 * pt.log(sv**2 * rt + 1) | ||||||
- 2 * pt.log(a), | ||||||
/ (2 * (sv**2) * tt + 2) | ||||||
- 0.5 * pt.log(sv**2 * tt + 1) | ||||||
- 2 * pt.log(pt.maximum(epsilon, a)), | ||||||
Comment on lines
325
to
+332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Evaluate separately providing a meaningful name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are probably not going to keep this one. I just tried this to see if we keep the log positive we can get somewhere. It helps a bit it seems, but the culprit is not this one |
||||||
) | ||||||
|
||||||
checked_logp = check_parameters(logp, a >= 0, msg="a >= 0") | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would a better name for this boolean array be, maybe
mask
orsieve
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should
pt.lt
be used here as done elsewhere in this PR?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually equivalent but I was just playing around