Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nan gradients in analytical likelihood #468

Closed

Conversation

digicosmos86
Copy link
Collaborator

No description provided.

@digicosmos86 digicosmos86 linked an issue Jun 20, 2024 that may be closed by this pull request
Copy link
Collaborator

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, this is mostly about iterating conceptually, not code quality.

src/hssm/likelihoods/analytical.py Show resolved Hide resolved
src/hssm/likelihoods/analytical.py Show resolved Hide resolved
LOGP_LB,
tt = negative_rt * epsilon + (1 - negative_rt) * rt

p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick note,
it seems like we are only passing k_terms here, not actually computing k_terms.
I think we had agreed to do that way back on another iteration of trying to fix issues with this likelihood, and I think it's fine, but in this case we should make the default a bit higher than 7.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just playing around here. Not actually changing

- (v_flipped**2 * rt / 2.0)
- 2.0 * pt.log(a),
- (v_flipped**2 * tt / 2.0)
- 2.0 * pt.log(pt.maximum(epsilon, a))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reflecting on this a bit,
I think this maximum business is actually corrupting the gradients, so we should just a priori restrict a > epsilon (via prior essentially?).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the other hand, apart from initialization (which 1. our strategies should already avoid, 2. we generally can impact) a should basically never come close to 0, so this should basically never be the culprit...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this did help a bit, for some reason...

- (v_flipped**2 * rt / 2.0)
- 2.0 * pt.log(a),
- (v_flipped**2 * tt / 2.0)
- 2.0 * pt.log(pt.maximum(epsilon, a))
)

checked_logp = check_parameters(logp, a >= 0, msg="a >= 0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the spirit of above, this check could be a>0 but honestly we shouldn't really ever get there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -220,7 +199,7 @@ def logp_ddm(
z: float,
t: float,
err: float = 1e-15,
k_terms: int = 20,
k_terms: int = 7,
epsilon: float = 1e-15,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what was used for testing / is used as actual value for inference, but I guess it is this default?

The epsilon for the rt part should rather be on the order of 1e-3, or even 1e-2.

If we are reusing the same epsilon in multiple places, we should probably separate it out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was playing around. It seems that changing k_terms to 7 did not improve speed or computation

@@ -262,15 +241,17 @@ def logp_ddm(
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response
rt = rt - t

p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))
negative_rt = rt <= epsilon
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok reflecting on this a bit, the logic that we want should probably look something like:

  • flag all rts lower than epsilon
  • go through with ftt01w
  • then set all flagged rts to LOGB_LB

This should actually cut the gradient for problematic rts.
Potentially we put this as a logp_ddm_2 and compare results / gradients.
Alternatively, if any rt breaches epsilon, directly send logp to -infty (this is probably not preferable).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were doing this. I think the problem is that the gradient is computed anyway and the over/underflow was still happening

@AlexanderFengler
Copy link
Collaborator

@digicosmos86 is this stale for now?

@digicosmos86
Copy link
Collaborator Author

@digicosmos86 is this stale for now?

There doesn't seem to be a solution for really small RTs in the denominator, which can blow up

@frankmj
Copy link
Collaborator

frankmj commented Jul 9, 2024 via email

@digicosmos86
Copy link
Collaborator Author

@frankmj I ran a few more tests and the RT-hack did do the trick. It might be hard for us to implement this trick in our code though, mostly because people use arviz functions instead of the convenience functions that we provide, which could give us some control over the output. We could note this in our documentation somewhere about this trick so that the users can implement this themselves so that they have full control

@frankmj
Copy link
Collaborator

frankmj commented Jul 9, 2024 via email

@digicosmos86
Copy link
Collaborator Author

@frankmj That's a great idea! I also noticed that the RT-hack only worked when float64 is used, indicating some other issues that we might have with numerical stability. I'll look deeper into this

Copy link
Collaborator

@cpaniaguam cpaniaguam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things here I might end up picking up myself.

