-
Notifications
You must be signed in to change notification settings - Fork 534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added multi-chain permutation steps, multimer datamodule, and training code for multimer #336
Conversation
…optimal_transform
added multi-chain permutation to AlphaFoldMultimerLoss
Modify assignment stage
added openfold multimer dataloader class and overwrite batch processing
…nto a pytorch tensor
Added Multimer dataloader and training scripts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial review
README.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this to the original OF readme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing. Done in commit b8d0069
openfold/config.py
Outdated
@@ -163,6 +163,9 @@ def model_config( | |||
for k,v in multimer_model_config_update.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, can you change this so that the model components are in a 'model' dict so that it matches the loss dict: multimer_model_config_update['model']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I have added a key called 'model' inside the multimer_model_config_update
openfold/data/data_pipeline.py
Outdated
@@ -830,6 +831,15 @@ def read_template(start, size): | |||
with open(path, "r") as fp: | |||
hits = parsers.parse_hhr(fp.read()) | |||
all_hits[f] = hits | |||
fp.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp is automatically closed when using with statement, can remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure it's removed
openfold/data/data_pipeline.py
Outdated
@@ -830,6 +831,15 @@ def read_template(start, size): | |||
with open(path, "r") as fp: | |||
hits = parsers.parse_hhr(fp.read()) | |||
all_hits[f] = hits | |||
fp.close() | |||
|
|||
elif (ext =='.sto') and (f.startswith("pdb")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The template file would be named hmm_output.sto right? Is the startswith "pdb" in reference to something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry my mistake. I saw in the test_data there is a pdb70_hits.hhr file and I thought the hmm results will also be named as such. I've changed the 2nd condition so that it check whether the file starts with 'hmm'
openfold/model/model.py
Outdated
@@ -535,7 +535,10 @@ def forward(self, batch): | |||
|
|||
# Enable grad iff we're training and it's the final recycling layer | |||
is_final_iter = cycle_no == (num_iters - 1) or early_stop | |||
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): | |||
enable_grad= is_grad_enabled and is_final_iter | |||
if (type(enable_grad)!=bool) and (type(enable_grad)==torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both is_grad_enabled and is_final_iter should be type bool, I'm going to check what is causing it to be a tensor, but this should be removed once that is fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I have reversed this part back to the original. I couldn't figure out how this happened. It was a boolean in the first couple of iterations then changed into a tensor with the boolean value in it at some point and gave me this TypeError: enabled must be a bool (got Tensor)
error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the test data alignments, can we compress them and include fewer examples in general
openfold/utils/loss.py
Outdated
@@ -298,10 +310,10 @@ def fape_loss( | |||
interface_bb_loss = backbone_loss( | |||
traj=traj, | |||
pair_mask=1. - intra_chain_mask, | |||
**{**batch, **config.interface_backbone}, | |||
**{**batch, **config.intra_chain_backbone}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the interface backbone loss, why is it using the intra_chain_backbone config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you're right but in the config.py, there is no such a key called "interface_backbone" in loss config_dict. I suppose it should be "interface" in line 850 of config.py? I have changed this part of loss.py to config.interface
Lines 850 to 853 in bc35ef1
"interface": { | |
"clamp_distance": 30.0, | |
"loss_unit_distance": 20.0, | |
"weight": 0.5, |
…hain backbone weights
…t be a bool (got Tensor)' error still persists
…the new permutation unittest
chains = asym_id.unique() | ||
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) | ||
chains, _ = asym_id.unique(return_counts=True) | ||
one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64)-1, # have to reduce asym_id by one because class values must be smaller than num_classes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I kept my modifications when I solved this conflict as asym_id starts from 1, we should deduct it by 1 so that the class values is always smaller than the number of classes. Otherwise, pytorch throws an error.
@christinaflo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah I was going to merge your changes in with this PR, I only wanted to remove the return_counts=True because I wasn't using the returned counts.
@@ -529,9 +541,9 @@ def lddt_loss( | |||
cutoff=cutoff, | |||
eps=eps | |||
) | |||
|
|||
score = torch.nan_to_num(score,nan=torch.nanmean(score)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I added this checking on NaN in the predicted lddt scores because the ground truth structure I used for initial unittest was completely irrelevant to the fake features I generated. As the result, I got NaN or negative values here. Perhaps it'd be better if this part is removed in the real training code
@christinaflo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we should remove this, I was going to reformat some things after merging this PR in so I can just remove it then.
CUDA error: device-side assert triggered
still persists when slicing a tensor