Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Oct 24, 2023
1 parent f684f54 commit d4a989a
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 76 deletions.
9 changes: 2 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import tempfile
from dataclasses import asdict
from pathlib import Path

import numpy as np
import parsl
import pytest
import requests
import yaml
from ase import Atoms
from ase.build import bulk
from ase.calculators.emt import EMT

import psiflow
from psiflow.data import Dataset, FlowAtoms
from psiflow.models import (AllegroConfig, AllegroModel, MACEConfig, MACEModel,
NequIPConfig, NequIPModel)
from psiflow.reference import EMTReference
from psiflow.models import AllegroConfig, MACEConfig, MACEModel, NequIPConfig


def pytest_addoption(parser):
Expand Down Expand Up @@ -104,7 +99,7 @@ def generate_emt_cu_data(nstates, amplitude):
pos = atoms.get_positions()
box = atoms.get_cell()
atoms_list = []
for i in range(nstates):
for _ in range(nstates):
atoms.set_positions(
pos + np.random.uniform(-amplitude, amplitude, size=(len(atoms), 3))
)
Expand Down
11 changes: 3 additions & 8 deletions tests/test_bias.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import tempfile

import numpy as np
import pytest
import requests
import yaml
from ase.build import bulk, make_supercell
from ase.build import make_supercell