src/hssm/likelihoods/analytical.py Outdated Show resolved Hide resolved
Comment on lines +36 to +38
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1
_b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err))
_c = pt.sqrt(rt) + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fundamental operation is pt.sqrt(rt). It's better to do this first and reuse the result to avoid computing it again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical stability, it's better to group the constant factor C = 2 * pt.sqrt(2 * np.pi) * err and compare each member of sqrt_rt = pt.sqrt(rt) against 1/C.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Feel free to change this

ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err))
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0)
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2)
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would a better name for this boolean array be, maybe mask or sieve?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should pt.lt be used here as done elsewhere in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually equivalent but I was just playing around

_b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err))
_c = pt.sqrt(rt) + 1
_d = pt.max(pt.stack([_b, _c]), axis=0)
ks = _a * _d + (1 - _a) * 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because _a is boolean, I think it's better to treat it as such and use pt.switch.

Suggested change
ks = _a * _d + (1 - _a) * 2
ks = pt.switch(mask, _d, 2) # having renamed `_a` to `mask`, for example

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comment below

_b = 1.0 / (np.pi * pt.sqrt(rt))
_c = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt))
_d = pt.max(pt.stack([_b, _c]), axis=0)
kl = _a * _b + (1 - _a) * _b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_c and _d are not used. Should _d be used in the second term instead of _b? Otherwise kl will be _b.

Suggested change
kl = _a * _b + (1 - _a) * _b
kl = pt.switch(mask, _b, _d)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comment below

src/hssm/likelihoods/analytical.py Outdated Show resolved Hide resolved
src/hssm/likelihoods/analytical.py Outdated Show resolved Hide resolved
src/hssm/likelihoods/analytical.py Outdated Show resolved Hide resolved
logp = pt.where(
rt <= epsilon,
LOGP_LB,
tt = negative_rt * epsilon + (1 - negative_rt) * rt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tt = negative_rt * epsilon + (1 - negative_rt) * rt
tt = pt.switch(negative_rt, epsilon, rt)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually is done on purpose. pt.switch can cause some weird errors

Comment on lines 324 to +331
+ (
(a * z_flipped * sv) ** 2
- 2 * a * v_flipped * z_flipped
- (v_flipped**2) * rt
- (v_flipped**2) * tt
)
/ (2 * (sv**2) * rt + 2)
- 0.5 * pt.log(sv**2 * rt + 1)
- 2 * pt.log(a),
/ (2 * (sv**2) * tt + 2)
- 0.5 * pt.log(sv**2 * tt + 1)
- 2 * pt.log(pt.maximum(epsilon, a)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Evaluate separately providing a meaningful name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are probably not going to keep this one. I just tried this to see if we keep the log positive we can get somewhere. It helps a bit it seems, but the culprit is not this one

@digicosmos86
Copy link
Collaborator Author

@cpaniaguam Thanks for the suggestions! I committed all excluding those involving pt.switch. I thought pt.switch is a more readable alternative, but it caused some switch_sink errors that shows up only when float64 is used. Actually removing switch Ops allowed me to sample with float64 with out errors.

Please feel free to take this further. This PR wasn't final - was just a placeholder for some of my experiments

@AlexanderFengler
Copy link
Collaborator

@digicosmos86 let's use this PR to switch to float64 overall?

Also, the latest state of affairs with all changes in this PR is that it's still breaking right?

@digicosmos86
Copy link
Collaborator Author

You are correct. It is still broken. This PR is kind of my mess though. I'd rather start a new one and just switch out all the switch ops, which should get us over the float64 issue

@AlexanderFengler
Copy link
Collaborator

@digicosmos86 I am good with that approach.

@digicosmos86
Copy link
Collaborator Author

Since this is still in the works, I am going to convert it to a draft PR

@digicosmos86 digicosmos86 marked this pull request as draft August 21, 2024 14:30
@AlexanderFengler
Copy link
Collaborator

@digicosmos86 to be closed now that the other PR is up?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

nan grads when running find_MAP() on analytic, ddm
4 participants