diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..9fe17bceb --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ \ No newline at end of file diff --git a/README.md b/README.md index ad8157923..d4e6c60c3 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) mmcif_files/ # About 180,000 .cif files. obsolete.dat - small_fbd/ # ~ 17 GB (download: 9.6 GB) + small_bfd/ # ~ 17 GB (download: 9.6 GB) bfd-first_non_consensus_sequences.fasta uniclust30/ # ~ 86 GB (download: 24.9 GB) uniclust30_2018_08/ @@ -273,6 +273,10 @@ The contents of each output file are as follows: serve for a visualisation of domain packing confidence within the structure. +The pLDDT confidence measure is stored in the B-factor field of the output PDB +files (although unlike a B-factor, higher pLDDT is better, so care must be taken +when using for tasks such as molecular replacement). + This code has been tested to match mean top-1 accuracy on a CASP14 test set with pLDDT ranking over 5 model predictions (some CASP targets were run with earlier versions of AlphaFold and some had manual interventions; see our forthcoming @@ -319,7 +323,7 @@ For genetics: For templates: * PDB: (downloaded 2020-05-14) -* PDB70: (downloaded 2020-05-13) +* PDB70: [2020-05-13](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200513.tar.gz) An alternative for templates is to use the latest PDB and PDB70, but pass the flag `--max_template_date=2020-05-14`, which restricts templates only to diff --git a/alphafold/common/protein.py b/alphafold/common/protein.py index 314d40160..2848f5bbc 100644 --- a/alphafold/common/protein.py +++ b/alphafold/common/protein.py @@ -194,7 +194,7 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray: `Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function computes a mask according to heavy atoms - that should be present in the given seqence of amino acids. + that should be present in the given sequence of amino acids. Args: prot: `Protein` whose fields are `numpy.ndarray` objects. diff --git a/alphafold/data/pipeline_mod.py b/alphafold/data/pipeline_mod.py new file mode 100644 index 000000000..38fdfae68 --- /dev/null +++ b/alphafold/data/pipeline_mod.py @@ -0,0 +1,134 @@ +"""Modular version of alphafold.data.pipeline""" + +import os +from dataclasses import dataclass +from typing import Mapping, Optional, Sequence +from absl import logging +from alphafold.data import parsers +from alphafold.data import templates +from alphafold.data.tools.cli import * +from alphafold.data.pipeline import make_sequence_features +import numpy as np + +# Internal import (7716). + +FeatureDict = Mapping[str, np.ndarray] + + +@dataclass +class ModularDataPipeline: + """Modular version of alphafold.data.pipeline.DataPipeline""" + jackhmmer_binary_path: str + hhblits_binary_path: str + hhsearch_binary_path: str + uniref90_database_path: str + mgnify_database_path: str + pdb70_database_path: str + use_small_bfd: bool + + # for construction of TemplateHitFeaturizer, replacing + # template_featurizer: templates.TemplateHitFeaturizer + mmcif_dir: str + max_template_date: str + max_hits: int + kalign_binary_path: str + release_dates_path: str = None + obsolete_pdbs_path: str = None + strict_error_check: bool = False + + mgnify_max_hits: int = 501 + uniref_max_hits: int = 10000 + bfd_database_path: str = None + uniclust30_database_path: str = None + small_bfd_database_path: str = None + + def jackhmmer_uniref90(self, input_fasta_path: str): + return jackhmmer( + input_fasta_path=input_fasta_path, + jackhmmer_binary_path=self.jackhmmer_binary_path, + database_path=self.uniref90_database_path, + fname='uniref90_hits.sto', + output_dir=self.msa_output_dir + ) + + def jackhmmer_mgnify(self, input_fasta_path: str): + return jackhmmer( + input_fasta_path=input_fasta_path, + jackhmmer_binary_path=self.jackhmmer_binary_path, + database_path=self.mgnify_database_path, + fname='mgnify.sto', + output_dir=self.msa_output_dir + ) + + def hhsearch_pdb70(self, jackhmmer_uniref90_hits_path): + return hhsearch_pdb70( + jackhmmer_uniref90_hits_path=jackhmmer_uniref90_hits_path, + hhsearch_binary_path=self.hhsearch_binary_path, + pdb70_database_path=self.pdb70_database_path, + uniref_max_hits=self.uniref_max_hits, + output_dir=self.msa_output_dir + ) + + def jackhmmer_small_bfd(self, input_fasta_path): + return jackhmmer( + input_fasta_path=input_fasta_path, + jackhmmer_binary_path=self.jackhmmer_binary_path, + database_path=self.small_bfd_database_path, + fname='small_bfd_hits.sto', + output_dir=self.msa_output_dir + ) + + def hhblits(self, input_fasta_path): + return hhblits( + input_fasta_path=input_fasta_path, + hhblits_binary_path=self.hhblits_binary_path, + bfd_database_path=self.bfd_database_path, + uniclust30_database_path=self.uniclust30_database_path, + output_dir=self.msa_output_dir + ) + + def template_featurize(self, input_fasta_path, hhsearch_hits_path): + return template_featurize( + input_fasta_path=input_fasta_path, + hhsearch_hits_path=hhsearch_hits_path, + mmcif_dir=self.mmcif_dir, + max_template_date=self.max_template_date, + max_hits=self.max_hits, + kalign_binary_path=self.kalign_binary_path, + release_dates_path=self.release_dates_path, + obsolete_pdbs_path=self.obsolete_pdbs_path, + strict_error_check=self.strict_error_check + ) + + def make_msa_features(self, jackhmmer_uniref90_hits_path, jackhmmer_mgnify_hits_path, + bfd_hits_path): + return make_msa_features(jackhmmer_uniref90_hits_path, jackhmmer_mgnify_hits_path, + bfd_hits_path, + mgnify_max_hits=self.mgnify_max_hits, + use_small_bfd=self.use_small_bfd) + + def make_sequence_features(self, input_fasta_path): + input_sequence, input_description, num_res = parse_fasta_path(input_fasta_path) + return make_sequence_features(sequence=input_sequence, + description=input_description, + num_res=num_res) + + def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + self.msa_output_dir = msa_output_dir + + jackhmmer_uniref90_hits_path = self.jackhmmer_uniref90(input_fasta_path) + hhsearch_hits_path = self.hhsearch_pdb70(jackhmmer_uniref90_hits_path) + template_features = self.template_featurize(input_fasta_path, hhsearch_hits_path) + + if self.use_small_bfd: + bfd_hits_path = self.jackhmmer_small_bfd(input_fasta_path) + else: + bfd_hits_path = self.hhblits(input_fasta_path) + + jackhmmer_mgnify_hits_path = self.jackhmmer_mgnify(input_fasta_path) + sequence_features = self.make_sequence_features(input_fasta_path) + msa_features = self.make_msa_features(jackhmmer_uniref90_hits_path, + jackhmmer_mgnify_hits_path, + bfd_hits_path) + return {**sequence_features, **msa_features, **template_features} diff --git a/alphafold/data/templates.py b/alphafold/data/templates.py index 24dd93963..a9fc45865 100644 --- a/alphafold/data/templates.py +++ b/alphafold/data/templates.py @@ -885,7 +885,7 @@ def get_templates( errors.append(result.error) # There could be an error even if there are some results, e.g. thrown by - # other unparseable chains in the same mmCIF file. + # other unparsable chains in the same mmCIF file. if result.warning: warnings.append(result.warning) diff --git a/alphafold/data/tools/cache_utils.py b/alphafold/data/tools/cache_utils.py new file mode 100644 index 000000000..e6c7c6307 --- /dev/null +++ b/alphafold/data/tools/cache_utils.py @@ -0,0 +1,123 @@ +import os +import json +import pickle +from functools import lru_cache +import hashlib +import logging + +DEFAULT_CACHE_DIR = '.cache' + + +@lru_cache(maxsize=32) +def hash_fp(fp): + with open(fp, 'rb') as f: + return md5_hash(f.read()) + + +def md5_hash(s: str) -> str: + return hashlib.md5(s).hexdigest() + + +def cache_key_with_hashed_paths(args, kwargs): + """Hashable key of function name, args, and kwargs. For kwarg names ending + in '_path', attempt to hash the file and set that as the key. + """ + kw = dict() + for k in kwargs: + assert isinstance(k, str) + if k.endswith('_path'): + kw[k] = hash_fp(kwargs[k]) + else: + kw[k] = kwargs[k] + return (args, frozenset(kw.items())) + + +def order_dict(d: dict) -> tuple: + return tuple(((k, d[k]) for k in sorted(d.keys()))) + +def looks_like_path(s: str, suffixes: tuple = ('_path', '_dir')) -> bool: + if not isinstance(s, str): + return False + return any(s.endswith(suf) for suf in suffixes) + +def normpath(s: str) -> str: + return os.path.abspath(os.path.normpath(s)) + +def cache_key(args, kwargs): + # convert path-like kwargs to their normal absolute paths + kw = dict() + for k, v in kwargs.items(): + if v is None: + kw[k] = v + elif looks_like_path(k): + kw[k] = normpath(v) + else: + kw[k] = v + + obj = (args, order_dict(kw)) + + # JSON serialize ordered (kw)args + serialized = json.dumps(obj).encode('utf-8') + logging.info(f"using serialized={serialized}") + return md5_hash(serialized) + + +def cache_to_pckl(cache_dir=None, exclude_kw=None, use_pckl=True): + """Caches function results to pickle file. Returns a decorator factory. + Pickled function results are cached to a path `cache_dir/func.__name__/hash` + where hash is hashed args and kwargs (except kwargs listed in `exclude_kw`). + If the cache file exists, return the unpickled result. If it does not exist, + or if environment variable `AF2_SKIP_PCKL_CACHE=1`, run the function and write + its result to the cache. + """ + if exclude_kw is None: + exclude_kw = list() + elif isinstance(exclude_kw, str): + exclude_kw = [exclude_kw] + + if cache_dir is None: + cache_dir = os.environ.get('AF2_CACHE_DIR', None) + if cache_dir is None: + cache_dir = DEFAULT_CACHE_DIR + + # whether to use pickle or plain text cache + cache_ext = 'pckl' if use_pckl is True else 'out' + AF2_SKIP_PCKL_CACHE = os.environ.get('AF2_SKIP_PCKL_CACHE', 0) + + def decorator(fn): + def wrapped(*args, **kwargs): + kw = {k: kwargs[k] for k in kwargs if k not in exclude_kw} + key = cache_key(args, kw) + cache_fp = os.path.join(cache_dir, fn.__name__, f"{key}.{cache_ext}") + + logging.debug(f"using cache_fp={cache_fp} (AF2_SKIP_PCKL_CACHE={AF2_SKIP_PCKL_CACHE})") + + if os.path.exists(cache_fp) and not AF2_SKIP_PCKL_CACHE: + logging.info(f"using cache at {cache_fp} instead of running {fn.__name__}") + if use_pckl: + with open(cache_fp, 'rb') as f: + return pickle.load(f) + else: + with open(cache_fp, 'r') as f: + return f.read() + else: + logging.info(f"no cache found at {cache_fp}") + + result = fn(*args, **kwargs) + + # write to cache file + os.makedirs(os.path.dirname(cache_fp), exist_ok=True) + logging.info(f"saving results to {cache_fp}") + if use_pckl: + with open(cache_fp, 'wb') as f: + pickle.dump(result, f, protocol=4) + else: + with open(cache_fp, 'w') as f: + f.write(result) + + return result + + return wrapped + + return decorator + diff --git a/alphafold/data/tools/cli.py b/alphafold/data/tools/cli.py new file mode 100644 index 000000000..15f20580d --- /dev/null +++ b/alphafold/data/tools/cli.py @@ -0,0 +1,154 @@ +# -------------------- Functional interface for Click CLI ---------------------- +import os +import json +import click +from collections import OrderedDict +import pickle +from functools import lru_cache +import hashlib +from typing import Mapping, Optional, Sequence +from alphafold.data.tools import hhblits +from alphafold.data.tools import hhsearch +from alphafold.data.tools import jackhmmer as jackhmmer_wrapper +from alphafold.data import parsers, templates, pipeline +from alphafold.data.tools.cache_utils import cache_to_pckl +import logging + +logging.basicConfig(level=logging.DEBUG) + +@click.group() +def cli(): + pass + + +def parse_fasta_path(input_fasta_path): + """Given fasta file at `input_fasta_path`, calls `parsers.parse_fasta` + returning tuple of input sequence, input description, and sequence length. + """ + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + return (input_seqs[0], input_descs[0], len(input_seqs[0])) + + +def write_output(data, fname, output_dir) -> str: + fp = os.path.join(output_dir, fname) + with open(fp, 'w') as f: + f.write(data) + return fp + + +# ------------------------------------------------------------------------------ + +@cli.command(name='jackhmmer') +@click.option('--input-fasta-path', required=True, type=click.Path()) +@click.option('--jackhmmer-binary-path', required=True, type=click.Path()) +@click.option('--database-path', required=True, type=click.Path()) +@click.option('--output-dir', required=True, type=click.Path(file_okay=False)) +def jackhmmer_cli(*args, **kwargs): + return jackhmmer(*args, **kwargs) + + +@cache_to_pckl(exclude_kw=['output_dir', 'fname']) +def jackhmmer(input_fasta_path: str, jackhmmer_binary_path: str, + database_path: str, fname: str, output_dir: str): + jackhmmer_uniref90_runner = jackhmmer_wrapper.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=database_path) + result = jackhmmer_uniref90_runner.query(input_fasta_path)[0]['sto'] + return write_output(result, fname, output_dir=output_dir) + + +@cache_to_pckl(exclude_kw='output_dir') +def hhsearch_pdb70(jackhmmer_uniref90_hits_path, hhsearch_binary_path: str, + pdb70_database_path: str, output_dir: str, + uniref_max_hits): + with open(jackhmmer_uniref90_hits_path, 'r') as f: + jackhmmer_uniref90_hits = f.read() + hhsearch_pdb70_runner = hhsearch.HHSearch( + binary_path=hhsearch_binary_path, + databases=[pdb70_database_path]) + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( + jackhmmer_uniref90_hits, max_sequences=uniref_max_hits) + result = hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) + return write_output(result, 'pdb70_hits.hhr', output_dir=output_dir) + + +@cache_to_pckl(exclude_kw='output_dir') +def hhblits(input_fasta_path: str, hhblits_binary_path: str, + bfd_database_path: str, uniclust30_database_path: str, + output_dir: str): + hhblits_bfd_uniclust_runner = hhblits.HHBlits( + binary_path=hhblits_binary_path, + databases=[bfd_database_path, uniclust30_database_path]) + result = hhblits_bfd_uniclust_runner.query(input_fasta_path)['a3m'] + return write_output(result, 'bfd_uniclust_hits.a3m', output_dir=output_dir) + + +@cache_to_pckl() +def template_featurize(input_fasta_path, hhsearch_hits_path, mmcif_dir: str, + max_template_date, max_hits, kalign_binary_path, + release_dates_path, obsolete_pdbs_path, + strict_error_check): + with open(hhsearch_hits_path, 'r') as f: + hhsearch_hits = parsers.parse_hhr(f.read()) + template_featurizer = templates.TemplateHitFeaturizer( + mmcif_dir=mmcif_dir, + max_template_date=max_template_date, + max_hits=max_hits, + kalign_binary_path=kalign_binary_path, + release_dates_path=release_dates_path, + obsolete_pdbs_path=obsolete_pdbs_path, + strict_error_check=strict_error_check, + ) + input_sequence, _, _ = parse_fasta_path(input_fasta_path) + features = template_featurizer.get_templates( + query_sequence=input_sequence, + query_pdb_code=None, + query_release_date=None, + hits=hhsearch_hits).features + logging.info('Total number of templates (NB: this can include bad ' + 'templates and is later filtered to top 4): %d.', + features['template_domain_names'].shape[0]) + return features + + +@cache_to_pckl() +def make_msa_features(jackhmmer_uniref90_hits_path, jackhmmer_mgnify_hits_path, + bfd_hits_path, mgnify_max_hits: int, use_small_bfd: bool): + with open(jackhmmer_uniref90_hits_path, 'r') as f: + uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(f.read()) + + with open(jackhmmer_mgnify_hits_path, 'r') as f: + mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(f.read()) + + with open(bfd_hits_path, 'r') as f: + bfd_hits = f.read() + + mgnify_msa = mgnify_msa[:mgnify_max_hits] + mgnify_deletion_matrix = mgnify_deletion_matrix[:mgnify_max_hits] + + if use_small_bfd: + bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(bfd_hits) + else: + bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(bfd_hits) + + msa_features = pipeline.make_msa_features( + msas=(uniref90_msa, bfd_msa, mgnify_msa), + deletion_matrices=(uniref90_deletion_matrix, + bfd_deletion_matrix, + mgnify_deletion_matrix)) + + logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) + logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) + logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) + logging.info('Final (deduplicated) MSA size: %d sequences.', + msa_features['num_alignments'][0]) + return msa_features + + +if __name__ == '__main__': + cli() \ No newline at end of file diff --git a/alphafold/model/all_atom.py b/alphafold/model/all_atom.py index 678e224c1..c8ebe8b08 100644 --- a/alphafold/model/all_atom.py +++ b/alphafold/model/all_atom.py @@ -578,10 +578,10 @@ def extreme_ca_ca_distance_violations( residue_index: jnp.ndarray, # (N) max_angstrom_tolerance=1.5 ) -> jnp.ndarray: - """Counts residues whose Ca is a large distance from its neighbor. + """Counts residues whose Ca is a large distance from its neighbour. - Measures the fraction of CA-CA pairs between consectutive amino acids that - are more than 'max_angstrom_tolerance' apart. + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. Args: pred_atom_positions: Atom positions in atom37/14 representation diff --git a/alphafold/model/tf/protein_features.py b/alphafold/model/tf/protein_features.py index 66fc258b1..c78cfa5ea 100644 --- a/alphafold/model/tf/protein_features.py +++ b/alphafold/model/tf/protein_features.py @@ -93,7 +93,7 @@ def shape(feature_name: str, Args: feature_name: String identifier for the feature. If the feature name ends - with "_unnormalized", theis suffix is stripped off. + with "_unnormalized", this suffix is stripped off. num_residues: The number of residues in the current domain - some elements of the shape can be dynamic and will be replaced by this value. msa_length: The number of sequences in the multiple sequence alignment, some diff --git a/docker/Dockerfile b/docker/Dockerfile index a2e081432..b40b68e8e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -42,7 +42,7 @@ RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh && popd \ && rm -rf /tmp/hh-suite -# Install Miniconda package manger. +# Install Miniconda package manager. RUN wget -q -P /tmp \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ diff --git a/job.slurm b/job.slurm new file mode 100644 index 000000000..1fb802828 --- /dev/null +++ b/job.slurm @@ -0,0 +1,24 @@ +#!/bin/bash +# job.slurm +# ----------------------------------------------------------------- +#SBATCH -J af2_eho # Job name +#SBATCH -o af2_eho.%j.out # Name of stdout output file +#SBATCH -e af2_eho.%j.err # Name of stderr output file +#SBATCH -p normal # Queue (partition) name +#SBATCH -N 1 # Total # of nodes +#SBATCH -n 1 # Total # of mpi tasks +#SBATCH -t 14:00:00 # Run time (hh:mm:ss) +#SBATCH -A SD2E-Community # Project/Allocation name +# ----------------------------------------------------------------- + +module unload xalt +module load tacc-singularity +module list + +export SIF='/scratch/projects/tacc/bio/alphafold/images/alphafold_2.0.0.sif' +export AF2_HOME='/scratch/projects/tacc/bio/alphafold/' + +singularity exec $SIF python3 run_alphafold.py --flagfile=$AF2_HOME/test-container/flags/reduced_dbs.ff \ + --fasta_paths=$AF2_HOME/test-container/input/sample.fasta \ + --output_dir=$SCRATCH/af2_reduced \ + --model_names=model_1 \ No newline at end of file diff --git a/notebooks/AlphaFold.ipynb b/notebooks/AlphaFold.ipynb index 85f621b39..2ce4ccf44 100644 --- a/notebooks/AlphaFold.ipynb +++ b/notebooks/AlphaFold.ipynb @@ -125,7 +125,10 @@ " %shell rm -rf alphafold\n", " %shell git clone {GIT_REPO} alphafold\n", " pbar.update(8)\n", - " %shell pip3 install ./alphafold\n", + " # Install the required versions of all dependencies.\n", + " %shell pip3 install -r ./alphafold/requirements.txt\n", + " # Run setup.py to install only AlphaFold.\n", + " %shell pip3 install --no-dependencies ./alphafold\n", " pbar.update(10)\n", "\n", " # Apply OpenMM patch.\n", diff --git a/run_alphafold.py b/run_alphafold.py index beebd4ced..d825b0823 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -26,7 +26,8 @@ from absl import flags from absl import logging from alphafold.common import protein -from alphafold.data import pipeline +from alphafold.common import residue_constants +from alphafold.data import pipeline, pipeline_mod from alphafold.data import templates from alphafold.model import data from alphafold.model import config @@ -158,15 +159,22 @@ def predict_structure( timings[f'predict_benchmark_{model_name}'] = time.time() - t_0 # Get mean pLDDT confidence metric. - plddts[model_name] = np.mean(prediction_result['plddt']) + plddt = prediction_result['plddt'] + plddts[model_name] = np.mean(plddt) # Save the model outputs. result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') with open(result_output_path, 'wb') as f: pickle.dump(prediction_result, f, protocol=4) - unrelaxed_protein = protein.from_prediction(processed_feature_dict, - prediction_result) + # Add the predicted LDDT in the b-factor column. + # Note that higher predicted LDDT value means higher model confidence. + plddt_b_factors = np.repeat( + plddt[:, None], residue_constants.atom_type_num, axis=-1) + unrelaxed_protein = protein.from_prediction( + features=processed_feature_dict, + result=prediction_result, + b_factors=plddt_b_factors) unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') with open(unrelaxed_pdb_path, 'w') as f: @@ -226,15 +234,7 @@ def main(argv): if len(fasta_names) != len(set(fasta_names)): raise ValueError('All FASTA paths must have a unique basename.') - template_featurizer = templates.TemplateHitFeaturizer( - mmcif_dir=FLAGS.template_mmcif_dir, - max_template_date=FLAGS.max_template_date, - max_hits=MAX_TEMPLATE_HITS, - kalign_binary_path=FLAGS.kalign_binary_path, - release_dates_path=None, - obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) - - data_pipeline = pipeline.DataPipeline( + data_pipeline = pipeline_mod.ModularDataPipeline( jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, hhblits_binary_path=FLAGS.hhblits_binary_path, hhsearch_binary_path=FLAGS.hhsearch_binary_path, @@ -244,8 +244,16 @@ def main(argv): uniclust30_database_path=FLAGS.uniclust30_database_path, small_bfd_database_path=FLAGS.small_bfd_database_path, pdb70_database_path=FLAGS.pdb70_database_path, - template_featurizer=template_featurizer, - use_small_bfd=use_small_bfd) + use_small_bfd=use_small_bfd, + + # for construction of TemplateHitFeaturizer, replacing + # template_featurizer=template_featurizer, + mmcif_dir=FLAGS.template_mmcif_dir, + max_template_date=FLAGS.max_template_date, + max_hits=MAX_TEMPLATE_HITS, + kalign_binary_path=FLAGS.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) model_runners = {} for model_name in FLAGS.model_names: diff --git a/run_alphafold_test.py b/run_alphafold_test.py index 91aec1f25..a9194a9c1 100644 --- a/run_alphafold_test.py +++ b/run_alphafold_test.py @@ -45,7 +45,7 @@ def test_end_to_end(self): 'predicted_lddt': { 'logits': np.ones((10, 50)), }, - 'plddt': np.zeros(10), + 'plddt': np.ones(10) * 42, 'ptm': np.array(0.), 'aligned_confidence_probs': np.zeros((10, 10, 50)), 'predicted_aligned_error': np.zeros((10, 10)), @@ -71,6 +71,22 @@ def test_end_to_end(self): benchmark=False, random_seed=0) + base_output_files = os.listdir(out_dir) + self.assertIn('target.fasta', base_output_files) + self.assertIn('test', base_output_files) + + target_output_files = os.listdir(os.path.join(out_dir, 'test')) + self.assertSequenceEqual( + ['features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', + 'relaxed_model1.pdb', 'result_model1.pkl', 'timings.json', + 'unrelaxed_model1.pdb'], target_output_files) + + # Check that pLDDT is set in the B-factor column. + with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f: + for line in f: + if line.startswith('ATOM'): + self.assertEqual(line[61:66], '42.00') + if __name__ == '__main__': absltest.main() diff --git a/scripts/download_all_data.sh b/scripts/download_all_data.sh index cbe003cf3..c88581067 100755 --- a/scripts/download_all_data.sh +++ b/scripts/download_all_data.sh @@ -31,7 +31,7 @@ fi DOWNLOAD_DIR="$1" DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. -if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] +if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] then echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." exit 1 diff --git a/setup.py b/setup.py index 762bd7be8..9729f039d 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,13 @@ 'numpy', 'scipy', 'tensorflow', + 'Click', ], + entry_points={ + 'console_scripts': [ + 'af2 = alphafold.data.tools.cli:cli', + ], + }, tests_require=['mock'], classifiers=[ 'Development Status :: 5 - Production/Stable',