-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
169 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# spice-models | ||
Models trained on the SPICE dataset | ||
# SPICE-Models | ||
Models trained on the SPICE dataset. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
This directory contains the five equivariant transformer models described in (insert reference when available). | ||
|
||
The script `createSpiceDataset.py` converts the dataset file SPICE.hdf5 downloaded from https://github.com/openmm/spice-dataset/releases | ||
to the format used by [TorchMD-Net](https://github.com/torchmd/torchmd-net). It generates a new file SPICE-processed.hdf5 | ||
which was used for training. | ||
|
||
The file `hparams.yaml` contains the configuration used for training the models. All models used identical settings | ||
except that `seed` was set to a different value for each one (the numbers 1 through 5). Note that although the file | ||
specifies `num_epochs: 1000`, training was halted after 24 hours (when the training job reached the end of its allocated | ||
time). This corresponded to 118 epochs. | ||
|
||
The files ending in `.ckpt` are checkpoint files for TorchMD-Net 0.2.4 containing the trained models. They should | ||
hopefully work with later versions as well, but that may not be guaranteed. They can be loaded like this: | ||
|
||
```python | ||
from torchmdnet.models.model import load_model | ||
model = load_model('model1.ckpt') | ||
``` | ||
|
||
The `device` argument to `load_model()` can be used to specify a device to load it on. For example, | ||
|
||
```python | ||
model = load_model('model1.ckpt', device=torch.device('cuda:0')) | ||
``` | ||
|
||
To compute energy and forces for a molecular conformation, invoke the model's `forward()` method. It takes two arguments: | ||
a tensor of length `n_atoms` and dtype `long` containing the atom types, and a tensor of shape `(n_atoms, 3)` and dtype | ||
`float32` containing the atom positions in angstroms. It returns two arguments: the potential energy in kJ/mol, and | ||
the force on each atom in kJ/mol/angstrom. Atom types are defined by the element and formal charge of each atom. The | ||
mapping is defined in `createSpiceDataset.py` with this dictionary: | ||
|
||
```python | ||
typeDict = {('Br', -1): 0, ('Br', 0): 1, ('C', -1): 2, ('C', 0): 3, ('C', 1): 4, ('Ca', 2): 5, ('Cl', -1): 6, | ||
('Cl', 0): 7, ('F', -1): 8, ('F', 0): 9, ('H', 0): 10, ('I', -1): 11, ('I', 0): 12, ('K', 1): 13, | ||
('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20, | ||
('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27} | ||
``` | ||
|
||
For example, the following computes the energy and forces for a pair of ions (Cl- and Na+) positioned 3 angstroms apart. | ||
|
||
```python | ||
types = torch.tensor([6, 19], dtype=torch.long) | ||
pos = torch.tensor([[0, 0, 0], [0, 3, 0]], dtype=torch.float32) | ||
energy, forces = model.forward(types, pos) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
from openff.toolkit.topology import Molecule | ||
from openmm.unit import * | ||
from collections import defaultdict | ||
import h5py | ||
|
||
typeDict = {('Br', -1): 0, ('Br', 0): 1, ('C', -1): 2, ('C', 0): 3, ('C', 1): 4, ('Ca', 2): 5, ('Cl', -1): 6, | ||
('Cl', 0): 7, ('F', -1): 8, ('F', 0): 9, ('H', 0): 10, ('I', -1): 11, ('I', 0): 12, ('K', 1): 13, | ||
('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20, | ||
('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27} | ||
|
||
infile = h5py.File('SPICE.hdf5') | ||
|
||
# First pass: group the samples by total number of atoms. | ||
|
||
groupsByAtomCount = defaultdict(list) | ||
for name in infile: | ||
group = infile[name] | ||
count = len(group['atomic_numbers']) | ||
groupsByAtomCount[count].append(group) | ||
|
||
# Create the output file. | ||
|
||
filename = 'SPICE-processed.hdf5' | ||
outfile = h5py.File(filename, 'w') | ||
|
||
# One pass for each number of atoms, creating a group for it. | ||
|
||
print(sorted(list(groupsByAtomCount.keys()))) | ||
posScale = 1*bohr/angstrom | ||
energyScale = 1*hartree/item/(kilojoules_per_mole) | ||
forceScale = energyScale/posScale | ||
for count in sorted(groupsByAtomCount.keys()): | ||
print(count) | ||
smiles = [] | ||
pos = [] | ||
types = [] | ||
energy = [] | ||
forces = [] | ||
for g in groupsByAtomCount[count]: | ||
molSmiles = g['smiles'][0] | ||
mol = Molecule.from_mapped_smiles(molSmiles, allow_undefined_stereo=True) | ||
molTypes = [typeDict[(atom.element.symbol, atom.formal_charge/elementary_charge)] for atom in mol.atoms] | ||
assert len(molTypes) == count | ||
for i, atom in enumerate(mol.atoms): | ||
assert atom.atomic_number == g['atomic_numbers'][i] | ||
numConfs = g['conformations'].shape[0] | ||
for i in range(numConfs): | ||
smiles.append(molSmiles) | ||
pos.append(g['conformations'][i]) | ||
types.append(molTypes) | ||
energy.append(g['formation_energy'][i]) | ||
forces.append(g['dft_total_gradient'][i]) | ||
group = outfile.create_group(f'samples{count}') | ||
group.create_dataset('smiles', data=smiles, dtype=h5py.string_dtype()) | ||
group.create_dataset('types', data=np.array(types), dtype=np.int8) | ||
group.create_dataset('pos', data=np.array(pos)*posScale, dtype=np.float32) | ||
group.create_dataset('energy', data=np.array(energy)*energyScale, dtype=np.float32) | ||
group.create_dataset('forces', data=-np.array(forces)*forceScale, dtype=np.float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
activation: silu | ||
aggr: add | ||
atom_filter: -1 | ||
attn_activation: silu | ||
batch_size: 128 | ||
charge: false | ||
conf: null | ||
coord_files: null | ||
cutoff_lower: 0.0 | ||
cutoff_upper: 10.0 | ||
dataset: HDF5 | ||
dataset_arg: null | ||
dataset_root: SPICE-processed.hdf5 | ||
derivative: true | ||
distance_influence: both | ||
distributed_backend: ddp | ||
early_stopping_patience: 20 | ||
ema_alpha_dy: 1.0 | ||
ema_alpha_y: 1.0 | ||
embed_files: null | ||
embedding_dimension: 128 | ||
energy_files: null | ||
energy_weight: 1.0 | ||
force_files: null | ||
force_weight: 1.0 | ||
inference_batch_size: 128 | ||
load_model: null | ||
log_dir: model1b | ||
lr: 0.0005 | ||
lr_factor: 0.5 | ||
lr_metric: train_loss | ||
lr_min: 1.0e-07 | ||
lr_patience: 0 | ||
lr_warmup_steps: 0 | ||
max_num_neighbors: 100 | ||
max_z: 28 | ||
model: equivariant-transformer | ||
neighbor_embedding: true | ||
ngpus: -1 | ||
num_epochs: 1000 | ||
num_heads: 8 | ||
num_layers: 6 | ||
num_nodes: 1 | ||
num_rbf: 64 | ||
num_workers: 16 | ||
output_model: Scalar | ||
precision: 32 | ||
prior_model: null | ||
rbf_type: expnorm | ||
redirect: true | ||
reduce_op: add | ||
reset_trainer: false | ||
save_interval: 1 | ||
seed: 1 | ||
spin: false | ||
splits: null | ||
standardize: false | ||
test_interval: 10 | ||
test_size: 0.0 | ||
train_size: null | ||
trainable_rbf: true | ||
val_size: 0.05 | ||
weight_decay: 0.0 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.