Skip to content

Commit

Permalink
Merge branch 'beta' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
Dohyun-s authored Mar 1, 2023
2 parents 525a481 + b0425cd commit a339459
Show file tree
Hide file tree
Showing 12 changed files with 906 additions and 518 deletions.
4 changes: 4 additions & 0 deletions AlphaFold2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@
" # install dependencies\n",
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold@beta\"\n",
" pip uninstall -yq jax jaxlib\n",
" pip install -q \"jax[cuda]==0.3.25\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
" # for debugging\n",
" ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold\n",
Expand Down Expand Up @@ -270,6 +272,7 @@
"num_recycles = \"auto\" #@param [\"auto\", \"0\", \"1\", \"3\", \"6\", \"12\", \"24\", \"48\"]\n",
"recycle_early_stop_tolerance = \"auto\" #@param [\"auto\", \"0.0\", \"0.5\", \"1.0\"]\n",
"#@markdown - if `auto` will use `num_recycles=20 tol=0.5` for `model_type=alphafold2_multimer_v3`, else `num_recyles=3 tol=0.5`\n",
"set_cyclic_offset = False #@param {type:\"boolean\"}\n",
"\n",
"#@markdown ### Sample settings\n",
"max_msa = \"auto\" #@param [\"auto\", \"512:1024\", \"256:512\", \"64:128\", \"32:64\", \"16:32\"]\n",
Expand Down Expand Up @@ -382,6 +385,7 @@
" inputs_callback=inputs_callback,\n",
" outputs_callback=outputs_callback,\n",
" save_recycles=save_recycles,\n",
" cyclic=set_cyclic_offset\n",
")\n",
"results_zip = f\"{jobname}.result.zip\"\n",
"os.system(f\"zip -r {results_zip} {jobname}\")\n",
Expand Down
6 changes: 3 additions & 3 deletions AlphaFold2_batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"id": "G4yBrceuFbf3"
},
"source": [
"#ColabFold: AlphaFold2 w/ MMseqs2 BATCH\n",
"#ColabFold v1.6.0: AlphaFold2 w/ MMseqs2 BATCH\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"256\" align=\"right\" style=\"height:256px\">\n",
"\n",
Expand Down Expand Up @@ -120,8 +120,8 @@
" # install dependencies\n",
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold@beta\"\n",
" # high risk high gain\n",
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
" pip uninstall -yq jax jaxlib\n",
" pip install -q \"jax[cuda]==0.3.25\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
" touch COLABFOLD_READY\n",
"fi\n",
"\n",
Expand Down
16 changes: 14 additions & 2 deletions colabfold/alphafold/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from alphafold.model.modules import AlphaFold
from alphafold.model.modules_multimer import AlphaFold as AlphaFoldMultimer


def load_models_and_params(
num_models: int,
use_templates: bool,
num_recycles: Optional[int] = None,
recycle_early_stop_tolerance: Optional[float] = None,
num_ensemble: int = 1,
model_order: Optional[List[int]] = None,
model_suffix: str = "_ptm",
data_dir: Path = Path("."),
stop_at_score: float = 100,
Expand All @@ -23,6 +23,7 @@ def load_models_and_params(
use_fuse: bool = True,
use_bfloat16: bool = True,
use_dropout: bool = False,
use_masking: bool = True,
save_all: bool = False,
) -> List[Tuple[str, model.RunModel, haiku.Params]]:
"""We use only two actual models and swap the parameters to avoid recompiling.
Expand All @@ -34,7 +35,11 @@ def load_models_and_params(
# Use only two model and later swap params to avoid recompiling
model_runner_and_params: [Tuple[str, model.RunModel, haiku.Params]] = []

model_order = [1, 2, 3, 4, 5]
if model_order is None:
model_order = [1, 2, 3, 4, 5]
else:
model_order.sort()

model_build_order = [3, 4, 5, 1, 2]
if "multimer" in model_suffix:
models_need_compilation = [3]
Expand Down Expand Up @@ -77,6 +82,13 @@ def load_models_and_params(
model_config.model.embeddings_and_evoformer.num_extra_msa = max_extra_seq
else:
model_config.data.common.max_extra_msa = max_extra_seq

# disable masking
if not use_masking:
if "multimer" in model_suffix:
model_config.model.embeddings_and_evoformer.masked_msa.replace_fraction = 0.0
else:
model_config.data.eval.masked_msa_replace_fraction = 0.0

# disable some outputs if not being saved
if not save_all:
Expand Down
22 changes: 16 additions & 6 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os,sys
ENV = {"TF_FORCE_UNIFIED_MEMORY":"1", "XLA_PYTHON_CLIENT_MEM_FRACTION":"4.0"}
for k,v in ENV.items():
if k not in os.environ: os.environ[k] = v
Expand All @@ -13,7 +13,7 @@
from pathlib import Path
import random

from colabfold.run_alphafold import run
from colabfold.run_alphafold import run, set_model_type
from colabfold.utils import (
DEFAULT_API_SERVER, ACCEPT_DEFAULT_TERMS,
get_commit, setup_logging
Expand Down Expand Up @@ -136,7 +136,7 @@ def main():
choices=["auto", "plddt", "ptm", "iptm", "multimer"],
)
parser.add_argument("--pair-mode",
help="rank models by auto, unpaired, paired, unpaired_paired",
help="how to generate MSA for multimeric inputs: unpaired, paired, unpaired_paired",
type=str,
default="unpaired_paired",
choices=["unpaired", "paired", "unpaired_paired"],
Expand All @@ -157,12 +157,16 @@ def main():
action="store_true",
help="saves the pair representation embeddings of all models",
)
parser.add_argument(
"--use-dropout",
parser.add_argument("--use-dropout",
default=False,
action="store_true",
help="activate dropouts during inference to sample from uncertainity of the models",
)
parser.add_argument("--disable-masking",
default=False,
action="store_true",
help='by default, 15% of the input MSA is randomly masked, set this flag to disable this',
)
parser.add_argument("--max-seq",
help="number of sequence clusters to use",
type=int,
Expand Down Expand Up @@ -203,6 +207,9 @@ def main():
parser.add_argument("--interaction-scan", default=False, action="store_true")
parser.add_argument("--disable-cluster-profile", default=False, action="store_true")

parser.add_argument("--cyclic", default=False, action="store_true")
parser.add_argument("--save-best", default=False, action="store_true")

# backward compatability
parser.add_argument('--training', default=False, action="store_true", help=argparse.SUPPRESS)
parser.add_argument('--templates', default=False, action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -283,12 +290,15 @@ def main():
save_single_representations=args.save_single_representations,
save_pair_representations=args.save_pair_representations,
use_dropout=args.use_dropout,
use_masking=not args.disable_masking,
max_seq=args.max_seq,
max_extra_seq=args.max_extra_seq,
use_cluster_profile=not args.disable_cluster_profile,
use_gpu_relax = args.use_gpu_relax,
use_gpu_relax=args.use_gpu_relax,
save_all=args.save_all,
save_recycles=args.save_recycles,
cyclic=args.cyclic,
save_best=args.save_best,
)

if args.interaction_scan:
Expand Down
Loading

0 comments on commit a339459

Please sign in to comment.