Skip to content

Commit

Permalink
add environmental pairing feature (#614)
Browse files Browse the repository at this point in the history
* add environmental pairing feature

* Refactor env pairing

---------

Co-authored-by: Milot Mirdita <[email protected]>
  • Loading branch information
endixk and milot-mirdita authored May 2, 2024
1 parent 1653605 commit 07644a8
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def mmseqs_search_pair(
dbbase: Path,
base: Path,
uniref_db: Path = Path("uniref30_2302_db"),
spire_db: Path = Path("spire_ctg10_2401_db"),
mmseqs: Path = Path("mmseqs"),
pair_env: bool = True,
prefilter_mode: int = 0,
s: float = 8,
threads: int = 64,
Expand All @@ -200,6 +202,13 @@ def mmseqs_search_pair(
dbSuffix1 = ".idx"
dbSuffix2 = ".idx"

if pair_env:
db = spire_db
output = ".env.paired.a3m"
else:
db = uniref_db
output = ".paired.a3m"

# fmt: off
# @formatter:off
search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000",]
Expand All @@ -209,16 +218,14 @@ def mmseqs_search_pair(
else:
search_param += ["--k-score", "'seq:96,prof:80'"]
expand_param = ["--expansion-mode", "0", "-e", "inf", "--expand-filter-clusters", "0", "--max-seq-id", "0.95",]
run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads),] + search_param,)
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads),] + expand_param,)
run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", "0.001", "--max-accept", "1000000", "--threads", str(threads), "-c", "0.5", "--cov-mode", "1",],)
run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}"), base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_pair"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "0", "--threads", str(threads), ],)
run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],)
run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],)
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],)
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".paired.a3m",],)
run_mmseqs(mmseqs, ["rmdb", base.joinpath("qdb")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("qdb_h")])
run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads),] + search_param,)
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads),] + expand_param,)
run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", "0.001", "--max-accept", "1000000", "--threads", str(threads), "-c", "0.5", "--cov-mode", "1",],)
run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_pair"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "0", "--threads", str(threads), ],)
run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],)
run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],)
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],)
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],)
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")])
Expand All @@ -230,7 +237,6 @@ def mmseqs_search_pair(
# @formatter:on
# fmt: on


def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument(
Expand Down Expand Up @@ -271,11 +277,15 @@ def main():
default=Path("colabfold_envdb_202108_db"),
help="Environmental database",
)
parser.add_argument("--db4", type=Path, default=Path("spire_ctg10_2401_db"), help="Environmental pairing database")

# poor man's boolean arguments
parser.add_argument(
"--use-env", type=int, default=1, choices=[0, 1], help="Use --db3"
)
parser.add_argument(
"--use-env-pairing", type=int, default=0, choices=[0, 1], help="Use --db4"
)
parser.add_argument(
"--use-templates", type=int, default=0, choices=[0, 1], help="Use --db2"
)
Expand Down Expand Up @@ -418,7 +428,22 @@ def main():
db_load_mode=args.db_load_mode,
threads=args.threads,
pairing_strategy=args.pairing_strategy,
pair_env=False,
)
if args.use_env_pairing:
mmseqs_search_pair(
mmseqs=args.mmseqs,
dbbase=args.dbbase,
base=args.base,
uniref_db=args.db1,
spire_db=args.db4,
prefilter_mode=args.prefilter_mode,
s=args.s,
db_load_mode=args.db_load_mode,
threads=args.threads,
pairing_strategy=args.pairing_strategy,
pair_env=True,
)

id = 0
for job_number, (
Expand All @@ -434,6 +459,14 @@ def main():
with args.base.joinpath(f"{id}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()

if args.use_env_pairing:
with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair:
with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env:
while chunk := file_pair_env.read(10 * 1024 * 1024):
file_pair.write(chunk)
args.base.joinpath(f"{id}.env.paired.a3m").unlink()

if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
Expand Down

0 comments on commit 07644a8

Please sign in to comment.