-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question on the Complexity of CKY #99
Comments
So just to unify terminology The way we compute linear-chain CRF in O(log N) time is by viewing it as a single balanced binary tree of height For CKY, we need to at least consider all trees of height O(N) since we may have fully right-branching trees. As far as I know there is no way to do better then O(N) serial operations, one for each layer (i.e. width of span). The other nasty thing about CKY is that the shape of operations changes drastically as you go up the tree. At the bottom layer, there are N spans, with 1 child sizes, whereas at the top layer there is 1 span with N children sizes. This is an extremely non-GPU friendly operation. (Still haven't figured out the ideal way to compute it). |
Thanks for the clarification. Now I understand the situation. One more question regarding the decoding procedure for Linear-chain CRF.
My question: is it true that we can't avoid the O(N) to do back-tracking? For example, we use an O(log N) algorithm to obtain the matrix with size |
You can do O(log N) [parallel] backtracking. The trick in the code is that we never implement the backward / backpointer step. We really on the fact that in pytorch the (sub)gradient used for the max operator is the 1-hot argmax vector. Therefore if you compute the max score then calling .backward will give you the argmax / viterbi sequence. If you look at the code here (https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/distributions.py#L123) you will see that is exactly what it does. (This trick is really cool, and should be documented much better in the codebase.) |
Thanks. That helps a lot. I think I need some more time to figure out the details of What I did at my side at the moment is to implement the complete parallel scan algorithm, for inference and backtracking. (https://github.com/allanj/pytorch_lstmcrf/blob/40f8980dde/src/model/module/fast_linear_crf_inferencer.py)Maybe this is stupid, :(. But just help me better understand the procedure. Right now, I think in my back-tracking code, the memory could be larger, because during the parallel scan "forward/backward" pass in the tree, I store a N-length vector to represent the best sequence in each node. |
Neat! Yeah this should result in the same outcome. My assertion though is that you do not need to write the backward pass manually. It might lead to some speed-ups though in practice. |
Btw, your repo looks really nice. If you wanted to build some more transformer-backed, structured prediction models, I would be happy to collaborate. Would be nice to have a single repo with a bert tagger / ner / cky all backed by Hugging Face and HF datasets |
Sure. That would be great. One of my goals is exactly building these models incorporated with the current HF backend. CKY is something that I'm really looking forward to. Another thing is the general hypergraph framework (from a research perspective), though I think it is still pretty challenging to implement the general framework in PyTorch in practice. Thus, I come to learn from this repo as well. |
Neat! Me too. General hypergraphs would be really interesting. But I think we would need to write that in CUDA or TVM manually. The one thing I care a lot about is testing. I want to be sure that when people implement CRF they are really computing the partition function. |
So, if I view the linear-chain CRF as a specific case of Tree, where the height
H
= sequence lengthN
, so the complexity can be re-written as O(log H).The text was updated successfully, but these errors were encountered: