-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add alignment CRF test. Fix missing fill_() #109
base: master
Are you sure you want to change the base?
Conversation
I realised I hadn't tested many of the distribution properties. I've tried tests for few more but it looks like there are at least two more issues to resolve. |
torch_struct/alignment.py
Outdated
charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_( | ||
charta[1][:, b, point:, 1, ind, :, :, Mid] | ||
) | ||
charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this is not going to work.
We need to call
init = torch.zeros(charta[1].shape).bool()
init[:, b, point:, 1, ind, :, :, Mid].fill_(True)
charta[1] = semiring.fill(charta[1], init, semiring.one)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this should fix your other issues too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for that. I have to admit I had just copied the code from the one_()
method before it was removed in #105. My assumption was that it was the correct code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still facing a few issues though. I fixed a few of them in the commits below but some remain. The main sticking point seems to be that the BandedMatrix
s are not correctly dispatched to multiply
rather than matmul()
in semirings.py
. The matmul
implementation only works for standard tensors. This affects dist.entropy, dist.sample(), dist.topk()
but not the partition, argmax, marginals
.
I tried to fix this rather naively by overloading the classmethod matmul
in some of the semirings but this broke the existing tests. I backed that out and am trying to understand how the code relates to the description in the torch struct paper so that I can make the correct fix.
PR following up discussion here.
For the tests to pass I also had to update
genbmm
. See PR here.Note that the tests only check the shape of the
argmax
andmarginals
. The values are not checked.