-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conditional and joint span probabilities in TreeCRF #107
Comments
What a great question... If there are two specific spans that you need, then your "dirty trick" is the right way to do it. If you want to do it efficiently for any pair of spans, then there are some fun auto-diff tricks you can use. If I remember correctly, the hessian of the log-partition (with respect to the log-potentials) will give you the joint of all pairs of spans. I don't think this is currently implemented in the library, but wouldn't be that hard to add. maybe look at https://pytorch.org/docs/stable/generated/torch.autograd.functional.hessian.html |
Thank you :-). I am more interested in an efficient way for any pair of spans. |
I think this is a nice reference for bayes nets https://dl.acm.org/doi/pdf/10.1145/765568.765570 Alternatively you can think of CRF as exponential families and therefore the log-partition generates moments: https://www.cs.cmu.edu/~epxing/Class/10708-14/scribe_notes/scribe_note_lecture6.pdf I can't find a nice reference though to explain the hessian, but you can derive it from differentiating the log-partition twice If you are feeling brave it is also in this paper. |
That's super interesting, thank you for your help! I am going to try to use this for my project, but would you like me to make an attempt for a PR to add it to this library as well? In that case, do you have any pointers on where to add it? Some last questions :-) : Also, am I right to deduce that 3rd order partial derivatives of the partition function then give joint probabilities of 3 edges (assuming the 2nd order derivatives are differentiable)? Unless I'm making a mistake, this is easy to prove in the same way as proving the relation between the first and second order derivatives and the marginals and joint of 2 edges resp.? |
Yes, would love a PR. I think you can do this in the StructDistribution
class
https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/distributions.py.
Should in theory work for all the distributions in the library.
The main thing I would suggest though is that everything in torch-struct
needs to be general (not tree or sequence specific), and tested. The
testing part is done by enumerating all trees up to a certain size and
testing that the pairwise marginals are the same.
However, do not feel compelled. Would be happy just to hear if this works
for your use case.
|
I want to use the TreeCRF class to learn latent tree distributions for constituency trees for sentences. I noticed you can easily obtain the text span marginals with
.marginals
. However, I am interested in computing more probabilities in the tree distribution, like the conditional probability that one span occurs in the tree, given that another one occurs, or the joint probability of two spans. Is there an easy way to compute these probabilities from the marginals? Or using different torch-struct functionality?A 'dirty' trick for the conditional probability could be to compute the marginals again, with the potential of the span you want to condition on set to a very high value? The new marginals would then actually be conditional probabilities? But that requires running the parsing algorithm once per condition, which ideally I would like to avoid.
The text was updated successfully, but these errors were encountered: