Skip to content

Commit

Permalink
Get energy targets and predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
nerkulec committed Aug 21, 2024
1 parent 29fab9a commit 24f9f62
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 1 deletion.
160 changes: 160 additions & 0 deletions mala/network/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,166 @@ def _calculate_energy_errors(
)
return errors

def _calculate_energy_targets_and_predictions(
self, actual_outputs, predicted_outputs, energy_types, snapshot_number
):
"""
Calculate the energies corresponding to actual and predicted outputs.
Parameters
----------
actual_outputs : numpy.ndarray
Actual outputs.
predicted_outputs : numpy.ndarray
Predicted outputs.
energy_types : list
List of energy types to calculate.
snapshot_number : int
Snapshot number for which the energies are calculated.
"""
target_calculator = self.data.target_calculator
output_file = self.data.get_snapshot_calculation_output(
snapshot_number
)
if not output_file:
raise Exception(
"Output file needed for energy calculations."
)
target_calculator.read_additional_calculation_data(output_file)

targets = {}
predictions = {}
fe_actual = None
fe_predicted = None
try:
fe_actual = target_calculator.get_self_consistent_fermi_energy(
actual_outputs
)
except ValueError:
targets = {
energy_type: np.nan for energy_type in energy_types
}
predictions = {
energy_type: np.nan for energy_type in energy_types
}
printout(
"CAUTION! LDOS ground truth is so wrong that the "
"estimation of the self consistent Fermi energy fails."
)
return targets, predictions
try:
fe_predicted = target_calculator.get_self_consistent_fermi_energy(
predicted_outputs
)
except ValueError:
targets = {
energy_type: np.nan for energy_type in energy_types
}
predictions = {
energy_type: np.nan for energy_type in energy_types
}
printout(
"CAUTION! LDOS prediction is so wrong that the "
"estimation of the self consistent Fermi energy fails."
)
return targets, predictions
for energy_type in energy_types:
if energy_type == "fermi_energy":
targets[energy_type] = fe_actual
predictions[energy_type] = fe_predicted
elif energy_type == "band_energy":
if not isinstance(target_calculator, LDOS) and not isinstance(
target_calculator, DOS
):
raise Exception(
"Cannot calculate the band energy from this observable."
)
try:
target_calculator.read_from_array(actual_outputs)
be_actual = target_calculator.get_band_energy(
fermi_energy=fe_actual
)
target_calculator.read_from_array(predicted_outputs)
be_predicted = target_calculator.get_band_energy(
fermi_energy=fe_predicted
)
targets[energy_type] = be_actual * 1000 / len(target_calculator.atoms)
predictions[energy_type] = be_predicted * 1000 / len(target_calculator.atoms)
except ValueError:
targets[energy_type] = np.nan
predictions[energy_type] = np.nan
elif energy_type == "band_energy_actual_fe":
if not isinstance(target_calculator, LDOS) and not isinstance(
target_calculator, DOS
):
raise Exception(
"Cannot calculate the band energy from this observable."
)
try:
target_calculator.read_from_array(predicted_outputs)
be_predicted_actual_fe = (
target_calculator.get_band_energy(
fermi_energy=fe_actual
)
)
targets[energy_type] = be_actual * 1000 / len(target_calculator.atoms)
predictions[energy_type] = be_predicted_actual_fe * 1000 / len(target_calculator.atoms)
except ValueError:
targets[energy_type] = np.nan
predictions[energy_type] = np.nan
elif energy_type == "total_energy":
if not isinstance(target_calculator, LDOS):
raise Exception(
"Cannot calculate the total energy from this "
"observable."
)
try:
target_calculator.read_additional_calculation_data(
self.data.get_snapshot_calculation_output(
snapshot_number
)
)
target_calculator.read_from_array(actual_outputs)
te_actual = target_calculator.get_total_energy(
fermi_energy=fe_actual
)
target_calculator.read_from_array(predicted_outputs)
te_predicted = target_calculator.get_total_energy(
fermi_energy=fe_predicted
)
targets[energy_type] = te_actual * 1000 / len(target_calculator.atoms)
predictions[energy_type] = te_predicted * 1000 / len(target_calculator.atoms)
except ValueError:
targets[energy_type] = np.nan
predictions[energy_type] = np.nan
elif energy_type == "total_energy_actual_fe":
if not isinstance(target_calculator, LDOS):
raise Exception(
"Cannot calculate the total energy from this "
"observable."
)
try:
target_calculator.read_from_array(predicted_outputs)
te_predicted_actual_fe = (
target_calculator.get_total_energy(
fermi_energy=fe_actual
)
)

targets[energy_type] = te_actual * 1000 / len(target_calculator.atoms)
predictions[energy_type] = te_predicted_actual_fe * 1000 / len(target_calculator.atoms)
except ValueError:
targets[energy_type] = np.nan
predictions[energy_type] = np.nan
else:
raise Exception(
f"Invalid energy type ({energy_type}) requested."
)
return targets, predictions

def save_run(
self,
run_name,
Expand Down
32 changes: 31 additions & 1 deletion mala/network/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,40 @@ def test_snapshot(self, snapshot_number, data_type="te"):
snapshot_number,
)
return results

def get_energy_targets_and_predictions(self, snapshot_number, data_type="te"):
"""
Get the energy targets and predictions for a single snapshot.
Parameters
----------
snapshot_number : int
Snapshot which to test.
data_type : str
'tr', 'va', or 'te' indicating the partition to be tested
Returns
-------
results : dict
A dictionary containing the errors for the selected observables.
"""
actual_outputs, predicted_outputs = self.predict_targets(
snapshot_number, data_type=data_type
)

energy_metrics = [metric for metric in self.observables_to_test if "energy" in metric]
targets, predictions = self._calculate_energy_targets_and_predictions(
actual_outputs,
predicted_outputs,
energy_metrics,
snapshot_number,
)
return targets, predictions

def predict_targets(self, snapshot_number, data_type="te"):
"""
Get actual and predicted output for a snapshot.
Get actual and predicted energy outputs for a snapshot.
Parameters
----------
Expand Down

0 comments on commit 24f9f62

Please sign in to comment.