from psiflow.data import Dataset, NullState
from psiflow.models import NequIPModel
from psiflow.walkers import PlumedBias, RandomWalker
from psiflow.walkers.bias import (
generate_external_grid,
Expand Down Expand Up @@ -111,7 +106,7 @@ def test_bias_evaluate(context, dataset):
assert np.allclose(volume, values[i, 0])
assert np.allclose(np.zeros(values[:, 1].shape), values[:, 1])
dataset_ = bias.evaluate(dataset, as_dataset=True)
for i, atoms in enumerate(dataset_.as_list().result()):
for atoms in dataset_.as_list().result():
assert np.allclose(atoms.get_volume(), atoms.info["CV1"])
state = dataset_[0].result()
state.reset()
Expand Down Expand Up @@ -142,7 +137,7 @@ def test_bias_external(context, dataset, tmp_path):
CV: VOLUME
external: EXTERNAL ARG=CV FILE=test_grid
"""
bias_function = lambda x: np.exp(-0.01 * (x - 150) ** 2)
bias_function = lambda x: np.exp(-0.01 * (x - 150) ** 2) # noqa: E731
variable = np.linspace(0, 300, 500)
grid = generate_external_grid(bias_function, variable, "CV", periodic=False)
data = {"EXTERNAL": grid}
Expand Down
25 changes: 12 additions & 13 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@

def test_flow_atoms(dataset, tmp_path):
atoms = dataset.get(index=0).result().copy() # copy necessary with HTEX!
assert type(atoms) == FlowAtoms
assert type(atoms) is FlowAtoms
atoms.reference_status = True
atoms_ = atoms.copy()
assert atoms_.reference_status == True
assert atoms_.reference_status
atoms_ = FlowAtoms.from_atoms(atoms)
assert atoms_.reference_status == True
assert atoms_.reference_status
for i in range(dataset.length().result()):
atoms = dataset[i].result()
assert type(atoms) == FlowAtoms
assert atoms.reference_status == True
assert type(atoms) is FlowAtoms
assert atoms.reference_status
assert dataset.labeled().length().result() == dataset.length().result()
dataset += Dataset([NullState])
assert dataset.length().result() == 1 + dataset.not_null().length().result()
assert atoms.reference_status == True
assert atoms.reference_status
atoms.reset()
atoms.cell[:] = np.array([[3, 1, 1], [1, 5, 0], [0, -1, 5]])
assert not "energy" in atoms.info
assert atoms.reference_status == False
assert "energy" not in atoms.info
assert not atoms.reference_status
assert tuple(sorted(atoms.elements)) == ("Cu", "H")
assert not is_reduced(atoms.cell)
atoms.canonical_orientation()
Expand All @@ -42,7 +42,6 @@ def test_dataset_empty(tmp_path):
assert isinstance(dataset.data_future, DataFuture)
path_xyz = tmp_path / "test.xyz"
dataset.save(path_xyz) # ensure the copy is executed before assert
assert not os.path.isfile(path_xyz)
psiflow.wait()
assert os.path.isfile(path_xyz)
with pytest.raises(ValueError): # cannot save outside cwd
Expand All @@ -53,8 +52,8 @@ def test_dataset_append(dataset):
assert 20 == dataset.length().result()
atoms_list = dataset.as_list().result()
assert len(atoms_list) == 20
assert type(atoms_list) == list
assert type(atoms_list[0]) == FlowAtoms
assert type(atoms_list) is list
assert type(atoms_list[0]) is FlowAtoms
empty = Dataset([]) # use [] instead of None
empty.append(dataset)
assert 20 == empty.length().result()
Expand Down Expand Up @@ -216,7 +215,7 @@ def test_data_elements(dataset):

def test_data_reset(dataset):
dataset = dataset.reset()
assert not "energy" in dataset[0].result().info
assert "energy" not in dataset[0].result().info


def test_nullstate(context):
Expand Down Expand Up @@ -249,7 +248,7 @@ def test_identifier(dataset):
for i in range(data.length().result()):
s = data[i].result()
if not s == NullState:
assert not "identifier" in s.info
assert "identifier" not in s.info
identifier = data.assign_identifiers(10)
assert identifier.result() == 10 # none are labeled
identifier = dataset.assign_identifiers(10)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_incremental_learning(gpu, tmp_path, mace_config, dataset):
walkers,
)
assert data.length().result() == len(walkers) # perform 1 iteration
for i, walker in enumerate(walkers):
for walker in walkers:
assert not walker.is_reset().result()
steps, kappas, centers = walker.bias.get_moving_restraint(variable="CV")
assert steps == 10
Expand All @@ -172,7 +172,7 @@ def test_temperature_ramp(context):
assert apply_temperature_ramp(100, 300, 1, 100) == 300
assert apply_temperature_ramp(100, 500, 3, 550) == 500
T = 100
for i in range(3):
for _ in range(3):
T = apply_temperature_ramp(100, 500, 5, T)
assert T == 1 / (1 / 100 - 3 * (1 / 100 - 1 / 500) / 4)
assert not T == 500
Expand Down
26 changes: 13 additions & 13 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import ast
import copy
import os
from dataclasses import asdict

import numpy as np
import pytest
import torch
from ase.data import chemical_symbols
from ase.io.extxyz import read_extxyz
from parsl.app.futures import DataFuture
from parsl.dataflow.futures import AppFuture

import psiflow
from psiflow.committee import Committee
from psiflow.data import Dataset
from psiflow.execution import ModelEvaluation
from psiflow.models import (AllegroModel, MACEConfig, MACEModel, NequIPConfig,
NequIPModel, load_model)
from psiflow.models import (
AllegroModel,
MACEConfig,
MACEModel,
NequIPConfig,
NequIPModel,
load_model,
)
from psiflow.reference import EMTReference


Expand Down Expand Up @@ -79,14 +79,14 @@ def test_nequip_save_load(nequip_config, dataset, tmp_path):
path_config = tmp_path / "config_after_init.yaml"
path_model = tmp_path / "model_undeployed.pth"
path_deploy = tmp_path / "model_deployed.pth"
futures = model.save(tmp_path, require_done=True)
model.save(tmp_path, require_done=True)
assert os.path.exists(path_config_raw)
assert os.path.exists(path_config)
assert os.path.exists(path_model)
assert os.path.exists(path_deploy)

model_ = load_model(tmp_path)
assert type(model_) == NequIPModel
assert type(model_) is NequIPModel
assert model_.model_future is not None
assert model_.deploy_future is not None
e1 = model_.evaluate(dataset.get(indices=[3]))[0].result().info["energy"]
Expand Down Expand Up @@ -199,14 +199,14 @@ def test_allegro_save_load(allegro_config, dataset, tmp_path):
path_config = tmp_path / "config_after_init.yaml"
path_model = tmp_path / "model_undeployed.pth"
path_deploy = tmp_path / "model_deployed.pth"
futures = model.save(tmp_path, require_done=True)
model.save(tmp_path, require_done=True)
assert os.path.exists(path_config_raw)
assert os.path.exists(path_config)
assert os.path.exists(path_model)
assert os.path.exists(path_deploy)

model_ = load_model(tmp_path)
assert type(model_) == AllegroModel
assert type(model_) is AllegroModel
assert model_.model_future is not None
assert model_.deploy_future is not None
e1 = model_.evaluate(dataset.get(indices=[3]))[0].result().info["energy"]
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_mace_save_load(mace_config, dataset, tmp_path):
assert os.path.exists(path_deployed)

model_ = load_model(tmp_path)
assert type(model_) == MACEModel
assert type(model_) is MACEModel
assert model_.model_future is not None
assert model_.deploy_future is not None
e1 = model_.evaluate(dataset.get(indices=[3]))[0].result().info["energy"]
Expand Down
27 changes: 11 additions & 16 deletions tests/test_reference.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import os
from pathlib import Path

import molmod
import numpy as np
import pytest
import requests
from ase import Atoms
from ase.io.extxyz import write_extxyz
from ase.units import Pascal
from parsl.app.futures import DataFuture
from parsl.dataflow.futures import AppFuture
from pymatgen.io.cp2k.inputs import Cp2kInput

import psiflow
from psiflow.data import Dataset, FlowAtoms, NullState
from psiflow.reference import CP2KReference, EMTReference, NWChemReference
from psiflow.reference._cp2k import (insert_atoms_in_input,
insert_filepaths_in_input)
from psiflow.reference._cp2k import insert_filepaths_in_input


@pytest.fixture
Expand Down Expand Up @@ -182,11 +177,11 @@ def test_reference_emt(context, dataset, tmp_path):
assert evaluated.length().result() == len(atoms_list)

atoms = reference.evaluate(dataset_[5]).result()
assert type(atoms) == FlowAtoms
assert atoms.reference_status == True
assert type(atoms) is FlowAtoms
assert atoms.reference_status
atoms = reference.evaluate(dataset_[6]).result()
assert type(atoms) == FlowAtoms
assert atoms.reference_status == False
assert type(atoms) is FlowAtoms
assert not atoms.reference_status


def test_cp2k_insert_filepaths(fake_cp2k_input):
Expand Down Expand Up @@ -242,7 +237,7 @@ def test_cp2k_success(context, cp2k_reference):
dataset = Dataset([atoms])
evaluated = cp2k_reference.evaluate(dataset[0])
assert isinstance(evaluated, AppFuture)
assert evaluated.result().reference_status == True
assert evaluated.result().reference_statuss
assert Path(evaluated.result().reference_stdout).is_file()
assert Path(evaluated.result().reference_stderr).is_file()
assert "energy" in evaluated.result().info.keys()
Expand Down Expand Up @@ -375,7 +370,7 @@ def test_cp2k_failure(context, cp2k_data, tmp_path):
)
evaluated = reference.evaluate(atoms)
assert isinstance(evaluated, AppFuture)
assert evaluated.result().reference_status == False
assert not evaluated.result().reference_status
assert "energy" not in evaluated.result().info.keys()
with open(evaluated.result().reference_stdout, "r") as f:
log = f.read()
Expand All @@ -393,7 +388,7 @@ def test_cp2k_timeout(context, cp2k_reference):
)
evaluated = cp2k_reference.evaluate(atoms)
assert isinstance(evaluated, AppFuture)
assert evaluated.result().reference_status == False
assert not evaluated.result().reference_status
assert "energy" not in evaluated.result().info.keys()


Expand Down Expand Up @@ -438,14 +433,14 @@ def test_nwchem_success(nwchem_reference):
dataset = Dataset([atoms])
evaluated = nwchem_reference.evaluate(dataset[0])
assert isinstance(evaluated, AppFuture)
assert evaluated.result().reference_status == True
assert evaluated.result().reference_status
assert Path(evaluated.result().reference_stdout).is_file()
assert Path(evaluated.result().reference_stderr).is_file()
assert "energy" in evaluated.result().info.keys()
assert not "stress" in evaluated.result().info.keys()
assert "stress" not in evaluated.result().info.keys()
assert "forces" in evaluated.result().arrays.keys()
assert evaluated.result().arrays["forces"][0, 0] < 0
assert evaluated.result().arrays["forces"][1, 0] > 0

energy_h2 = nwchem_reference.evaluate(dataset)
nwchem_reference.evaluate(dataset)
assert nwchem_reference.compute_atomic_energy("H").result() < 0
5 changes: 2 additions & 3 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import parsl

from psiflow.committee import Committee
from psiflow.data import Dataset, FlowAtoms
from psiflow.data import FlowAtoms
from psiflow.metrics import Metrics, log_dataset
from psiflow.models import MACEModel
from psiflow.reference import EMTReference
from psiflow.sampling import sample_with_committee, sample_with_model
from psiflow.walkers import (BiasedDynamicWalker, DynamicWalker, PlumedBias,
RandomWalker)
from psiflow.walkers import BiasedDynamicWalker, DynamicWalker, PlumedBias, RandomWalker


def test_sample_metrics(mace_model, dataset, tmp_path):
Expand Down
Loading

0 comments on commit d4a989a

Please sign in to comment.